tprism.torch_expl_graph

This module contains pytorch explanation graphs and pytorch tensors.

This module constructs a computational graph by forward traversing the given explanation graph.

  1"""
  2This module contains pytorch explanation graphs and pytorch tensors.
  3
  4This module constructs a computational graph by forward traversing the given explanation graph.
  5
  6"""
  7
  8
  9import torch
 10import torch.nn.functional as F
 11import json
 12import re
 13import numpy as np
 14from google.protobuf import json_format
 15
 16from itertools import chain
 17import collections
 18
 19import inspect
 20import importlib
 21import glob
 22import os
 23import re
 24import pickle
 25import h5py
 26import math
 27
 28import tprism.expl_pb2 as expl_pb2
 29import tprism.op.base
 30import tprism.loss.base
 31import tprism.constraint
 32
 33from tprism.expl_graph import ComputationalExplGraph, SwitchTensorProvider
 34from tprism.expl_graph import PlaceholderGraph, VocabSet
 35from tprism.loader import OperatorLoader
 36from tprism.placeholder import PlaceholderData
 37from numpy import int64
 38from torch import dtype
 39from typing import Any, Dict, List, Tuple, Union
 40
 41
 42class TorchComputationalExplGraph(ComputationalExplGraph, torch.nn.Module):
 43    """ This class is a concrete explanation graph for pytorch
 44
 45Note:
 46    A path in the explanation graph is represented as the following format:
 47        ```
 48        {
 49            "sw_template": [],         a list of template: List[str]
 50            "sw_inside": [],           a list of torch.tensor
 51            "prob_sw_inside": [],      a list of torch.tensor (scalar) 
 52            "node_template": [],       a list of template: List[str]
 53            "node_inside": [],         a list of torch.tensor
 54            "node_scalar_inside": [],  a list of torch.tensor (scalar)
 55        }
 56        ```
 57Note:
 58    A goal in the explanation graph is represented as the following format:
 59        ```
 60         goal_inside[sorted_id] = {
 61            "template": path_template,      template: List[str]
 62            "inside": path_inside,          torch.tensor
 63            "batch_flag": path_batch_flag,  bool
 64         }
 65        ```
 66    
 67    Regarding the path template of goals, the T-PRISM's assumption requires that all pathes have equal tensor size.
 68    However, index symbols need not be equal for all paths.
 69    So the system display a warning in case of different index symbols and only uses the index symbols in the first path.
 70
 71
 72    """
 73    def __init__(self, graph, tensor_provider, operator_loader, cycle_embedding_generator=None):
 74        torch.nn.Module.__init__(self)
 75        ComputationalExplGraph.__init__(self)
 76        ## setting
 77        self.operator_loader = None
 78        self.goal_template = None
 79        self.cycle_node = None
 80        self.param_op = None
 81        self.graph = graph
 82        self.loss = {}
 83        self.tensor_provider = tensor_provider
 84        self.cycle_embedding_generator = cycle_embedding_generator
 85        """
 86        if operator_loader is None:
 87            operator_loader = OperatorLoader()
 88            operator_loader.load_all("op/torch_")
 89        """
 90        self.operator_loader = operator_loader
 91        ###
 92        self.build()
 93        
 94    def build(self):
 95
 96        ## call super class        
 97        goal_template, cycle_node = self.build_explanation_graph_template(
 98            self.graph, self.tensor_provider, self.operator_loader
 99        )
100        self.goal_template = goal_template
101        self.cycle_node = cycle_node
102        
103        ## setting parameterized operators
104        self.param_op = torch.nn.ModuleList()
105        for k,v in self.operators.items():
106            if issubclass(v.__class__, torch.nn.Module):
107                print("Torch parameterized operator:", k)
108                self.param_op.append(v)
109        ##
110        for name, (param,tensor_type) in self.tensor_provider.params.items():
111            self.register_parameter(name, param)
112        
113    def _distribution_forward(self, name, dist, params, param_template, op):
114        if dist == "normal":
115            mean = params[0]
116            var = params[1]
117            scale = torch.sqrt(F.softplus(var))
118            q_z = torch.distributions.normal.Normal(mean, scale)
119            out_inside = q_z.rsample()
120            out_template = param_template[0]
121            p_z = torch.distributions.normal.Normal(
122                torch.zeros_like(mean), torch.ones_like(var)
123            )
124            loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z)
125            self.loss[name] = loss_KL.sum()
126        else:
127            out_inside = params[0]
128            out_template = param_template[0]
129            print("[ERROR] unknown distribution:", dist)
130        return out_inside, out_template
131
132    def make_einsum_args(self, template,out_template,path_batch_flag):
133        """
134        Example
135        template: [["i","j"],["j","k"]]
136        out_template: ["i","k"]
137        => "ij,jk->ik", out_template
138        """
139        lhs = ",".join(map(lambda x: "".join(x), template))
140        rhs = "".join(out_template)
141        if path_batch_flag:
142            rhs = "b" + rhs
143            out_template = ["b"] + out_template
144        einsum_eq = lhs + "->" + rhs
145        return einsum_eq, out_template
146    
147    def make_einsum_args_sublist(self,template,inputs,out_template,path_batch_flag):
148        """
149        Example
150        template: [["i","j"],["j","k"]]
151        out_template: ["i","k"]
152        => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template
153        """
154        symbol_set = set([e for sublist in template for e in sublist])
155        mapping={s:i for i,s in enumerate(symbol_set)}
156        if path_batch_flag:
157            out_template = ["b"] + out_template
158        sublistform_args=[]
159        for v,input_x in zip(template,inputs):
160            l=[mapping[el] for el in v]
161            sublistform_args.append(input_x)
162            sublistform_args.append(l)
163        sublistform_args.append([mapping[el] for el in out_template])
164        return sublistform_args, out_template
165
166    def _apply_operator(self,op,operator_loader,out_inside, out_template, dryrun):
167        ## restore operator
168        key=str(op.name)+"_"+str(op.values)
169        if key in self.operators:
170            op_obj=self.operators[key]
171        else:
172            #cls = operator_loader.get_operator(op.name)
173            #op_obj = cls(op.values)
174            assert True, "unknown operator "+key+" has been found in forward procedure"
175        if dryrun:
176            out_inside = {
177                "type": "operator",
178                "name": op.name,
179                "path": out_inside}
180        else:
181            out_inside = op_obj.call(out_inside)
182        out_template = op_obj.get_output_template(out_template)
183        return out_inside,out_template
184    
185    def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False):
186        goal_template = self.goal_template
187        cycle_embedding_generator = self.cycle_embedding_generator
188        cycle_node=self.cycle_node
189        node_template = []
190        node_inside = []
191        node_scalar_inside = []
192        path_batch_flag=False
193        for node in path.nodes:
194            temp_goal = goal_inside[node.sorted_id]
195            if node.sorted_id in cycle_node:
196                name = node.goal.name
197                template = goal_template[node.sorted_id]["template"]
198                shape = goal_template[node.sorted_id]["shape"]
199                if dryrun:
200                    args = node.goal.args
201                    temp_goal_inside={
202                        "type":"goal",
203                        "from":"cycle_embedding_generator",
204                        "name": name,
205                        "args": args,
206                        "id": node.sorted_id,
207                        "shape":shape}
208                else:
209                    temp_goal_inside = cycle_embedding_generator.forward(
210                        name, shape, node.sorted_id
211                    )
212                temp_goal_template = template
213                node_inside.append(temp_goal_inside)
214                node_template.append(temp_goal_template)
215            elif temp_goal is None:
216                print("  [ERROR] cycle node is detected")
217                temp_goal = goal_inside[node.sorted_id]
218                print(g.node.sorted_id)
219                print(node)
220                print(node.sorted_id)
221                print(temp_goal)
222                quit()
223            elif len(temp_goal["template"]) > 0:
224                # tensor subgoal
225                if dryrun:
226                    name = node.goal.name
227                    args = node.goal.args
228                    temp_goal_inside={
229                        "type":"goal",
230                        "from":"goal",
231                        "name": name,
232                        "args": args,
233                        "id": node.sorted_id,}
234                        #"shape":shape}
235                else:
236                    temp_goal_inside = temp_goal["inside"]
237                temp_goal_template = temp_goal["template"]
238                if temp_goal["batch_flag"]:
239                    path_batch_flag = True
240                node_inside.append(temp_goal_inside)
241                node_template.append(temp_goal_template)
242            else:  # scalar subgoal
243                if dryrun:
244                    name = node.goal.name
245                    args = node.goal.args
246                    temp_goal_inside={
247                        "type":"goal",
248                        "from":"goal",
249                        "name": name,
250                        "args": args,
251                        "id": node.sorted_id,}
252                        #"shape":()}
253                    node_scalar_inside.append(temp_goal_inside)
254                else:
255                    if type(temp_goal["inside"]) is list:
256                        a = torch.tensor(temp_goal["inside"])
257                        node_scalar_inside.append(torch.squeeze(a))
258                    else:
259                        node_scalar_inside.append(temp_goal["inside"])
260        return node_template, node_inside, node_scalar_inside, path_batch_flag
261
262    def forward_path_sw(self, path, verbose=False,verbose_embedding=False, dryrun=False):
263        tensor_provider = self.tensor_provider
264        sw_template = []
265        sw_inside = []
266        path_batch_flag=False
267        for sw in path.tensor_switches:
268            ph = tensor_provider.get_placeholder_name(sw.name)
269            if len(ph) > 0:
270                sw_template.append(["b"] + list(sw.values))
271                path_batch_flag = True
272            else:
273                sw_template.append(list(sw.values))
274            if dryrun:
275                #x = tensor_provider.get_embedding(sw.name, verbose_embedding)
276                sw_var = {"type":"tensor_atom",
277                        "from":"tensor_provider.get_embedding",
278                        "name":sw.name,}
279                        #"shape":x.shape}
280            else:
281                sw_var = tensor_provider.get_embedding(sw.name, verbose_embedding)
282            sw_inside.append(sw_var)
283        prob_sw_inside = []
284        if dryrun:
285            for sw in path.prob_switches:
286                prob_sw_inside.append({
287                    "type":"const",
288                    "name": sw.name,
289                    "value": sw.inside,
290                    "shape":(),})
291        else:
292            for sw in path.prob_switches:
293                prob_sw_inside.append(sw.inside)
294                """
295                prob_sw_inside.append({
296                    "type":"const",
297                    "name": sw.name,
298                    "value": sw.inside,
299                    "shape":(),})
300                """
301        return sw_template, sw_inside, prob_sw_inside, path_batch_flag
302    
303    def forward_path_op(self, ops,
304             sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,path_batch_flag, verbose=False, dryrun=False):
305      
306        if "distribution" in ops:
307            op = ops["distribution"]
308            dist = op.values[0]
309            name = g.node.goal.name
310            if dryrun:
311                #out_inside, out_template = self._distribution_forward_dryrun
312                out_inside={
313                    "type": "distribution",
314                    "name": name,
315                    "dist_type": op,
316                    "path": sw_node_inside}
317                out_template= sw_node_template#TODO
318            else:
319                out_inside, out_template = self._distribution_forward(
320                    name,
321                    dist,
322                    params=sw_node_inside,
323                    param_template=sw_node_template,
324                    op=op,
325                )
326            
327        else:# einsum operator
328            path_v = sorted(
329                zip(sw_node_template, sw_node_inside), key=lambda x: x[0]
330            )
331            template = [x[0] for x in path_v]
332            inside = [x[1] for x in path_v]
333            # constructing einsum operation using template and inside
334            out_template = self._compute_output_template(template)
335            # print(template,out_template)
336            if len(template) > 0:  # condition for einsum
337                if verbose:
338                    einsum_eq, out_template_v = self.make_einsum_args(template,out_template,path_batch_flag)
339                    print("  index:", einsum_eq)
340                    print("  var. :", [x.shape for x in inside])
341                    #print("  var. :", inside)
342                if dryrun:
343                    einsum_eq, out_template = self.make_einsum_args(template,out_template,path_batch_flag)
344                    out_inside = {
345                        "type":"einsum",
346                        "name":"torch.einsum",
347                        "einsum_eq":einsum_eq,
348                        "path": inside}
349                else:
350                    einsum_args, out_template = self.make_einsum_args_sublist(template,inside,out_template,path_batch_flag)
351                    #out_inside = torch.einsum(einsum_eq, *inside) * out_inside
352                    out_inside = torch.einsum(*einsum_args)
353            else:  # condition for scalar
354                if dryrun:
355                    out_inside = {
356                        "type":"nop",
357                        "name":"nop",
358                        "path": inside}
359            if dryrun:
360                out_inside["path_scalar"]=node_scalar_inside+prob_sw_inside
361            else:
362                # scalar subgoal
363                for scalar_inside in node_scalar_inside:
364                    out_inside = scalar_inside * out_inside
365                # prob(scalar) switch
366                for prob_inside in prob_sw_inside:
367                    out_inside = prob_inside * out_inside
368                
369            ## computing operaters
370            for op_name, op in ops.items():
371                if verbose:
372                    print("  operator:", op_name)
373                out_inside,out_template=self._apply_operator(
374                    op,
375                    operator_loader,
376                    out_inside,
377                    out_template,
378                    dryrun)
379                ##
380        return out_inside,out_template
381
382    def forward(self, verbose=False,verbose_embedding=False, dryrun=False):
383        """
384        Args:
385            verbose (bool): if true, this function displays an explanation graph with forward computation
386            dryrun (bool):  if true, this function outputs information required for calculation as goal_inside instead of computational graph
387
388        Returns:
389            Tuple[List[Dict],Dict]: a pair of goal_inside and loss:
390                - goal_inside: tensors assigned for each goal
391                - loss: loss derived from explanation graph: key = loss name and value = loss
392        """
393        graph = self.graph
394        tensor_provider = self.tensor_provider
395        cycle_embedding_generator = self.cycle_embedding_generator
396        goal_template = self.goal_template
397        cycle_node = self.cycle_node
398        operator_loader = self.operator_loader
399        self.loss = {}
400        # goal_template
401        # converting explanation graph to computational graph
402        goal_inside = [None] * len(graph.goals)
403        for i in range(len(graph.goals)):
404            g = graph.goals[i]
405            if verbose:
406                print(
407                    "=== tensor equation (node_id:%d, %s) ==="
408                    % (g.node.sorted_id, g.node.goal.name)
409                )
410            path_inside = []
411            path_template = []
412            path_batch_flag = False
413            for path in g.paths:
414                ## build template and inside for switches in the path
415                sw_template, sw_inside, prob_sw_inside, batch_flag = self.forward_path_sw(
416                        path, verbose, verbose_embedding, dryrun)
417                path_batch_flag = batch_flag or path_batch_flag
418                
419                ## building template and inside for nodes in the path
420                node_template, node_inside, node_scalar_inside, batch_flag = self.forward_path_node(
421                        path, goal_inside, verbose, dryrun)
422                path_batch_flag = batch_flag or path_batch_flag
423
424                ## building template and inside for all elements (switches and nodes) in the path
425                sw_node_template = sw_template + node_template
426                sw_node_inside = sw_inside + node_inside
427
428                ops = {op.name: op for op in path.operators}
429                
430                out_inside,out_template = self.forward_path_op(ops,
431                        sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,
432                        path_batch_flag,
433                        verbose, dryrun)
434                
435                path_inside.append(out_inside)
436                path_template.append(out_template)
437                ##
438            ### update inside
439            path_template_list = self._get_unique_list(path_template)
440            if len(path_template_list) == 0: # non-tensor/non-probabilistic path
441                if dryrun:
442                    goal_inside[i] = {
443                        "template": [],
444                        "inside": path_inside,
445                        "batch_flag": False,
446                    }
447                else:
448                    goal_inside[i] = {
449                        "template": [],
450                        "inside": torch.tensor(1),
451                        "batch_flag": False,
452                    }
453            else:
454                if len(path_template_list) != 1:
455                    print("[WARNING] missmatch indices:", path_template_list)
456                if dryrun:
457                    goal_inside[i] = {
458                        "template": path_template_list[0],
459                        "inside": path_inside,
460                        "batch_flag": path_batch_flag,
461                    }
462                else:
463                    if len(path_template_list[0]) == 0: # scalar inside
464                        goal_inside[i] = {
465                            "template": path_template_list[0],
466                            "inside": path_inside[0],
467                            "batch_flag": path_batch_flag,
468                        }
469                    else:
470                        temp_inside=torch.sum(torch.stack(path_inside), dim=0)
471                        goal_inside[i] = {
472                            "template": path_template_list[0],
473                            "inside": temp_inside,
474                            "batch_flag": path_batch_flag,
475                        }
476            if dryrun:
477                goal_inside[i]["id"]=g.node.sorted_id
478                goal_inside[i]["name"]=g.node.goal.name
479                goal_inside[i]["args"]=g.node.goal.args
480
481        self.loss.update(tensor_provider.get_loss())
482        return goal_inside, self.loss
483
484
485    
486class TorchTensorBase:
487    def __init__(self):
488        pass
489
490
491class TorchTensorOnehot(TorchTensorBase):
492    def __init__(self, provider, shape, value):
493        self.shape = shape
494        self.value = value
495
496    def __call__(self):
497        v = torch.eye(self.shape)[self.value]
498        return v
499
500
501class TorchTensor(TorchTensorBase):
502    def __init__(self, provider: 'TorchSwitchTensorProvider', name: str, shape: List[Union[int64, int]], dtype: dtype=torch.float32, tensor_type=None) -> None:
503        self.shape = shape
504        self.dtype = dtype
505        if name is None:
506            self.name = "tensor%04d" % (np.random.randint(0, 10000),)
507        else:
508            self.name = name
509        self.provider = provider
510        self.tensor_type = tensor_type
511        ###
512        self.constraint_tensor=tprism.constraint.get_constraint_tensor(shape, tensor_type, device=None, dtype=None)
513        self.param = None
514        if self.constraint_tensor is None:
515            param = torch.nn.Parameter(torch.Tensor(*shape), requires_grad=True)
516            self.param = param
517            provider.add_param(self.name, param, tensor_type)
518            self.reset_parameters()
519        else:
520            param=list(self.constraint_tensor.parameters())[0] #TODO
521            provider.add_param(self.name, param, tensor_type)
522
523        ###
524    def reset_parameters(self) -> None:
525        if len(self.param.shape) == 2:
526            torch.nn.init.kaiming_uniform_(self.param, a=math.sqrt(5))
527        else:
528            self.param.data.uniform_(-0.1, 0.1)
529
530    def __call__(self):
531        if self.constraint_tensor is None:
532            return self.param
533        else:
534            return self.constraint_tensor()
535
536
537class TorchGather(TorchTensorBase):
538    def __init__(self, provider: 'TorchSwitchTensorProvider', var: TorchTensor, idx: PlaceholderData) -> None:
539        self.var = var
540        self.idx = idx
541        self.provider = provider
542
543    def __call__(self):
544        if isinstance(self.idx, PlaceholderData):
545            idx = self.provider.get_embedding(self.idx)
546        else:
547            idx = self.idx
548        if isinstance(self.var, TorchTensor):
549            temp = self.var()
550            v = torch.index_select(temp, 0, idx)
551        elif isinstance(self.var, TorchTensorBase):
552            v = torch.index_select(self.var(), 0, idx)
553        elif isinstance(self.var, PlaceholderData):
554            v = self.provider.get_embedding(self.var)
555            v = v[idx]
556        else:
557            v = torch.index_select(self.var, 0, idx)
558        return v
559
560
561class TorchSwitchTensorProvider(SwitchTensorProvider):
562    def __init__(self) -> None:
563        self.tensor_onehot_class = TorchTensorOnehot
564        self.tensor_class = TorchTensor
565        self.tensor_gather_class = TorchGather
566
567        self.integer_dtype = torch.int32
568        super().__init__()
569
570    # forward
571    def get_loss(self, verbose:bool=False):
572        loss={}
573        for name, (param,tensor_type) in self.params.items():
574            m=re.match(r"^sparse\(([0-9\.]*)\)$", tensor_type)
575            if m:
576                coeff = float(m.group(1))
577                l=coeff*torch.norm(param,1)
578                loss["sparse_"+name]=l
579            elif tensor_type=="sparse":
580                l=torch.norm(param,1)
581                loss["sparse_"+name]=l
582        return loss
583    # forward
584    def get_embedding(self, name: Union[str,PlaceholderData], verbose:bool=False):
585        if verbose:
586            print("[INFO] get embedding:", name)
587        out = None
588        ## TODO:
589        if self.input_feed_dict is None:
590            if verbose:
591                print("[INFO] from tensor_embedding", name)
592            obj = self.tensor_embedding[name]
593            if isinstance(obj, TorchTensorBase):
594                out = obj()
595            else:
596                raise Exception("Unknoen embedding type", name, type(obj))
597        elif type(name) is str:
598            key = self.tensor_embedding[name]
599            if type(key) is PlaceholderData:
600                if verbose:
601                    print("[INFO] from PlaceholderData", name, "==>", key.name)
602                out = self.input_feed_dict[key]
603            elif isinstance(key, TorchTensorBase):
604                if verbose:
605                    print("[INFO] from Tensor", name, "==>")
606                out = key()
607            else:
608                raise Exception("Unknoen embedding type", name, key)
609        elif type(name) is PlaceholderData:
610            if verbose:
611                print("[INFO] from PlaceholderData", name)
612            out = self.input_feed_dict[name]
613        else:
614            raise Exception("Unknoen embedding", name)
615        if verbose:
616            print(out)
617            print(type(out))
618            print("sum:", out.sum())
619        return out
class TorchComputationalExplGraph(tprism.expl_graph.ComputationalExplGraph, torch.nn.modules.module.Module):
 43class TorchComputationalExplGraph(ComputationalExplGraph, torch.nn.Module):
 44    """ This class is a concrete explanation graph for pytorch
 45
 46Note:
 47    A path in the explanation graph is represented as the following format:
 48        ```
 49        {
 50            "sw_template": [],         a list of template: List[str]
 51            "sw_inside": [],           a list of torch.tensor
 52            "prob_sw_inside": [],      a list of torch.tensor (scalar) 
 53            "node_template": [],       a list of template: List[str]
 54            "node_inside": [],         a list of torch.tensor
 55            "node_scalar_inside": [],  a list of torch.tensor (scalar)
 56        }
 57        ```
 58Note:
 59    A goal in the explanation graph is represented as the following format:
 60        ```
 61         goal_inside[sorted_id] = {
 62            "template": path_template,      template: List[str]
 63            "inside": path_inside,          torch.tensor
 64            "batch_flag": path_batch_flag,  bool
 65         }
 66        ```
 67    
 68    Regarding the path template of goals, the T-PRISM's assumption requires that all pathes have equal tensor size.
 69    However, index symbols need not be equal for all paths.
 70    So the system display a warning in case of different index symbols and only uses the index symbols in the first path.
 71
 72
 73    """
 74    def __init__(self, graph, tensor_provider, operator_loader, cycle_embedding_generator=None):
 75        torch.nn.Module.__init__(self)
 76        ComputationalExplGraph.__init__(self)
 77        ## setting
 78        self.operator_loader = None
 79        self.goal_template = None
 80        self.cycle_node = None
 81        self.param_op = None
 82        self.graph = graph
 83        self.loss = {}
 84        self.tensor_provider = tensor_provider
 85        self.cycle_embedding_generator = cycle_embedding_generator
 86        """
 87        if operator_loader is None:
 88            operator_loader = OperatorLoader()
 89            operator_loader.load_all("op/torch_")
 90        """
 91        self.operator_loader = operator_loader
 92        ###
 93        self.build()
 94        
 95    def build(self):
 96
 97        ## call super class        
 98        goal_template, cycle_node = self.build_explanation_graph_template(
 99            self.graph, self.tensor_provider, self.operator_loader
100        )
101        self.goal_template = goal_template
102        self.cycle_node = cycle_node
103        
104        ## setting parameterized operators
105        self.param_op = torch.nn.ModuleList()
106        for k,v in self.operators.items():
107            if issubclass(v.__class__, torch.nn.Module):
108                print("Torch parameterized operator:", k)
109                self.param_op.append(v)
110        ##
111        for name, (param,tensor_type) in self.tensor_provider.params.items():
112            self.register_parameter(name, param)
113        
114    def _distribution_forward(self, name, dist, params, param_template, op):
115        if dist == "normal":
116            mean = params[0]
117            var = params[1]
118            scale = torch.sqrt(F.softplus(var))
119            q_z = torch.distributions.normal.Normal(mean, scale)
120            out_inside = q_z.rsample()
121            out_template = param_template[0]
122            p_z = torch.distributions.normal.Normal(
123                torch.zeros_like(mean), torch.ones_like(var)
124            )
125            loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z)
126            self.loss[name] = loss_KL.sum()
127        else:
128            out_inside = params[0]
129            out_template = param_template[0]
130            print("[ERROR] unknown distribution:", dist)
131        return out_inside, out_template
132
133    def make_einsum_args(self, template,out_template,path_batch_flag):
134        """
135        Example
136        template: [["i","j"],["j","k"]]
137        out_template: ["i","k"]
138        => "ij,jk->ik", out_template
139        """
140        lhs = ",".join(map(lambda x: "".join(x), template))
141        rhs = "".join(out_template)
142        if path_batch_flag:
143            rhs = "b" + rhs
144            out_template = ["b"] + out_template
145        einsum_eq = lhs + "->" + rhs
146        return einsum_eq, out_template
147    
148    def make_einsum_args_sublist(self,template,inputs,out_template,path_batch_flag):
149        """
150        Example
151        template: [["i","j"],["j","k"]]
152        out_template: ["i","k"]
153        => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template
154        """
155        symbol_set = set([e for sublist in template for e in sublist])
156        mapping={s:i for i,s in enumerate(symbol_set)}
157        if path_batch_flag:
158            out_template = ["b"] + out_template
159        sublistform_args=[]
160        for v,input_x in zip(template,inputs):
161            l=[mapping[el] for el in v]
162            sublistform_args.append(input_x)
163            sublistform_args.append(l)
164        sublistform_args.append([mapping[el] for el in out_template])
165        return sublistform_args, out_template
166
167    def _apply_operator(self,op,operator_loader,out_inside, out_template, dryrun):
168        ## restore operator
169        key=str(op.name)+"_"+str(op.values)
170        if key in self.operators:
171            op_obj=self.operators[key]
172        else:
173            #cls = operator_loader.get_operator(op.name)
174            #op_obj = cls(op.values)
175            assert True, "unknown operator "+key+" has been found in forward procedure"
176        if dryrun:
177            out_inside = {
178                "type": "operator",
179                "name": op.name,
180                "path": out_inside}
181        else:
182            out_inside = op_obj.call(out_inside)
183        out_template = op_obj.get_output_template(out_template)
184        return out_inside,out_template
185    
186    def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False):
187        goal_template = self.goal_template
188        cycle_embedding_generator = self.cycle_embedding_generator
189        cycle_node=self.cycle_node
190        node_template = []
191        node_inside = []
192        node_scalar_inside = []
193        path_batch_flag=False
194        for node in path.nodes:
195            temp_goal = goal_inside[node.sorted_id]
196            if node.sorted_id in cycle_node:
197                name = node.goal.name
198                template = goal_template[node.sorted_id]["template"]
199                shape = goal_template[node.sorted_id]["shape"]
200                if dryrun:
201                    args = node.goal.args
202                    temp_goal_inside={
203                        "type":"goal",
204                        "from":"cycle_embedding_generator",
205                        "name": name,
206                        "args": args,
207                        "id": node.sorted_id,
208                        "shape":shape}
209                else:
210                    temp_goal_inside = cycle_embedding_generator.forward(
211                        name, shape, node.sorted_id
212                    )
213                temp_goal_template = template
214                node_inside.append(temp_goal_inside)
215                node_template.append(temp_goal_template)
216            elif temp_goal is None:
217                print("  [ERROR] cycle node is detected")
218                temp_goal = goal_inside[node.sorted_id]
219                print(g.node.sorted_id)
220                print(node)
221                print(node.sorted_id)
222                print(temp_goal)
223                quit()
224            elif len(temp_goal["template"]) > 0:
225                # tensor subgoal
226                if dryrun:
227                    name = node.goal.name
228                    args = node.goal.args
229                    temp_goal_inside={
230                        "type":"goal",
231                        "from":"goal",
232                        "name": name,
233                        "args": args,
234                        "id": node.sorted_id,}
235                        #"shape":shape}
236                else:
237                    temp_goal_inside = temp_goal["inside"]
238                temp_goal_template = temp_goal["template"]
239                if temp_goal["batch_flag"]:
240                    path_batch_flag = True
241                node_inside.append(temp_goal_inside)
242                node_template.append(temp_goal_template)
243            else:  # scalar subgoal
244                if dryrun:
245                    name = node.goal.name
246                    args = node.goal.args
247                    temp_goal_inside={
248                        "type":"goal",
249                        "from":"goal",
250                        "name": name,
251                        "args": args,
252                        "id": node.sorted_id,}
253                        #"shape":()}
254                    node_scalar_inside.append(temp_goal_inside)
255                else:
256                    if type(temp_goal["inside"]) is list:
257                        a = torch.tensor(temp_goal["inside"])
258                        node_scalar_inside.append(torch.squeeze(a))
259                    else:
260                        node_scalar_inside.append(temp_goal["inside"])
261        return node_template, node_inside, node_scalar_inside, path_batch_flag
262
263    def forward_path_sw(self, path, verbose=False,verbose_embedding=False, dryrun=False):
264        tensor_provider = self.tensor_provider
265        sw_template = []
266        sw_inside = []
267        path_batch_flag=False
268        for sw in path.tensor_switches:
269            ph = tensor_provider.get_placeholder_name(sw.name)
270            if len(ph) > 0:
271                sw_template.append(["b"] + list(sw.values))
272                path_batch_flag = True
273            else:
274                sw_template.append(list(sw.values))
275            if dryrun:
276                #x = tensor_provider.get_embedding(sw.name, verbose_embedding)
277                sw_var = {"type":"tensor_atom",
278                        "from":"tensor_provider.get_embedding",
279                        "name":sw.name,}
280                        #"shape":x.shape}
281            else:
282                sw_var = tensor_provider.get_embedding(sw.name, verbose_embedding)
283            sw_inside.append(sw_var)
284        prob_sw_inside = []
285        if dryrun:
286            for sw in path.prob_switches:
287                prob_sw_inside.append({
288                    "type":"const",
289                    "name": sw.name,
290                    "value": sw.inside,
291                    "shape":(),})
292        else:
293            for sw in path.prob_switches:
294                prob_sw_inside.append(sw.inside)
295                """
296                prob_sw_inside.append({
297                    "type":"const",
298                    "name": sw.name,
299                    "value": sw.inside,
300                    "shape":(),})
301                """
302        return sw_template, sw_inside, prob_sw_inside, path_batch_flag
303    
304    def forward_path_op(self, ops,
305             sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,path_batch_flag, verbose=False, dryrun=False):
306      
307        if "distribution" in ops:
308            op = ops["distribution"]
309            dist = op.values[0]
310            name = g.node.goal.name
311            if dryrun:
312                #out_inside, out_template = self._distribution_forward_dryrun
313                out_inside={
314                    "type": "distribution",
315                    "name": name,
316                    "dist_type": op,
317                    "path": sw_node_inside}
318                out_template= sw_node_template#TODO
319            else:
320                out_inside, out_template = self._distribution_forward(
321                    name,
322                    dist,
323                    params=sw_node_inside,
324                    param_template=sw_node_template,
325                    op=op,
326                )
327            
328        else:# einsum operator
329            path_v = sorted(
330                zip(sw_node_template, sw_node_inside), key=lambda x: x[0]
331            )
332            template = [x[0] for x in path_v]
333            inside = [x[1] for x in path_v]
334            # constructing einsum operation using template and inside
335            out_template = self._compute_output_template(template)
336            # print(template,out_template)
337            if len(template) > 0:  # condition for einsum
338                if verbose:
339                    einsum_eq, out_template_v = self.make_einsum_args(template,out_template,path_batch_flag)
340                    print("  index:", einsum_eq)
341                    print("  var. :", [x.shape for x in inside])
342                    #print("  var. :", inside)
343                if dryrun:
344                    einsum_eq, out_template = self.make_einsum_args(template,out_template,path_batch_flag)
345                    out_inside = {
346                        "type":"einsum",
347                        "name":"torch.einsum",
348                        "einsum_eq":einsum_eq,
349                        "path": inside}
350                else:
351                    einsum_args, out_template = self.make_einsum_args_sublist(template,inside,out_template,path_batch_flag)
352                    #out_inside = torch.einsum(einsum_eq, *inside) * out_inside
353                    out_inside = torch.einsum(*einsum_args)
354            else:  # condition for scalar
355                if dryrun:
356                    out_inside = {
357                        "type":"nop",
358                        "name":"nop",
359                        "path": inside}
360            if dryrun:
361                out_inside["path_scalar"]=node_scalar_inside+prob_sw_inside
362            else:
363                # scalar subgoal
364                for scalar_inside in node_scalar_inside:
365                    out_inside = scalar_inside * out_inside
366                # prob(scalar) switch
367                for prob_inside in prob_sw_inside:
368                    out_inside = prob_inside * out_inside
369                
370            ## computing operaters
371            for op_name, op in ops.items():
372                if verbose:
373                    print("  operator:", op_name)
374                out_inside,out_template=self._apply_operator(
375                    op,
376                    operator_loader,
377                    out_inside,
378                    out_template,
379                    dryrun)
380                ##
381        return out_inside,out_template
382
383    def forward(self, verbose=False,verbose_embedding=False, dryrun=False):
384        """
385        Args:
386            verbose (bool): if true, this function displays an explanation graph with forward computation
387            dryrun (bool):  if true, this function outputs information required for calculation as goal_inside instead of computational graph
388
389        Returns:
390            Tuple[List[Dict],Dict]: a pair of goal_inside and loss:
391                - goal_inside: tensors assigned for each goal
392                - loss: loss derived from explanation graph: key = loss name and value = loss
393        """
394        graph = self.graph
395        tensor_provider = self.tensor_provider
396        cycle_embedding_generator = self.cycle_embedding_generator
397        goal_template = self.goal_template
398        cycle_node = self.cycle_node
399        operator_loader = self.operator_loader
400        self.loss = {}
401        # goal_template
402        # converting explanation graph to computational graph
403        goal_inside = [None] * len(graph.goals)
404        for i in range(len(graph.goals)):
405            g = graph.goals[i]
406            if verbose:
407                print(
408                    "=== tensor equation (node_id:%d, %s) ==="
409                    % (g.node.sorted_id, g.node.goal.name)
410                )
411            path_inside = []
412            path_template = []
413            path_batch_flag = False
414            for path in g.paths:
415                ## build template and inside for switches in the path
416                sw_template, sw_inside, prob_sw_inside, batch_flag = self.forward_path_sw(
417                        path, verbose, verbose_embedding, dryrun)
418                path_batch_flag = batch_flag or path_batch_flag
419                
420                ## building template and inside for nodes in the path
421                node_template, node_inside, node_scalar_inside, batch_flag = self.forward_path_node(
422                        path, goal_inside, verbose, dryrun)
423                path_batch_flag = batch_flag or path_batch_flag
424
425                ## building template and inside for all elements (switches and nodes) in the path
426                sw_node_template = sw_template + node_template
427                sw_node_inside = sw_inside + node_inside
428
429                ops = {op.name: op for op in path.operators}
430                
431                out_inside,out_template = self.forward_path_op(ops,
432                        sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,
433                        path_batch_flag,
434                        verbose, dryrun)
435                
436                path_inside.append(out_inside)
437                path_template.append(out_template)
438                ##
439            ### update inside
440            path_template_list = self._get_unique_list(path_template)
441            if len(path_template_list) == 0: # non-tensor/non-probabilistic path
442                if dryrun:
443                    goal_inside[i] = {
444                        "template": [],
445                        "inside": path_inside,
446                        "batch_flag": False,
447                    }
448                else:
449                    goal_inside[i] = {
450                        "template": [],
451                        "inside": torch.tensor(1),
452                        "batch_flag": False,
453                    }
454            else:
455                if len(path_template_list) != 1:
456                    print("[WARNING] missmatch indices:", path_template_list)
457                if dryrun:
458                    goal_inside[i] = {
459                        "template": path_template_list[0],
460                        "inside": path_inside,
461                        "batch_flag": path_batch_flag,
462                    }
463                else:
464                    if len(path_template_list[0]) == 0: # scalar inside
465                        goal_inside[i] = {
466                            "template": path_template_list[0],
467                            "inside": path_inside[0],
468                            "batch_flag": path_batch_flag,
469                        }
470                    else:
471                        temp_inside=torch.sum(torch.stack(path_inside), dim=0)
472                        goal_inside[i] = {
473                            "template": path_template_list[0],
474                            "inside": temp_inside,
475                            "batch_flag": path_batch_flag,
476                        }
477            if dryrun:
478                goal_inside[i]["id"]=g.node.sorted_id
479                goal_inside[i]["name"]=g.node.goal.name
480                goal_inside[i]["args"]=g.node.goal.args
481
482        self.loss.update(tensor_provider.get_loss())
483        return goal_inside, self.loss

This class is a concrete explanation graph for pytorch

Note:

A path in the explanation graph is represented as the following format:

{
    "sw_template": [],         a list of template: List[str]
    "sw_inside": [],           a list of torch.tensor
    "prob_sw_inside": [],      a list of torch.tensor (scalar) 
    "node_template": [],       a list of template: List[str]
    "node_inside": [],         a list of torch.tensor
    "node_scalar_inside": [],  a list of torch.tensor (scalar)
}

Note:

A goal in the explanation graph is represented as the following format:

 goal_inside[sorted_id] = {
    "template": path_template,      template: List[str]
    "inside": path_inside,          torch.tensor
    "batch_flag": path_batch_flag,  bool
 }

Regarding the path template of goals, the T-PRISM's assumption requires that all pathes have equal tensor size. However, index symbols need not be equal for all paths. So the system display a warning in case of different index symbols and only uses the index symbols in the first path.

TorchComputationalExplGraph( graph, tensor_provider, operator_loader, cycle_embedding_generator=None)
74    def __init__(self, graph, tensor_provider, operator_loader, cycle_embedding_generator=None):
75        torch.nn.Module.__init__(self)
76        ComputationalExplGraph.__init__(self)
77        ## setting
78        self.operator_loader = None
79        self.goal_template = None
80        self.cycle_node = None
81        self.param_op = None
82        self.graph = graph
83        self.loss = {}
84        self.tensor_provider = tensor_provider
85        self.cycle_embedding_generator = cycle_embedding_generator
86        """
87        if operator_loader is None:
88            operator_loader = OperatorLoader()
89            operator_loader.load_all("op/torch_")
90        """
91        self.operator_loader = operator_loader
92        ###
93        self.build()

Initialize internal Module state, shared by both nn.Module and ScriptModule.

operator_loader
goal_template
cycle_node
param_op
graph
loss
tensor_provider
cycle_embedding_generator

if operator_loader is None: operator_loader = OperatorLoader() operator_loader.load_all("op/torch_")

def build(self):
 95    def build(self):
 96
 97        ## call super class        
 98        goal_template, cycle_node = self.build_explanation_graph_template(
 99            self.graph, self.tensor_provider, self.operator_loader
100        )
101        self.goal_template = goal_template
102        self.cycle_node = cycle_node
103        
104        ## setting parameterized operators
105        self.param_op = torch.nn.ModuleList()
106        for k,v in self.operators.items():
107            if issubclass(v.__class__, torch.nn.Module):
108                print("Torch parameterized operator:", k)
109                self.param_op.append(v)
110        ##
111        for name, (param,tensor_type) in self.tensor_provider.params.items():
112            self.register_parameter(name, param)
def make_einsum_args(self, template, out_template, path_batch_flag):
133    def make_einsum_args(self, template,out_template,path_batch_flag):
134        """
135        Example
136        template: [["i","j"],["j","k"]]
137        out_template: ["i","k"]
138        => "ij,jk->ik", out_template
139        """
140        lhs = ",".join(map(lambda x: "".join(x), template))
141        rhs = "".join(out_template)
142        if path_batch_flag:
143            rhs = "b" + rhs
144            out_template = ["b"] + out_template
145        einsum_eq = lhs + "->" + rhs
146        return einsum_eq, out_template

Example template: [["i","j"],["j","k"]] out_template: ["i","k"] => "ij,jk->ik", out_template

def make_einsum_args_sublist(self, template, inputs, out_template, path_batch_flag):
148    def make_einsum_args_sublist(self,template,inputs,out_template,path_batch_flag):
149        """
150        Example
151        template: [["i","j"],["j","k"]]
152        out_template: ["i","k"]
153        => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template
154        """
155        symbol_set = set([e for sublist in template for e in sublist])
156        mapping={s:i for i,s in enumerate(symbol_set)}
157        if path_batch_flag:
158            out_template = ["b"] + out_template
159        sublistform_args=[]
160        for v,input_x in zip(template,inputs):
161            l=[mapping[el] for el in v]
162            sublistform_args.append(input_x)
163            sublistform_args.append(l)
164        sublistform_args.append([mapping[el] for el in out_template])
165        return sublistform_args, out_template

Example template: [["i","j"],["j","k"]] out_template: ["i","k"] => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template

def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False):
186    def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False):
187        goal_template = self.goal_template
188        cycle_embedding_generator = self.cycle_embedding_generator
189        cycle_node=self.cycle_node
190        node_template = []
191        node_inside = []
192        node_scalar_inside = []
193        path_batch_flag=False
194        for node in path.nodes:
195            temp_goal = goal_inside[node.sorted_id]
196            if node.sorted_id in cycle_node:
197                name = node.goal.name
198                template = goal_template[node.sorted_id]["template"]
199                shape = goal_template[node.sorted_id]["shape"]
200                if dryrun:
201                    args = node.goal.args
202                    temp_goal_inside={
203                        "type":"goal",
204                        "from":"cycle_embedding_generator",
205                        "name": name,
206                        "args": args,
207                        "id": node.sorted_id,
208                        "shape":shape}
209                else:
210                    temp_goal_inside = cycle_embedding_generator.forward(
211                        name, shape, node.sorted_id
212                    )
213                temp_goal_template = template
214                node_inside.append(temp_goal_inside)
215                node_template.append(temp_goal_template)
216            elif temp_goal is None:
217                print("  [ERROR] cycle node is detected")
218                temp_goal = goal_inside[node.sorted_id]
219                print(g.node.sorted_id)
220                print(node)
221                print(node.sorted_id)
222                print(temp_goal)
223                quit()
224            elif len(temp_goal["template"]) > 0:
225                # tensor subgoal
226                if dryrun:
227                    name = node.goal.name
228                    args = node.goal.args
229                    temp_goal_inside={
230                        "type":"goal",
231                        "from":"goal",
232                        "name": name,
233                        "args": args,
234                        "id": node.sorted_id,}
235                        #"shape":shape}
236                else:
237                    temp_goal_inside = temp_goal["inside"]
238                temp_goal_template = temp_goal["template"]
239                if temp_goal["batch_flag"]:
240                    path_batch_flag = True
241                node_inside.append(temp_goal_inside)
242                node_template.append(temp_goal_template)
243            else:  # scalar subgoal
244                if dryrun:
245                    name = node.goal.name
246                    args = node.goal.args
247                    temp_goal_inside={
248                        "type":"goal",
249                        "from":"goal",
250                        "name": name,
251                        "args": args,
252                        "id": node.sorted_id,}
253                        #"shape":()}
254                    node_scalar_inside.append(temp_goal_inside)
255                else:
256                    if type(temp_goal["inside"]) is list:
257                        a = torch.tensor(temp_goal["inside"])
258                        node_scalar_inside.append(torch.squeeze(a))
259                    else:
260                        node_scalar_inside.append(temp_goal["inside"])
261        return node_template, node_inside, node_scalar_inside, path_batch_flag
def forward_path_sw(self, path, verbose=False, verbose_embedding=False, dryrun=False):
263    def forward_path_sw(self, path, verbose=False,verbose_embedding=False, dryrun=False):
264        tensor_provider = self.tensor_provider
265        sw_template = []
266        sw_inside = []
267        path_batch_flag=False
268        for sw in path.tensor_switches:
269            ph = tensor_provider.get_placeholder_name(sw.name)
270            if len(ph) > 0:
271                sw_template.append(["b"] + list(sw.values))
272                path_batch_flag = True
273            else:
274                sw_template.append(list(sw.values))
275            if dryrun:
276                #x = tensor_provider.get_embedding(sw.name, verbose_embedding)
277                sw_var = {"type":"tensor_atom",
278                        "from":"tensor_provider.get_embedding",
279                        "name":sw.name,}
280                        #"shape":x.shape}
281            else:
282                sw_var = tensor_provider.get_embedding(sw.name, verbose_embedding)
283            sw_inside.append(sw_var)
284        prob_sw_inside = []
285        if dryrun:
286            for sw in path.prob_switches:
287                prob_sw_inside.append({
288                    "type":"const",
289                    "name": sw.name,
290                    "value": sw.inside,
291                    "shape":(),})
292        else:
293            for sw in path.prob_switches:
294                prob_sw_inside.append(sw.inside)
295                """
296                prob_sw_inside.append({
297                    "type":"const",
298                    "name": sw.name,
299                    "value": sw.inside,
300                    "shape":(),})
301                """
302        return sw_template, sw_inside, prob_sw_inside, path_batch_flag
def forward_path_op( self, ops, sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside, path_batch_flag, verbose=False, dryrun=False):
304    def forward_path_op(self, ops,
305             sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,path_batch_flag, verbose=False, dryrun=False):
306      
307        if "distribution" in ops:
308            op = ops["distribution"]
309            dist = op.values[0]
310            name = g.node.goal.name
311            if dryrun:
312                #out_inside, out_template = self._distribution_forward_dryrun
313                out_inside={
314                    "type": "distribution",
315                    "name": name,
316                    "dist_type": op,
317                    "path": sw_node_inside}
318                out_template= sw_node_template#TODO
319            else:
320                out_inside, out_template = self._distribution_forward(
321                    name,
322                    dist,
323                    params=sw_node_inside,
324                    param_template=sw_node_template,
325                    op=op,
326                )
327            
328        else:# einsum operator
329            path_v = sorted(
330                zip(sw_node_template, sw_node_inside), key=lambda x: x[0]
331            )
332            template = [x[0] for x in path_v]
333            inside = [x[1] for x in path_v]
334            # constructing einsum operation using template and inside
335            out_template = self._compute_output_template(template)
336            # print(template,out_template)
337            if len(template) > 0:  # condition for einsum
338                if verbose:
339                    einsum_eq, out_template_v = self.make_einsum_args(template,out_template,path_batch_flag)
340                    print("  index:", einsum_eq)
341                    print("  var. :", [x.shape for x in inside])
342                    #print("  var. :", inside)
343                if dryrun:
344                    einsum_eq, out_template = self.make_einsum_args(template,out_template,path_batch_flag)
345                    out_inside = {
346                        "type":"einsum",
347                        "name":"torch.einsum",
348                        "einsum_eq":einsum_eq,
349                        "path": inside}
350                else:
351                    einsum_args, out_template = self.make_einsum_args_sublist(template,inside,out_template,path_batch_flag)
352                    #out_inside = torch.einsum(einsum_eq, *inside) * out_inside
353                    out_inside = torch.einsum(*einsum_args)
354            else:  # condition for scalar
355                if dryrun:
356                    out_inside = {
357                        "type":"nop",
358                        "name":"nop",
359                        "path": inside}
360            if dryrun:
361                out_inside["path_scalar"]=node_scalar_inside+prob_sw_inside
362            else:
363                # scalar subgoal
364                for scalar_inside in node_scalar_inside:
365                    out_inside = scalar_inside * out_inside
366                # prob(scalar) switch
367                for prob_inside in prob_sw_inside:
368                    out_inside = prob_inside * out_inside
369                
370            ## computing operaters
371            for op_name, op in ops.items():
372                if verbose:
373                    print("  operator:", op_name)
374                out_inside,out_template=self._apply_operator(
375                    op,
376                    operator_loader,
377                    out_inside,
378                    out_template,
379                    dryrun)
380                ##
381        return out_inside,out_template
def forward(self, verbose=False, verbose_embedding=False, dryrun=False):
383    def forward(self, verbose=False,verbose_embedding=False, dryrun=False):
384        """
385        Args:
386            verbose (bool): if true, this function displays an explanation graph with forward computation
387            dryrun (bool):  if true, this function outputs information required for calculation as goal_inside instead of computational graph
388
389        Returns:
390            Tuple[List[Dict],Dict]: a pair of goal_inside and loss:
391                - goal_inside: tensors assigned for each goal
392                - loss: loss derived from explanation graph: key = loss name and value = loss
393        """
394        graph = self.graph
395        tensor_provider = self.tensor_provider
396        cycle_embedding_generator = self.cycle_embedding_generator
397        goal_template = self.goal_template
398        cycle_node = self.cycle_node
399        operator_loader = self.operator_loader
400        self.loss = {}
401        # goal_template
402        # converting explanation graph to computational graph
403        goal_inside = [None] * len(graph.goals)
404        for i in range(len(graph.goals)):
405            g = graph.goals[i]
406            if verbose:
407                print(
408                    "=== tensor equation (node_id:%d, %s) ==="
409                    % (g.node.sorted_id, g.node.goal.name)
410                )
411            path_inside = []
412            path_template = []
413            path_batch_flag = False
414            for path in g.paths:
415                ## build template and inside for switches in the path
416                sw_template, sw_inside, prob_sw_inside, batch_flag = self.forward_path_sw(
417                        path, verbose, verbose_embedding, dryrun)
418                path_batch_flag = batch_flag or path_batch_flag
419                
420                ## building template and inside for nodes in the path
421                node_template, node_inside, node_scalar_inside, batch_flag = self.forward_path_node(
422                        path, goal_inside, verbose, dryrun)
423                path_batch_flag = batch_flag or path_batch_flag
424
425                ## building template and inside for all elements (switches and nodes) in the path
426                sw_node_template = sw_template + node_template
427                sw_node_inside = sw_inside + node_inside
428
429                ops = {op.name: op for op in path.operators}
430                
431                out_inside,out_template = self.forward_path_op(ops,
432                        sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,
433                        path_batch_flag,
434                        verbose, dryrun)
435                
436                path_inside.append(out_inside)
437                path_template.append(out_template)
438                ##
439            ### update inside
440            path_template_list = self._get_unique_list(path_template)
441            if len(path_template_list) == 0: # non-tensor/non-probabilistic path
442                if dryrun:
443                    goal_inside[i] = {
444                        "template": [],
445                        "inside": path_inside,
446                        "batch_flag": False,
447                    }
448                else:
449                    goal_inside[i] = {
450                        "template": [],
451                        "inside": torch.tensor(1),
452                        "batch_flag": False,
453                    }
454            else:
455                if len(path_template_list) != 1:
456                    print("[WARNING] missmatch indices:", path_template_list)
457                if dryrun:
458                    goal_inside[i] = {
459                        "template": path_template_list[0],
460                        "inside": path_inside,
461                        "batch_flag": path_batch_flag,
462                    }
463                else:
464                    if len(path_template_list[0]) == 0: # scalar inside
465                        goal_inside[i] = {
466                            "template": path_template_list[0],
467                            "inside": path_inside[0],
468                            "batch_flag": path_batch_flag,
469                        }
470                    else:
471                        temp_inside=torch.sum(torch.stack(path_inside), dim=0)
472                        goal_inside[i] = {
473                            "template": path_template_list[0],
474                            "inside": temp_inside,
475                            "batch_flag": path_batch_flag,
476                        }
477            if dryrun:
478                goal_inside[i]["id"]=g.node.sorted_id
479                goal_inside[i]["name"]=g.node.goal.name
480                goal_inside[i]["args"]=g.node.goal.args
481
482        self.loss.update(tensor_provider.get_loss())
483        return goal_inside, self.loss
Arguments:
  • verbose (bool): if true, this function displays an explanation graph with forward computation
  • dryrun (bool): if true, this function outputs information required for calculation as goal_inside instead of computational graph
Returns:

Tuple[List[Dict],Dict]: a pair of goal_inside and loss: - goal_inside: tensors assigned for each goal - loss: loss derived from explanation graph: key = loss name and value = loss

Inherited Members
tprism.expl_graph.ComputationalExplGraph
operators
build_explanation_graph_template
torch.nn.modules.module.Module
dump_patches
training
call_super_init
register_buffer
register_parameter
add_module
register_module
get_submodule
get_parameter
get_buffer
get_extra_state
set_extra_state
apply
cuda
ipu
xpu
cpu
type
float
double
half
bfloat16
to_empty
to
register_full_backward_pre_hook
register_backward_hook
register_full_backward_hook
register_forward_pre_hook
register_forward_hook
register_state_dict_pre_hook
state_dict
register_load_state_dict_post_hook
load_state_dict
parameters
named_parameters
buffers
named_buffers
children
named_children
modules
named_modules
train
eval
requires_grad_
zero_grad
share_memory
extra_repr
compile
class TorchTensorBase:
487class TorchTensorBase:
488    def __init__(self):
489        pass
class TorchTensorOnehot(TorchTensorBase):
492class TorchTensorOnehot(TorchTensorBase):
493    def __init__(self, provider, shape, value):
494        self.shape = shape
495        self.value = value
496
497    def __call__(self):
498        v = torch.eye(self.shape)[self.value]
499        return v
TorchTensorOnehot(provider, shape, value)
493    def __init__(self, provider, shape, value):
494        self.shape = shape
495        self.value = value
shape
value
class TorchTensor(TorchTensorBase):
502class TorchTensor(TorchTensorBase):
503    def __init__(self, provider: 'TorchSwitchTensorProvider', name: str, shape: List[Union[int64, int]], dtype: dtype=torch.float32, tensor_type=None) -> None:
504        self.shape = shape
505        self.dtype = dtype
506        if name is None:
507            self.name = "tensor%04d" % (np.random.randint(0, 10000),)
508        else:
509            self.name = name
510        self.provider = provider
511        self.tensor_type = tensor_type
512        ###
513        self.constraint_tensor=tprism.constraint.get_constraint_tensor(shape, tensor_type, device=None, dtype=None)
514        self.param = None
515        if self.constraint_tensor is None:
516            param = torch.nn.Parameter(torch.Tensor(*shape), requires_grad=True)
517            self.param = param
518            provider.add_param(self.name, param, tensor_type)
519            self.reset_parameters()
520        else:
521            param=list(self.constraint_tensor.parameters())[0] #TODO
522            provider.add_param(self.name, param, tensor_type)
523
524        ###
525    def reset_parameters(self) -> None:
526        if len(self.param.shape) == 2:
527            torch.nn.init.kaiming_uniform_(self.param, a=math.sqrt(5))
528        else:
529            self.param.data.uniform_(-0.1, 0.1)
530
531    def __call__(self):
532        if self.constraint_tensor is None:
533            return self.param
534        else:
535            return self.constraint_tensor()
TorchTensor( provider: TorchSwitchTensorProvider, name: str, shape: List[Union[numpy.int64, int]], dtype: torch.dtype = torch.float32, tensor_type=None)
503    def __init__(self, provider: 'TorchSwitchTensorProvider', name: str, shape: List[Union[int64, int]], dtype: dtype=torch.float32, tensor_type=None) -> None:
504        self.shape = shape
505        self.dtype = dtype
506        if name is None:
507            self.name = "tensor%04d" % (np.random.randint(0, 10000),)
508        else:
509            self.name = name
510        self.provider = provider
511        self.tensor_type = tensor_type
512        ###
513        self.constraint_tensor=tprism.constraint.get_constraint_tensor(shape, tensor_type, device=None, dtype=None)
514        self.param = None
515        if self.constraint_tensor is None:
516            param = torch.nn.Parameter(torch.Tensor(*shape), requires_grad=True)
517            self.param = param
518            provider.add_param(self.name, param, tensor_type)
519            self.reset_parameters()
520        else:
521            param=list(self.constraint_tensor.parameters())[0] #TODO
522            provider.add_param(self.name, param, tensor_type)
523
524        ###
shape
dtype
provider
tensor_type
constraint_tensor
param
def reset_parameters(self) -> None:
525    def reset_parameters(self) -> None:
526        if len(self.param.shape) == 2:
527            torch.nn.init.kaiming_uniform_(self.param, a=math.sqrt(5))
528        else:
529            self.param.data.uniform_(-0.1, 0.1)
class TorchGather(TorchTensorBase):
538class TorchGather(TorchTensorBase):
539    def __init__(self, provider: 'TorchSwitchTensorProvider', var: TorchTensor, idx: PlaceholderData) -> None:
540        self.var = var
541        self.idx = idx
542        self.provider = provider
543
544    def __call__(self):
545        if isinstance(self.idx, PlaceholderData):
546            idx = self.provider.get_embedding(self.idx)
547        else:
548            idx = self.idx
549        if isinstance(self.var, TorchTensor):
550            temp = self.var()
551            v = torch.index_select(temp, 0, idx)
552        elif isinstance(self.var, TorchTensorBase):
553            v = torch.index_select(self.var(), 0, idx)
554        elif isinstance(self.var, PlaceholderData):
555            v = self.provider.get_embedding(self.var)
556            v = v[idx]
557        else:
558            v = torch.index_select(self.var, 0, idx)
559        return v
TorchGather( provider: TorchSwitchTensorProvider, var: TorchTensor, idx: tprism.placeholder.PlaceholderData)
539    def __init__(self, provider: 'TorchSwitchTensorProvider', var: TorchTensor, idx: PlaceholderData) -> None:
540        self.var = var
541        self.idx = idx
542        self.provider = provider
var
idx
provider
class TorchSwitchTensorProvider(tprism.expl_graph.SwitchTensorProvider):
562class TorchSwitchTensorProvider(SwitchTensorProvider):
563    def __init__(self) -> None:
564        self.tensor_onehot_class = TorchTensorOnehot
565        self.tensor_class = TorchTensor
566        self.tensor_gather_class = TorchGather
567
568        self.integer_dtype = torch.int32
569        super().__init__()
570
571    # forward
572    def get_loss(self, verbose:bool=False):
573        loss={}
574        for name, (param,tensor_type) in self.params.items():
575            m=re.match(r"^sparse\(([0-9\.]*)\)$", tensor_type)
576            if m:
577                coeff = float(m.group(1))
578                l=coeff*torch.norm(param,1)
579                loss["sparse_"+name]=l
580            elif tensor_type=="sparse":
581                l=torch.norm(param,1)
582                loss["sparse_"+name]=l
583        return loss
584    # forward
585    def get_embedding(self, name: Union[str,PlaceholderData], verbose:bool=False):
586        if verbose:
587            print("[INFO] get embedding:", name)
588        out = None
589        ## TODO:
590        if self.input_feed_dict is None:
591            if verbose:
592                print("[INFO] from tensor_embedding", name)
593            obj = self.tensor_embedding[name]
594            if isinstance(obj, TorchTensorBase):
595                out = obj()
596            else:
597                raise Exception("Unknoen embedding type", name, type(obj))
598        elif type(name) is str:
599            key = self.tensor_embedding[name]
600            if type(key) is PlaceholderData:
601                if verbose:
602                    print("[INFO] from PlaceholderData", name, "==>", key.name)
603                out = self.input_feed_dict[key]
604            elif isinstance(key, TorchTensorBase):
605                if verbose:
606                    print("[INFO] from Tensor", name, "==>")
607                out = key()
608            else:
609                raise Exception("Unknoen embedding type", name, key)
610        elif type(name) is PlaceholderData:
611            if verbose:
612                print("[INFO] from PlaceholderData", name)
613            out = self.input_feed_dict[name]
614        else:
615            raise Exception("Unknoen embedding", name)
616        if verbose:
617            print(out)
618            print(type(out))
619            print("sum:", out.sum())
620        return out

This class provides information of switches

Attributes:
  • tensor_embedding (Dict[str, Tensor]): embedding tensor
  • sw_info (Dict[str, SwitchTensor]): switch infomation
  • ph_graph (PlaceholderGraph): associated placeholder graph
  • input_feed_dict (Dict[PlaceholderData, Tensor]): feed_dict to replace a placeholder with a tensor
  • params (Dict[str,Tuple[Parameter,str]]): pytorch parameters associated with all switches provided by this provider
tensor_onehot_class
tensor_class
tensor_gather_class
integer_dtype
def get_loss(self, verbose: bool = False):
572    def get_loss(self, verbose:bool=False):
573        loss={}
574        for name, (param,tensor_type) in self.params.items():
575            m=re.match(r"^sparse\(([0-9\.]*)\)$", tensor_type)
576            if m:
577                coeff = float(m.group(1))
578                l=coeff*torch.norm(param,1)
579                loss["sparse_"+name]=l
580            elif tensor_type=="sparse":
581                l=torch.norm(param,1)
582                loss["sparse_"+name]=l
583        return loss
def get_embedding( self, name: Union[str, tprism.placeholder.PlaceholderData], verbose: bool = False):
585    def get_embedding(self, name: Union[str,PlaceholderData], verbose:bool=False):
586        if verbose:
587            print("[INFO] get embedding:", name)
588        out = None
589        ## TODO:
590        if self.input_feed_dict is None:
591            if verbose:
592                print("[INFO] from tensor_embedding", name)
593            obj = self.tensor_embedding[name]
594            if isinstance(obj, TorchTensorBase):
595                out = obj()
596            else:
597                raise Exception("Unknoen embedding type", name, type(obj))
598        elif type(name) is str:
599            key = self.tensor_embedding[name]
600            if type(key) is PlaceholderData:
601                if verbose:
602                    print("[INFO] from PlaceholderData", name, "==>", key.name)
603                out = self.input_feed_dict[key]
604            elif isinstance(key, TorchTensorBase):
605                if verbose:
606                    print("[INFO] from Tensor", name, "==>")
607                out = key()
608            else:
609                raise Exception("Unknoen embedding type", name, key)
610        elif type(name) is PlaceholderData:
611            if verbose:
612                print("[INFO] from PlaceholderData", name)
613            out = self.input_feed_dict[name]
614        else:
615            raise Exception("Unknoen embedding", name)
616        if verbose:
617            print(out)
618            print(type(out))
619            print("sum:", out.sum())
620        return out