tprism.torch_util

This module contains utilities related to graphviz

 1"""
 2This module contains utilities related to graphviz
 3"""
 4
 5import torch
 6import torch.optim as optim
 7import torch.nn.functional as F
 8from graphviz import Digraph
 9
10
11def make_dot(var, params):
12    """ Produces Graphviz representation of PyTorch autograd graph
13
14    Blue nodes are the Variables that require grad, orange are Tensors
15    saved for backward in torch.autograd.Function
16
17    Args:
18        var: output Variable
19        params: dict of (name, Variable) to add names to node that
20            require grad (TODO: make optional)
21    """
22    param_map = {id(v): k for k, v in params.items()}
23    print(param_map)
24
25    node_attr = dict(
26        style="filled",
27        shape="box",
28        align="left",
29        fontsize="12",
30        ranksep="0.1",
31        height="0.2",
32    )
33    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
34    seen = set()
35
36    def size_to_str(size):
37        return "(" + (", ").join(["%d" % v for v in size]) + ")"
38
39    def add_nodes(var):
40        if var not in seen:
41            if torch.is_tensor(var):
42                dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange")
43            elif hasattr(var, "variable"):
44                u = var.variable
45                node_name = "%s\n %s" % (param_map.get(id(u)), size_to_str(u.size()))
46                dot.node(str(id(var)), node_name, fillcolor="lightblue")
47            else:
48                dot.node(str(id(var)), str(type(var).__name__))
49            seen.add(var)
50            if hasattr(var, "next_functions"):
51                for u in var.next_functions:
52                    if u[0] is not None:
53                        dot.edge(str(id(u[0])), str(id(var)))
54                        add_nodes(u[0])
55            if hasattr(var, "saved_tensors"):
56                for t in var.saved_tensors:
57                    dot.edge(str(id(t)), str(id(var)))
58                    add_nodes(t)
59
60    add_nodes(var.grad_fn)
61    return dot
62
63
64def draw_graph(y, name):
65    s = make_dot(y, {})
66    s.format = "png"
67    s.render(name)
def make_dot(var, params):
12def make_dot(var, params):
13    """ Produces Graphviz representation of PyTorch autograd graph
14
15    Blue nodes are the Variables that require grad, orange are Tensors
16    saved for backward in torch.autograd.Function
17
18    Args:
19        var: output Variable
20        params: dict of (name, Variable) to add names to node that
21            require grad (TODO: make optional)
22    """
23    param_map = {id(v): k for k, v in params.items()}
24    print(param_map)
25
26    node_attr = dict(
27        style="filled",
28        shape="box",
29        align="left",
30        fontsize="12",
31        ranksep="0.1",
32        height="0.2",
33    )
34    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
35    seen = set()
36
37    def size_to_str(size):
38        return "(" + (", ").join(["%d" % v for v in size]) + ")"
39
40    def add_nodes(var):
41        if var not in seen:
42            if torch.is_tensor(var):
43                dot.node(str(id(var)), size_to_str(var.size()), fillcolor="orange")
44            elif hasattr(var, "variable"):
45                u = var.variable
46                node_name = "%s\n %s" % (param_map.get(id(u)), size_to_str(u.size()))
47                dot.node(str(id(var)), node_name, fillcolor="lightblue")
48            else:
49                dot.node(str(id(var)), str(type(var).__name__))
50            seen.add(var)
51            if hasattr(var, "next_functions"):
52                for u in var.next_functions:
53                    if u[0] is not None:
54                        dot.edge(str(id(u[0])), str(id(var)))
55                        add_nodes(u[0])
56            if hasattr(var, "saved_tensors"):
57                for t in var.saved_tensors:
58                    dot.edge(str(id(t)), str(id(var)))
59                    add_nodes(t)
60
61    add_nodes(var.grad_fn)
62    return dot

Produces Graphviz representation of PyTorch autograd graph

Blue nodes are the Variables that require grad, orange are Tensors saved for backward in torch.autograd.Function

Arguments:
  • var: output Variable
  • params: dict of (name, Variable) to add names to node that require grad (TODO: make optional)
def draw_graph(y, name):
65def draw_graph(y, name):
66    s = make_dot(y, {})
67    s.format = "png"
68    s.render(name)