tprism.torch_util
This module contains utilities related to graphviz
1""" 2This module contains utilities related to graphviz 3""" 4 5import torch 6import torch.optim as optim 7import torch.nn.functional as F 8from graphviz import Digraph 9 10 11def make_dot(var, params): 12 """ Produces Graphviz representation of PyTorch autograd graph 13 14 Blue nodes are the Variables that require grad, orange are Tensors 15 saved for backward in torch.autograd.Function 16 17 Args: 18 var: output Variable 19 params: dict of (name, Variable) to add names to node that 20 require grad (TODO: make optional) 21 """ 22 param_map = {id(v): k for k, v in params.items()} 23 print(param_map) 24 25 node_attr = dict( 26 style="filled", 27 shape="box", 28 align="left", 29 fontsize="12", 30 ranksep="0.1", 31 height="0.2", 32 ) 33 dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 34 seen = set() 35 36 def size_to_str(size): 37 return "(" + (", ").join(["%d" % v for v in size]) + ")" 38 39 def add_nodes(var): 40 if var not in seen: 41 if torch.is_tensor(var): 42 dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange") 43 elif hasattr(var, "variable"): 44 u = var.variable 45 node_name = "%s\n %s" % (param_map.get(id(u)), size_to_str(u.size())) 46 dot.node(str(id(var)), node_name, fillcolor="lightblue") 47 else: 48 dot.node(str(id(var)), str(type(var).__name__)) 49 seen.add(var) 50 if hasattr(var, "next_functions"): 51 for u in var.next_functions: 52 if u[0] is not None: 53 dot.edge(str(id(u[0])), str(id(var))) 54 add_nodes(u[0]) 55 if hasattr(var, "saved_tensors"): 56 for t in var.saved_tensors: 57 dot.edge(str(id(t)), str(id(var))) 58 add_nodes(t) 59 60 add_nodes(var.grad_fn) 61 return dot 62 63 64def draw_graph(y, name): 65 s = make_dot(y, {}) 66 s.format = "png" 67 s.render(name)
def
make_dot(var, params):
12def make_dot(var, params): 13 """ Produces Graphviz representation of PyTorch autograd graph 14 15 Blue nodes are the Variables that require grad, orange are Tensors 16 saved for backward in torch.autograd.Function 17 18 Args: 19 var: output Variable 20 params: dict of (name, Variable) to add names to node that 21 require grad (TODO: make optional) 22 """ 23 param_map = {id(v): k for k, v in params.items()} 24 print(param_map) 25 26 node_attr = dict( 27 style="filled", 28 shape="box", 29 align="left", 30 fontsize="12", 31 ranksep="0.1", 32 height="0.2", 33 ) 34 dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12")) 35 seen = set() 36 37 def size_to_str(size): 38 return "(" + (", ").join(["%d" % v for v in size]) + ")" 39 40 def add_nodes(var): 41 if var not in seen: 42 if torch.is_tensor(var): 43 dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange") 44 elif hasattr(var, "variable"): 45 u = var.variable 46 node_name = "%s\n %s" % (param_map.get(id(u)), size_to_str(u.size())) 47 dot.node(str(id(var)), node_name, fillcolor="lightblue") 48 else: 49 dot.node(str(id(var)), str(type(var).__name__)) 50 seen.add(var) 51 if hasattr(var, "next_functions"): 52 for u in var.next_functions: 53 if u[0] is not None: 54 dot.edge(str(id(u[0])), str(id(var))) 55 add_nodes(u[0]) 56 if hasattr(var, "saved_tensors"): 57 for t in var.saved_tensors: 58 dot.edge(str(id(t)), str(id(var))) 59 add_nodes(t) 60 61 add_nodes(var.grad_fn) 62 return dot
Produces Graphviz representation of PyTorch autograd graph
Blue nodes are the Variables that require grad, orange are Tensors saved for backward in torch.autograd.Function
Arguments:
- var: output Variable
- params: dict of (name, Variable) to add names to node that require grad (TODO: make optional)
def
draw_graph(y, name):