tprism.torch_expl_graph
This module contains pytorch explanation graphs and pytorch tensors.
This module constructs a computational graph by forward traversing the given explanation graph.
1""" 2This module contains pytorch explanation graphs and pytorch tensors. 3 4This module constructs a computational graph by forward traversing the given explanation graph. 5 6""" 7 8 9import torch 10import torch.nn.functional as F 11import json 12import re 13import numpy as np 14from google.protobuf import json_format 15 16from itertools import chain 17import collections 18 19import inspect 20import importlib 21import glob 22import os 23import re 24import pickle 25import h5py 26import math 27 28import tprism.expl_pb2 as expl_pb2 29import tprism.op.base 30import tprism.loss.base 31import tprism.constraint 32 33from tprism.expl_graph import ComputationalExplGraph, SwitchTensorProvider 34from tprism.expl_graph import PlaceholderGraph, VocabSet 35from tprism.loader import OperatorLoader 36from tprism.placeholder import PlaceholderData 37from numpy import int64 38from torch import dtype 39from typing import Any, Dict, List, Tuple, Union 40 41 42class TorchComputationalExplGraph(ComputationalExplGraph, torch.nn.Module): 43 """ This class is a concrete explanation graph for pytorch 44 45Note: 46 A path in the explanation graph is represented as the following format: 47 ``` 48 { 49 "sw_template": [], a list of template: List[str] 50 "sw_inside": [], a list of torch.tensor 51 "prob_sw_inside": [], a list of torch.tensor (scalar) 52 "node_template": [], a list of template: List[str] 53 "node_inside": [], a list of torch.tensor 54 "node_scalar_inside": [], a list of torch.tensor (scalar) 55 } 56 ``` 57Note: 58 A goal in the explanation graph is represented as the following format: 59 ``` 60 goal_inside[sorted_id] = { 61 "template": path_template, template: List[str] 62 "inside": path_inside, torch.tensor 63 "batch_flag": path_batch_flag, bool 64 } 65 ``` 66 67 Regarding the path template of goals, the T-PRISM's assumption requires that all pathes have equal tensor size. 68 However, index symbols need not be equal for all paths. 69 So the system display a warning in case of different index symbols and only uses the index symbols in the first path. 70 71 72 """ 73 def __init__(self, graph, tensor_provider, operator_loader, cycle_embedding_generator=None): 74 torch.nn.Module.__init__(self) 75 ComputationalExplGraph.__init__(self) 76 ## setting 77 self.operator_loader = None 78 self.goal_template = None 79 self.cycle_node = None 80 self.param_op = None 81 self.graph = graph 82 self.loss = {} 83 self.tensor_provider = tensor_provider 84 self.cycle_embedding_generator = cycle_embedding_generator 85 """ 86 if operator_loader is None: 87 operator_loader = OperatorLoader() 88 operator_loader.load_all("op/torch_") 89 """ 90 self.operator_loader = operator_loader 91 ### 92 self.build() 93 94 def build(self): 95 96 ## call super class 97 goal_template, cycle_node = self.build_explanation_graph_template( 98 self.graph, self.tensor_provider, self.operator_loader 99 ) 100 self.goal_template = goal_template 101 self.cycle_node = cycle_node 102 103 ## setting parameterized operators 104 self.param_op = torch.nn.ModuleList() 105 for k,v in self.operators.items(): 106 if issubclass(v.__class__, torch.nn.Module): 107 print("Torch parameterized operator:", k) 108 self.param_op.append(v) 109 ## 110 for name, (param,tensor_type) in self.tensor_provider.params.items(): 111 self.register_parameter(name, param) 112 113 def _distribution_forward(self, name, dist, params, param_template, op): 114 if dist == "normal": 115 mean = params[0] 116 var = params[1] 117 scale = torch.sqrt(F.softplus(var)) 118 q_z = torch.distributions.normal.Normal(mean, scale) 119 out_inside = q_z.rsample() 120 out_template = param_template[0] 121 p_z = torch.distributions.normal.Normal( 122 torch.zeros_like(mean), torch.ones_like(var) 123 ) 124 loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z) 125 self.loss[name] = loss_KL.sum() 126 else: 127 out_inside = params[0] 128 out_template = param_template[0] 129 print("[ERROR] unknown distribution:", dist) 130 return out_inside, out_template 131 132 def make_einsum_args(self, template,out_template,path_batch_flag): 133 """ 134 Example 135 template: [["i","j"],["j","k"]] 136 out_template: ["i","k"] 137 => "ij,jk->ik", out_template 138 """ 139 lhs = ",".join(map(lambda x: "".join(x), template)) 140 rhs = "".join(out_template) 141 if path_batch_flag: 142 rhs = "b" + rhs 143 out_template = ["b"] + out_template 144 einsum_eq = lhs + "->" + rhs 145 return einsum_eq, out_template 146 147 def make_einsum_args_sublist(self,template,inputs,out_template,path_batch_flag): 148 """ 149 Example 150 template: [["i","j"],["j","k"]] 151 out_template: ["i","k"] 152 => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template 153 """ 154 symbol_set = set([e for sublist in template for e in sublist]) 155 mapping={s:i for i,s in enumerate(symbol_set)} 156 if path_batch_flag: 157 out_template = ["b"] + out_template 158 sublistform_args=[] 159 for v,input_x in zip(template,inputs): 160 l=[mapping[el] for el in v] 161 sublistform_args.append(input_x) 162 sublistform_args.append(l) 163 sublistform_args.append([mapping[el] for el in out_template]) 164 return sublistform_args, out_template 165 166 def _apply_operator(self,op,operator_loader,out_inside, out_template, dryrun): 167 ## restore operator 168 key=str(op.name)+"_"+str(op.values) 169 if key in self.operators: 170 op_obj=self.operators[key] 171 else: 172 #cls = operator_loader.get_operator(op.name) 173 #op_obj = cls(op.values) 174 assert True, "unknown operator "+key+" has been found in forward procedure" 175 if dryrun: 176 out_inside = { 177 "type": "operator", 178 "name": op.name, 179 "path": out_inside} 180 else: 181 out_inside = op_obj.call(out_inside) 182 out_template = op_obj.get_output_template(out_template) 183 return out_inside,out_template 184 185 def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False): 186 goal_template = self.goal_template 187 cycle_embedding_generator = self.cycle_embedding_generator 188 cycle_node=self.cycle_node 189 node_template = [] 190 node_inside = [] 191 node_scalar_inside = [] 192 path_batch_flag=False 193 for node in path.nodes: 194 temp_goal = goal_inside[node.sorted_id] 195 if node.sorted_id in cycle_node: 196 name = node.goal.name 197 template = goal_template[node.sorted_id]["template"] 198 shape = goal_template[node.sorted_id]["shape"] 199 if dryrun: 200 args = node.goal.args 201 temp_goal_inside={ 202 "type":"goal", 203 "from":"cycle_embedding_generator", 204 "name": name, 205 "args": args, 206 "id": node.sorted_id, 207 "shape":shape} 208 else: 209 temp_goal_inside = cycle_embedding_generator.forward( 210 name, shape, node.sorted_id 211 ) 212 temp_goal_template = template 213 node_inside.append(temp_goal_inside) 214 node_template.append(temp_goal_template) 215 elif temp_goal is None: 216 print(" [ERROR] cycle node is detected") 217 temp_goal = goal_inside[node.sorted_id] 218 print(g.node.sorted_id) 219 print(node) 220 print(node.sorted_id) 221 print(temp_goal) 222 quit() 223 elif len(temp_goal["template"]) > 0: 224 # tensor subgoal 225 if dryrun: 226 name = node.goal.name 227 args = node.goal.args 228 temp_goal_inside={ 229 "type":"goal", 230 "from":"goal", 231 "name": name, 232 "args": args, 233 "id": node.sorted_id,} 234 #"shape":shape} 235 else: 236 temp_goal_inside = temp_goal["inside"] 237 temp_goal_template = temp_goal["template"] 238 if temp_goal["batch_flag"]: 239 path_batch_flag = True 240 node_inside.append(temp_goal_inside) 241 node_template.append(temp_goal_template) 242 else: # scalar subgoal 243 if dryrun: 244 name = node.goal.name 245 args = node.goal.args 246 temp_goal_inside={ 247 "type":"goal", 248 "from":"goal", 249 "name": name, 250 "args": args, 251 "id": node.sorted_id,} 252 #"shape":()} 253 node_scalar_inside.append(temp_goal_inside) 254 else: 255 if type(temp_goal["inside"]) is list: 256 a = torch.tensor(temp_goal["inside"]) 257 node_scalar_inside.append(torch.squeeze(a)) 258 else: 259 node_scalar_inside.append(temp_goal["inside"]) 260 return node_template, node_inside, node_scalar_inside, path_batch_flag 261 262 def forward_path_sw(self, path, verbose=False,verbose_embedding=False, dryrun=False): 263 tensor_provider = self.tensor_provider 264 sw_template = [] 265 sw_inside = [] 266 path_batch_flag=False 267 for sw in path.tensor_switches: 268 ph = tensor_provider.get_placeholder_name(sw.name) 269 if len(ph) > 0: 270 sw_template.append(["b"] + list(sw.values)) 271 path_batch_flag = True 272 else: 273 sw_template.append(list(sw.values)) 274 if dryrun: 275 #x = tensor_provider.get_embedding(sw.name, verbose_embedding) 276 sw_var = {"type":"tensor_atom", 277 "from":"tensor_provider.get_embedding", 278 "name":sw.name,} 279 #"shape":x.shape} 280 else: 281 sw_var = tensor_provider.get_embedding(sw.name, verbose_embedding) 282 sw_inside.append(sw_var) 283 prob_sw_inside = [] 284 if dryrun: 285 for sw in path.prob_switches: 286 prob_sw_inside.append({ 287 "type":"const", 288 "name": sw.name, 289 "value": sw.inside, 290 "shape":(),}) 291 else: 292 for sw in path.prob_switches: 293 prob_sw_inside.append(sw.inside) 294 """ 295 prob_sw_inside.append({ 296 "type":"const", 297 "name": sw.name, 298 "value": sw.inside, 299 "shape":(),}) 300 """ 301 return sw_template, sw_inside, prob_sw_inside, path_batch_flag 302 303 def forward_path_op(self, ops, 304 sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,path_batch_flag, verbose=False, dryrun=False): 305 306 if "distribution" in ops: 307 op = ops["distribution"] 308 dist = op.values[0] 309 name = g.node.goal.name 310 if dryrun: 311 #out_inside, out_template = self._distribution_forward_dryrun 312 out_inside={ 313 "type": "distribution", 314 "name": name, 315 "dist_type": op, 316 "path": sw_node_inside} 317 out_template= sw_node_template#TODO 318 else: 319 out_inside, out_template = self._distribution_forward( 320 name, 321 dist, 322 params=sw_node_inside, 323 param_template=sw_node_template, 324 op=op, 325 ) 326 327 else:# einsum operator 328 path_v = sorted( 329 zip(sw_node_template, sw_node_inside), key=lambda x: x[0] 330 ) 331 template = [x[0] for x in path_v] 332 inside = [x[1] for x in path_v] 333 # constructing einsum operation using template and inside 334 out_template = self._compute_output_template(template) 335 # print(template,out_template) 336 if len(template) > 0: # condition for einsum 337 if verbose: 338 einsum_eq, out_template_v = self.make_einsum_args(template,out_template,path_batch_flag) 339 print(" index:", einsum_eq) 340 print(" var. :", [x.shape for x in inside]) 341 #print(" var. :", inside) 342 if dryrun: 343 einsum_eq, out_template = self.make_einsum_args(template,out_template,path_batch_flag) 344 out_inside = { 345 "type":"einsum", 346 "name":"torch.einsum", 347 "einsum_eq":einsum_eq, 348 "path": inside} 349 else: 350 einsum_args, out_template = self.make_einsum_args_sublist(template,inside,out_template,path_batch_flag) 351 #out_inside = torch.einsum(einsum_eq, *inside) * out_inside 352 out_inside = torch.einsum(*einsum_args) 353 else: # condition for scalar 354 if dryrun: 355 out_inside = { 356 "type":"nop", 357 "name":"nop", 358 "path": inside} 359 if dryrun: 360 out_inside["path_scalar"]=node_scalar_inside+prob_sw_inside 361 else: 362 # scalar subgoal 363 for scalar_inside in node_scalar_inside: 364 out_inside = scalar_inside * out_inside 365 # prob(scalar) switch 366 for prob_inside in prob_sw_inside: 367 out_inside = prob_inside * out_inside 368 369 ## computing operaters 370 for op_name, op in ops.items(): 371 if verbose: 372 print(" operator:", op_name) 373 out_inside,out_template=self._apply_operator( 374 op, 375 operator_loader, 376 out_inside, 377 out_template, 378 dryrun) 379 ## 380 return out_inside,out_template 381 382 def forward(self, verbose=False,verbose_embedding=False, dryrun=False): 383 """ 384 Args: 385 verbose (bool): if true, this function displays an explanation graph with forward computation 386 dryrun (bool): if true, this function outputs information required for calculation as goal_inside instead of computational graph 387 388 Returns: 389 Tuple[List[Dict],Dict]: a pair of goal_inside and loss: 390 - goal_inside: tensors assigned for each goal 391 - loss: loss derived from explanation graph: key = loss name and value = loss 392 """ 393 graph = self.graph 394 tensor_provider = self.tensor_provider 395 cycle_embedding_generator = self.cycle_embedding_generator 396 goal_template = self.goal_template 397 cycle_node = self.cycle_node 398 operator_loader = self.operator_loader 399 self.loss = {} 400 # goal_template 401 # converting explanation graph to computational graph 402 goal_inside = [None] * len(graph.goals) 403 for i in range(len(graph.goals)): 404 g = graph.goals[i] 405 if verbose: 406 print( 407 "=== tensor equation (node_id:%d, %s) ===" 408 % (g.node.sorted_id, g.node.goal.name) 409 ) 410 path_inside = [] 411 path_template = [] 412 path_batch_flag = False 413 for path in g.paths: 414 ## build template and inside for switches in the path 415 sw_template, sw_inside, prob_sw_inside, batch_flag = self.forward_path_sw( 416 path, verbose, verbose_embedding, dryrun) 417 path_batch_flag = batch_flag or path_batch_flag 418 419 ## building template and inside for nodes in the path 420 node_template, node_inside, node_scalar_inside, batch_flag = self.forward_path_node( 421 path, goal_inside, verbose, dryrun) 422 path_batch_flag = batch_flag or path_batch_flag 423 424 ## building template and inside for all elements (switches and nodes) in the path 425 sw_node_template = sw_template + node_template 426 sw_node_inside = sw_inside + node_inside 427 428 ops = {op.name: op for op in path.operators} 429 430 out_inside,out_template = self.forward_path_op(ops, 431 sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside, 432 path_batch_flag, 433 verbose, dryrun) 434 435 path_inside.append(out_inside) 436 path_template.append(out_template) 437 ## 438 ### update inside 439 path_template_list = self._get_unique_list(path_template) 440 if len(path_template_list) == 0: # non-tensor/non-probabilistic path 441 if dryrun: 442 goal_inside[i] = { 443 "template": [], 444 "inside": path_inside, 445 "batch_flag": False, 446 } 447 else: 448 goal_inside[i] = { 449 "template": [], 450 "inside": torch.tensor(1), 451 "batch_flag": False, 452 } 453 else: 454 if len(path_template_list) != 1: 455 print("[WARNING] missmatch indices:", path_template_list) 456 if dryrun: 457 goal_inside[i] = { 458 "template": path_template_list[0], 459 "inside": path_inside, 460 "batch_flag": path_batch_flag, 461 } 462 else: 463 if len(path_template_list[0]) == 0: # scalar inside 464 goal_inside[i] = { 465 "template": path_template_list[0], 466 "inside": path_inside[0], 467 "batch_flag": path_batch_flag, 468 } 469 else: 470 temp_inside=torch.sum(torch.stack(path_inside), dim=0) 471 goal_inside[i] = { 472 "template": path_template_list[0], 473 "inside": temp_inside, 474 "batch_flag": path_batch_flag, 475 } 476 if dryrun: 477 goal_inside[i]["id"]=g.node.sorted_id 478 goal_inside[i]["name"]=g.node.goal.name 479 goal_inside[i]["args"]=g.node.goal.args 480 481 self.loss.update(tensor_provider.get_loss()) 482 return goal_inside, self.loss 483 484 485 486class TorchTensorBase: 487 def __init__(self): 488 pass 489 490 491class TorchTensorOnehot(TorchTensorBase): 492 def __init__(self, provider, shape, value): 493 self.shape = shape 494 self.value = value 495 496 def __call__(self): 497 v = torch.eye(self.shape)[self.value] 498 return v 499 500 501class TorchTensor(TorchTensorBase): 502 def __init__(self, provider: 'TorchSwitchTensorProvider', name: str, shape: List[Union[int64, int]], dtype: dtype=torch.float32, tensor_type=None) -> None: 503 self.shape = shape 504 self.dtype = dtype 505 if name is None: 506 self.name = "tensor%04d" % (np.random.randint(0, 10000),) 507 else: 508 self.name = name 509 self.provider = provider 510 self.tensor_type = tensor_type 511 ### 512 self.constraint_tensor=tprism.constraint.get_constraint_tensor(shape, tensor_type, device=None, dtype=None) 513 self.param = None 514 if self.constraint_tensor is None: 515 param = torch.nn.Parameter(torch.Tensor(*shape), requires_grad=True) 516 self.param = param 517 provider.add_param(self.name, param, tensor_type) 518 self.reset_parameters() 519 else: 520 param=list(self.constraint_tensor.parameters())[0] #TODO 521 provider.add_param(self.name, param, tensor_type) 522 523 ### 524 def reset_parameters(self) -> None: 525 if len(self.param.shape) == 2: 526 torch.nn.init.kaiming_uniform_(self.param, a=math.sqrt(5)) 527 else: 528 self.param.data.uniform_(-0.1, 0.1) 529 530 def __call__(self): 531 if self.constraint_tensor is None: 532 return self.param 533 else: 534 return self.constraint_tensor() 535 536 537class TorchGather(TorchTensorBase): 538 def __init__(self, provider: 'TorchSwitchTensorProvider', var: TorchTensor, idx: PlaceholderData) -> None: 539 self.var = var 540 self.idx = idx 541 self.provider = provider 542 543 def __call__(self): 544 if isinstance(self.idx, PlaceholderData): 545 idx = self.provider.get_embedding(self.idx) 546 else: 547 idx = self.idx 548 if isinstance(self.var, TorchTensor): 549 temp = self.var() 550 v = torch.index_select(temp, 0, idx) 551 elif isinstance(self.var, TorchTensorBase): 552 v = torch.index_select(self.var(), 0, idx) 553 elif isinstance(self.var, PlaceholderData): 554 v = self.provider.get_embedding(self.var) 555 v = v[idx] 556 else: 557 v = torch.index_select(self.var, 0, idx) 558 return v 559 560 561class TorchSwitchTensorProvider(SwitchTensorProvider): 562 def __init__(self) -> None: 563 self.tensor_onehot_class = TorchTensorOnehot 564 self.tensor_class = TorchTensor 565 self.tensor_gather_class = TorchGather 566 567 self.integer_dtype = torch.int32 568 super().__init__() 569 570 # forward 571 def get_loss(self, verbose:bool=False): 572 loss={} 573 for name, (param,tensor_type) in self.params.items(): 574 m=re.match(r"^sparse\(([0-9\.]*)\)$", tensor_type) 575 if m: 576 coeff = float(m.group(1)) 577 l=coeff*torch.norm(param,1) 578 loss["sparse_"+name]=l 579 elif tensor_type=="sparse": 580 l=torch.norm(param,1) 581 loss["sparse_"+name]=l 582 return loss 583 # forward 584 def get_embedding(self, name: Union[str,PlaceholderData], verbose:bool=False): 585 if verbose: 586 print("[INFO] get embedding:", name) 587 out = None 588 ## TODO: 589 if self.input_feed_dict is None: 590 if verbose: 591 print("[INFO] from tensor_embedding", name) 592 obj = self.tensor_embedding[name] 593 if isinstance(obj, TorchTensorBase): 594 out = obj() 595 else: 596 raise Exception("Unknoen embedding type", name, type(obj)) 597 elif type(name) is str: 598 key = self.tensor_embedding[name] 599 if type(key) is PlaceholderData: 600 if verbose: 601 print("[INFO] from PlaceholderData", name, "==>", key.name) 602 out = self.input_feed_dict[key] 603 elif isinstance(key, TorchTensorBase): 604 if verbose: 605 print("[INFO] from Tensor", name, "==>") 606 out = key() 607 else: 608 raise Exception("Unknoen embedding type", name, key) 609 elif type(name) is PlaceholderData: 610 if verbose: 611 print("[INFO] from PlaceholderData", name) 612 out = self.input_feed_dict[name] 613 else: 614 raise Exception("Unknoen embedding", name) 615 if verbose: 616 print(out) 617 print(type(out)) 618 print("sum:", out.sum()) 619 return out
43class TorchComputationalExplGraph(ComputationalExplGraph, torch.nn.Module): 44 """ This class is a concrete explanation graph for pytorch 45 46Note: 47 A path in the explanation graph is represented as the following format: 48 ``` 49 { 50 "sw_template": [], a list of template: List[str] 51 "sw_inside": [], a list of torch.tensor 52 "prob_sw_inside": [], a list of torch.tensor (scalar) 53 "node_template": [], a list of template: List[str] 54 "node_inside": [], a list of torch.tensor 55 "node_scalar_inside": [], a list of torch.tensor (scalar) 56 } 57 ``` 58Note: 59 A goal in the explanation graph is represented as the following format: 60 ``` 61 goal_inside[sorted_id] = { 62 "template": path_template, template: List[str] 63 "inside": path_inside, torch.tensor 64 "batch_flag": path_batch_flag, bool 65 } 66 ``` 67 68 Regarding the path template of goals, the T-PRISM's assumption requires that all pathes have equal tensor size. 69 However, index symbols need not be equal for all paths. 70 So the system display a warning in case of different index symbols and only uses the index symbols in the first path. 71 72 73 """ 74 def __init__(self, graph, tensor_provider, operator_loader, cycle_embedding_generator=None): 75 torch.nn.Module.__init__(self) 76 ComputationalExplGraph.__init__(self) 77 ## setting 78 self.operator_loader = None 79 self.goal_template = None 80 self.cycle_node = None 81 self.param_op = None 82 self.graph = graph 83 self.loss = {} 84 self.tensor_provider = tensor_provider 85 self.cycle_embedding_generator = cycle_embedding_generator 86 """ 87 if operator_loader is None: 88 operator_loader = OperatorLoader() 89 operator_loader.load_all("op/torch_") 90 """ 91 self.operator_loader = operator_loader 92 ### 93 self.build() 94 95 def build(self): 96 97 ## call super class 98 goal_template, cycle_node = self.build_explanation_graph_template( 99 self.graph, self.tensor_provider, self.operator_loader 100 ) 101 self.goal_template = goal_template 102 self.cycle_node = cycle_node 103 104 ## setting parameterized operators 105 self.param_op = torch.nn.ModuleList() 106 for k,v in self.operators.items(): 107 if issubclass(v.__class__, torch.nn.Module): 108 print("Torch parameterized operator:", k) 109 self.param_op.append(v) 110 ## 111 for name, (param,tensor_type) in self.tensor_provider.params.items(): 112 self.register_parameter(name, param) 113 114 def _distribution_forward(self, name, dist, params, param_template, op): 115 if dist == "normal": 116 mean = params[0] 117 var = params[1] 118 scale = torch.sqrt(F.softplus(var)) 119 q_z = torch.distributions.normal.Normal(mean, scale) 120 out_inside = q_z.rsample() 121 out_template = param_template[0] 122 p_z = torch.distributions.normal.Normal( 123 torch.zeros_like(mean), torch.ones_like(var) 124 ) 125 loss_KL = torch.distributions.kl.kl_divergence(q_z, p_z) 126 self.loss[name] = loss_KL.sum() 127 else: 128 out_inside = params[0] 129 out_template = param_template[0] 130 print("[ERROR] unknown distribution:", dist) 131 return out_inside, out_template 132 133 def make_einsum_args(self, template,out_template,path_batch_flag): 134 """ 135 Example 136 template: [["i","j"],["j","k"]] 137 out_template: ["i","k"] 138 => "ij,jk->ik", out_template 139 """ 140 lhs = ",".join(map(lambda x: "".join(x), template)) 141 rhs = "".join(out_template) 142 if path_batch_flag: 143 rhs = "b" + rhs 144 out_template = ["b"] + out_template 145 einsum_eq = lhs + "->" + rhs 146 return einsum_eq, out_template 147 148 def make_einsum_args_sublist(self,template,inputs,out_template,path_batch_flag): 149 """ 150 Example 151 template: [["i","j"],["j","k"]] 152 out_template: ["i","k"] 153 => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template 154 """ 155 symbol_set = set([e for sublist in template for e in sublist]) 156 mapping={s:i for i,s in enumerate(symbol_set)} 157 if path_batch_flag: 158 out_template = ["b"] + out_template 159 sublistform_args=[] 160 for v,input_x in zip(template,inputs): 161 l=[mapping[el] for el in v] 162 sublistform_args.append(input_x) 163 sublistform_args.append(l) 164 sublistform_args.append([mapping[el] for el in out_template]) 165 return sublistform_args, out_template 166 167 def _apply_operator(self,op,operator_loader,out_inside, out_template, dryrun): 168 ## restore operator 169 key=str(op.name)+"_"+str(op.values) 170 if key in self.operators: 171 op_obj=self.operators[key] 172 else: 173 #cls = operator_loader.get_operator(op.name) 174 #op_obj = cls(op.values) 175 assert True, "unknown operator "+key+" has been found in forward procedure" 176 if dryrun: 177 out_inside = { 178 "type": "operator", 179 "name": op.name, 180 "path": out_inside} 181 else: 182 out_inside = op_obj.call(out_inside) 183 out_template = op_obj.get_output_template(out_template) 184 return out_inside,out_template 185 186 def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False): 187 goal_template = self.goal_template 188 cycle_embedding_generator = self.cycle_embedding_generator 189 cycle_node=self.cycle_node 190 node_template = [] 191 node_inside = [] 192 node_scalar_inside = [] 193 path_batch_flag=False 194 for node in path.nodes: 195 temp_goal = goal_inside[node.sorted_id] 196 if node.sorted_id in cycle_node: 197 name = node.goal.name 198 template = goal_template[node.sorted_id]["template"] 199 shape = goal_template[node.sorted_id]["shape"] 200 if dryrun: 201 args = node.goal.args 202 temp_goal_inside={ 203 "type":"goal", 204 "from":"cycle_embedding_generator", 205 "name": name, 206 "args": args, 207 "id": node.sorted_id, 208 "shape":shape} 209 else: 210 temp_goal_inside = cycle_embedding_generator.forward( 211 name, shape, node.sorted_id 212 ) 213 temp_goal_template = template 214 node_inside.append(temp_goal_inside) 215 node_template.append(temp_goal_template) 216 elif temp_goal is None: 217 print(" [ERROR] cycle node is detected") 218 temp_goal = goal_inside[node.sorted_id] 219 print(g.node.sorted_id) 220 print(node) 221 print(node.sorted_id) 222 print(temp_goal) 223 quit() 224 elif len(temp_goal["template"]) > 0: 225 # tensor subgoal 226 if dryrun: 227 name = node.goal.name 228 args = node.goal.args 229 temp_goal_inside={ 230 "type":"goal", 231 "from":"goal", 232 "name": name, 233 "args": args, 234 "id": node.sorted_id,} 235 #"shape":shape} 236 else: 237 temp_goal_inside = temp_goal["inside"] 238 temp_goal_template = temp_goal["template"] 239 if temp_goal["batch_flag"]: 240 path_batch_flag = True 241 node_inside.append(temp_goal_inside) 242 node_template.append(temp_goal_template) 243 else: # scalar subgoal 244 if dryrun: 245 name = node.goal.name 246 args = node.goal.args 247 temp_goal_inside={ 248 "type":"goal", 249 "from":"goal", 250 "name": name, 251 "args": args, 252 "id": node.sorted_id,} 253 #"shape":()} 254 node_scalar_inside.append(temp_goal_inside) 255 else: 256 if type(temp_goal["inside"]) is list: 257 a = torch.tensor(temp_goal["inside"]) 258 node_scalar_inside.append(torch.squeeze(a)) 259 else: 260 node_scalar_inside.append(temp_goal["inside"]) 261 return node_template, node_inside, node_scalar_inside, path_batch_flag 262 263 def forward_path_sw(self, path, verbose=False,verbose_embedding=False, dryrun=False): 264 tensor_provider = self.tensor_provider 265 sw_template = [] 266 sw_inside = [] 267 path_batch_flag=False 268 for sw in path.tensor_switches: 269 ph = tensor_provider.get_placeholder_name(sw.name) 270 if len(ph) > 0: 271 sw_template.append(["b"] + list(sw.values)) 272 path_batch_flag = True 273 else: 274 sw_template.append(list(sw.values)) 275 if dryrun: 276 #x = tensor_provider.get_embedding(sw.name, verbose_embedding) 277 sw_var = {"type":"tensor_atom", 278 "from":"tensor_provider.get_embedding", 279 "name":sw.name,} 280 #"shape":x.shape} 281 else: 282 sw_var = tensor_provider.get_embedding(sw.name, verbose_embedding) 283 sw_inside.append(sw_var) 284 prob_sw_inside = [] 285 if dryrun: 286 for sw in path.prob_switches: 287 prob_sw_inside.append({ 288 "type":"const", 289 "name": sw.name, 290 "value": sw.inside, 291 "shape":(),}) 292 else: 293 for sw in path.prob_switches: 294 prob_sw_inside.append(sw.inside) 295 """ 296 prob_sw_inside.append({ 297 "type":"const", 298 "name": sw.name, 299 "value": sw.inside, 300 "shape":(),}) 301 """ 302 return sw_template, sw_inside, prob_sw_inside, path_batch_flag 303 304 def forward_path_op(self, ops, 305 sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,path_batch_flag, verbose=False, dryrun=False): 306 307 if "distribution" in ops: 308 op = ops["distribution"] 309 dist = op.values[0] 310 name = g.node.goal.name 311 if dryrun: 312 #out_inside, out_template = self._distribution_forward_dryrun 313 out_inside={ 314 "type": "distribution", 315 "name": name, 316 "dist_type": op, 317 "path": sw_node_inside} 318 out_template= sw_node_template#TODO 319 else: 320 out_inside, out_template = self._distribution_forward( 321 name, 322 dist, 323 params=sw_node_inside, 324 param_template=sw_node_template, 325 op=op, 326 ) 327 328 else:# einsum operator 329 path_v = sorted( 330 zip(sw_node_template, sw_node_inside), key=lambda x: x[0] 331 ) 332 template = [x[0] for x in path_v] 333 inside = [x[1] for x in path_v] 334 # constructing einsum operation using template and inside 335 out_template = self._compute_output_template(template) 336 # print(template,out_template) 337 if len(template) > 0: # condition for einsum 338 if verbose: 339 einsum_eq, out_template_v = self.make_einsum_args(template,out_template,path_batch_flag) 340 print(" index:", einsum_eq) 341 print(" var. :", [x.shape for x in inside]) 342 #print(" var. :", inside) 343 if dryrun: 344 einsum_eq, out_template = self.make_einsum_args(template,out_template,path_batch_flag) 345 out_inside = { 346 "type":"einsum", 347 "name":"torch.einsum", 348 "einsum_eq":einsum_eq, 349 "path": inside} 350 else: 351 einsum_args, out_template = self.make_einsum_args_sublist(template,inside,out_template,path_batch_flag) 352 #out_inside = torch.einsum(einsum_eq, *inside) * out_inside 353 out_inside = torch.einsum(*einsum_args) 354 else: # condition for scalar 355 if dryrun: 356 out_inside = { 357 "type":"nop", 358 "name":"nop", 359 "path": inside} 360 if dryrun: 361 out_inside["path_scalar"]=node_scalar_inside+prob_sw_inside 362 else: 363 # scalar subgoal 364 for scalar_inside in node_scalar_inside: 365 out_inside = scalar_inside * out_inside 366 # prob(scalar) switch 367 for prob_inside in prob_sw_inside: 368 out_inside = prob_inside * out_inside 369 370 ## computing operaters 371 for op_name, op in ops.items(): 372 if verbose: 373 print(" operator:", op_name) 374 out_inside,out_template=self._apply_operator( 375 op, 376 operator_loader, 377 out_inside, 378 out_template, 379 dryrun) 380 ## 381 return out_inside,out_template 382 383 def forward(self, verbose=False,verbose_embedding=False, dryrun=False): 384 """ 385 Args: 386 verbose (bool): if true, this function displays an explanation graph with forward computation 387 dryrun (bool): if true, this function outputs information required for calculation as goal_inside instead of computational graph 388 389 Returns: 390 Tuple[List[Dict],Dict]: a pair of goal_inside and loss: 391 - goal_inside: tensors assigned for each goal 392 - loss: loss derived from explanation graph: key = loss name and value = loss 393 """ 394 graph = self.graph 395 tensor_provider = self.tensor_provider 396 cycle_embedding_generator = self.cycle_embedding_generator 397 goal_template = self.goal_template 398 cycle_node = self.cycle_node 399 operator_loader = self.operator_loader 400 self.loss = {} 401 # goal_template 402 # converting explanation graph to computational graph 403 goal_inside = [None] * len(graph.goals) 404 for i in range(len(graph.goals)): 405 g = graph.goals[i] 406 if verbose: 407 print( 408 "=== tensor equation (node_id:%d, %s) ===" 409 % (g.node.sorted_id, g.node.goal.name) 410 ) 411 path_inside = [] 412 path_template = [] 413 path_batch_flag = False 414 for path in g.paths: 415 ## build template and inside for switches in the path 416 sw_template, sw_inside, prob_sw_inside, batch_flag = self.forward_path_sw( 417 path, verbose, verbose_embedding, dryrun) 418 path_batch_flag = batch_flag or path_batch_flag 419 420 ## building template and inside for nodes in the path 421 node_template, node_inside, node_scalar_inside, batch_flag = self.forward_path_node( 422 path, goal_inside, verbose, dryrun) 423 path_batch_flag = batch_flag or path_batch_flag 424 425 ## building template and inside for all elements (switches and nodes) in the path 426 sw_node_template = sw_template + node_template 427 sw_node_inside = sw_inside + node_inside 428 429 ops = {op.name: op for op in path.operators} 430 431 out_inside,out_template = self.forward_path_op(ops, 432 sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside, 433 path_batch_flag, 434 verbose, dryrun) 435 436 path_inside.append(out_inside) 437 path_template.append(out_template) 438 ## 439 ### update inside 440 path_template_list = self._get_unique_list(path_template) 441 if len(path_template_list) == 0: # non-tensor/non-probabilistic path 442 if dryrun: 443 goal_inside[i] = { 444 "template": [], 445 "inside": path_inside, 446 "batch_flag": False, 447 } 448 else: 449 goal_inside[i] = { 450 "template": [], 451 "inside": torch.tensor(1), 452 "batch_flag": False, 453 } 454 else: 455 if len(path_template_list) != 1: 456 print("[WARNING] missmatch indices:", path_template_list) 457 if dryrun: 458 goal_inside[i] = { 459 "template": path_template_list[0], 460 "inside": path_inside, 461 "batch_flag": path_batch_flag, 462 } 463 else: 464 if len(path_template_list[0]) == 0: # scalar inside 465 goal_inside[i] = { 466 "template": path_template_list[0], 467 "inside": path_inside[0], 468 "batch_flag": path_batch_flag, 469 } 470 else: 471 temp_inside=torch.sum(torch.stack(path_inside), dim=0) 472 goal_inside[i] = { 473 "template": path_template_list[0], 474 "inside": temp_inside, 475 "batch_flag": path_batch_flag, 476 } 477 if dryrun: 478 goal_inside[i]["id"]=g.node.sorted_id 479 goal_inside[i]["name"]=g.node.goal.name 480 goal_inside[i]["args"]=g.node.goal.args 481 482 self.loss.update(tensor_provider.get_loss()) 483 return goal_inside, self.loss
This class is a concrete explanation graph for pytorch
Note:
A path in the explanation graph is represented as the following format:
{ "sw_template": [], a list of template: List[str] "sw_inside": [], a list of torch.tensor "prob_sw_inside": [], a list of torch.tensor (scalar) "node_template": [], a list of template: List[str] "node_inside": [], a list of torch.tensor "node_scalar_inside": [], a list of torch.tensor (scalar) }
Note:
A goal in the explanation graph is represented as the following format:
goal_inside[sorted_id] = { "template": path_template, template: List[str] "inside": path_inside, torch.tensor "batch_flag": path_batch_flag, bool }
Regarding the path template of goals, the T-PRISM's assumption requires that all pathes have equal tensor size. However, index symbols need not be equal for all paths. So the system display a warning in case of different index symbols and only uses the index symbols in the first path.
74 def __init__(self, graph, tensor_provider, operator_loader, cycle_embedding_generator=None): 75 torch.nn.Module.__init__(self) 76 ComputationalExplGraph.__init__(self) 77 ## setting 78 self.operator_loader = None 79 self.goal_template = None 80 self.cycle_node = None 81 self.param_op = None 82 self.graph = graph 83 self.loss = {} 84 self.tensor_provider = tensor_provider 85 self.cycle_embedding_generator = cycle_embedding_generator 86 """ 87 if operator_loader is None: 88 operator_loader = OperatorLoader() 89 operator_loader.load_all("op/torch_") 90 """ 91 self.operator_loader = operator_loader 92 ### 93 self.build()
Initialize internal Module state, shared by both nn.Module and ScriptModule.
if operator_loader is None: operator_loader = OperatorLoader() operator_loader.load_all("op/torch_")
95 def build(self): 96 97 ## call super class 98 goal_template, cycle_node = self.build_explanation_graph_template( 99 self.graph, self.tensor_provider, self.operator_loader 100 ) 101 self.goal_template = goal_template 102 self.cycle_node = cycle_node 103 104 ## setting parameterized operators 105 self.param_op = torch.nn.ModuleList() 106 for k,v in self.operators.items(): 107 if issubclass(v.__class__, torch.nn.Module): 108 print("Torch parameterized operator:", k) 109 self.param_op.append(v) 110 ## 111 for name, (param,tensor_type) in self.tensor_provider.params.items(): 112 self.register_parameter(name, param)
133 def make_einsum_args(self, template,out_template,path_batch_flag): 134 """ 135 Example 136 template: [["i","j"],["j","k"]] 137 out_template: ["i","k"] 138 => "ij,jk->ik", out_template 139 """ 140 lhs = ",".join(map(lambda x: "".join(x), template)) 141 rhs = "".join(out_template) 142 if path_batch_flag: 143 rhs = "b" + rhs 144 out_template = ["b"] + out_template 145 einsum_eq = lhs + "->" + rhs 146 return einsum_eq, out_template
Example template: [["i","j"],["j","k"]] out_template: ["i","k"] => "ij,jk->ik", out_template
148 def make_einsum_args_sublist(self,template,inputs,out_template,path_batch_flag): 149 """ 150 Example 151 template: [["i","j"],["j","k"]] 152 out_template: ["i","k"] 153 => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template 154 """ 155 symbol_set = set([e for sublist in template for e in sublist]) 156 mapping={s:i for i,s in enumerate(symbol_set)} 157 if path_batch_flag: 158 out_template = ["b"] + out_template 159 sublistform_args=[] 160 for v,input_x in zip(template,inputs): 161 l=[mapping[el] for el in v] 162 sublistform_args.append(input_x) 163 sublistform_args.append(l) 164 sublistform_args.append([mapping[el] for el in out_template]) 165 return sublistform_args, out_template
Example template: [["i","j"],["j","k"]] out_template: ["i","k"] => [inputs[0], [0,1], inputs[1], [1,2], [0,2]], out_template
186 def forward_path_node(self, path, goal_inside, verbose=False, dryrun=False): 187 goal_template = self.goal_template 188 cycle_embedding_generator = self.cycle_embedding_generator 189 cycle_node=self.cycle_node 190 node_template = [] 191 node_inside = [] 192 node_scalar_inside = [] 193 path_batch_flag=False 194 for node in path.nodes: 195 temp_goal = goal_inside[node.sorted_id] 196 if node.sorted_id in cycle_node: 197 name = node.goal.name 198 template = goal_template[node.sorted_id]["template"] 199 shape = goal_template[node.sorted_id]["shape"] 200 if dryrun: 201 args = node.goal.args 202 temp_goal_inside={ 203 "type":"goal", 204 "from":"cycle_embedding_generator", 205 "name": name, 206 "args": args, 207 "id": node.sorted_id, 208 "shape":shape} 209 else: 210 temp_goal_inside = cycle_embedding_generator.forward( 211 name, shape, node.sorted_id 212 ) 213 temp_goal_template = template 214 node_inside.append(temp_goal_inside) 215 node_template.append(temp_goal_template) 216 elif temp_goal is None: 217 print(" [ERROR] cycle node is detected") 218 temp_goal = goal_inside[node.sorted_id] 219 print(g.node.sorted_id) 220 print(node) 221 print(node.sorted_id) 222 print(temp_goal) 223 quit() 224 elif len(temp_goal["template"]) > 0: 225 # tensor subgoal 226 if dryrun: 227 name = node.goal.name 228 args = node.goal.args 229 temp_goal_inside={ 230 "type":"goal", 231 "from":"goal", 232 "name": name, 233 "args": args, 234 "id": node.sorted_id,} 235 #"shape":shape} 236 else: 237 temp_goal_inside = temp_goal["inside"] 238 temp_goal_template = temp_goal["template"] 239 if temp_goal["batch_flag"]: 240 path_batch_flag = True 241 node_inside.append(temp_goal_inside) 242 node_template.append(temp_goal_template) 243 else: # scalar subgoal 244 if dryrun: 245 name = node.goal.name 246 args = node.goal.args 247 temp_goal_inside={ 248 "type":"goal", 249 "from":"goal", 250 "name": name, 251 "args": args, 252 "id": node.sorted_id,} 253 #"shape":()} 254 node_scalar_inside.append(temp_goal_inside) 255 else: 256 if type(temp_goal["inside"]) is list: 257 a = torch.tensor(temp_goal["inside"]) 258 node_scalar_inside.append(torch.squeeze(a)) 259 else: 260 node_scalar_inside.append(temp_goal["inside"]) 261 return node_template, node_inside, node_scalar_inside, path_batch_flag
263 def forward_path_sw(self, path, verbose=False,verbose_embedding=False, dryrun=False): 264 tensor_provider = self.tensor_provider 265 sw_template = [] 266 sw_inside = [] 267 path_batch_flag=False 268 for sw in path.tensor_switches: 269 ph = tensor_provider.get_placeholder_name(sw.name) 270 if len(ph) > 0: 271 sw_template.append(["b"] + list(sw.values)) 272 path_batch_flag = True 273 else: 274 sw_template.append(list(sw.values)) 275 if dryrun: 276 #x = tensor_provider.get_embedding(sw.name, verbose_embedding) 277 sw_var = {"type":"tensor_atom", 278 "from":"tensor_provider.get_embedding", 279 "name":sw.name,} 280 #"shape":x.shape} 281 else: 282 sw_var = tensor_provider.get_embedding(sw.name, verbose_embedding) 283 sw_inside.append(sw_var) 284 prob_sw_inside = [] 285 if dryrun: 286 for sw in path.prob_switches: 287 prob_sw_inside.append({ 288 "type":"const", 289 "name": sw.name, 290 "value": sw.inside, 291 "shape":(),}) 292 else: 293 for sw in path.prob_switches: 294 prob_sw_inside.append(sw.inside) 295 """ 296 prob_sw_inside.append({ 297 "type":"const", 298 "name": sw.name, 299 "value": sw.inside, 300 "shape":(),}) 301 """ 302 return sw_template, sw_inside, prob_sw_inside, path_batch_flag
304 def forward_path_op(self, ops, 305 sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside,path_batch_flag, verbose=False, dryrun=False): 306 307 if "distribution" in ops: 308 op = ops["distribution"] 309 dist = op.values[0] 310 name = g.node.goal.name 311 if dryrun: 312 #out_inside, out_template = self._distribution_forward_dryrun 313 out_inside={ 314 "type": "distribution", 315 "name": name, 316 "dist_type": op, 317 "path": sw_node_inside} 318 out_template= sw_node_template#TODO 319 else: 320 out_inside, out_template = self._distribution_forward( 321 name, 322 dist, 323 params=sw_node_inside, 324 param_template=sw_node_template, 325 op=op, 326 ) 327 328 else:# einsum operator 329 path_v = sorted( 330 zip(sw_node_template, sw_node_inside), key=lambda x: x[0] 331 ) 332 template = [x[0] for x in path_v] 333 inside = [x[1] for x in path_v] 334 # constructing einsum operation using template and inside 335 out_template = self._compute_output_template(template) 336 # print(template,out_template) 337 if len(template) > 0: # condition for einsum 338 if verbose: 339 einsum_eq, out_template_v = self.make_einsum_args(template,out_template,path_batch_flag) 340 print(" index:", einsum_eq) 341 print(" var. :", [x.shape for x in inside]) 342 #print(" var. :", inside) 343 if dryrun: 344 einsum_eq, out_template = self.make_einsum_args(template,out_template,path_batch_flag) 345 out_inside = { 346 "type":"einsum", 347 "name":"torch.einsum", 348 "einsum_eq":einsum_eq, 349 "path": inside} 350 else: 351 einsum_args, out_template = self.make_einsum_args_sublist(template,inside,out_template,path_batch_flag) 352 #out_inside = torch.einsum(einsum_eq, *inside) * out_inside 353 out_inside = torch.einsum(*einsum_args) 354 else: # condition for scalar 355 if dryrun: 356 out_inside = { 357 "type":"nop", 358 "name":"nop", 359 "path": inside} 360 if dryrun: 361 out_inside["path_scalar"]=node_scalar_inside+prob_sw_inside 362 else: 363 # scalar subgoal 364 for scalar_inside in node_scalar_inside: 365 out_inside = scalar_inside * out_inside 366 # prob(scalar) switch 367 for prob_inside in prob_sw_inside: 368 out_inside = prob_inside * out_inside 369 370 ## computing operaters 371 for op_name, op in ops.items(): 372 if verbose: 373 print(" operator:", op_name) 374 out_inside,out_template=self._apply_operator( 375 op, 376 operator_loader, 377 out_inside, 378 out_template, 379 dryrun) 380 ## 381 return out_inside,out_template
383 def forward(self, verbose=False,verbose_embedding=False, dryrun=False): 384 """ 385 Args: 386 verbose (bool): if true, this function displays an explanation graph with forward computation 387 dryrun (bool): if true, this function outputs information required for calculation as goal_inside instead of computational graph 388 389 Returns: 390 Tuple[List[Dict],Dict]: a pair of goal_inside and loss: 391 - goal_inside: tensors assigned for each goal 392 - loss: loss derived from explanation graph: key = loss name and value = loss 393 """ 394 graph = self.graph 395 tensor_provider = self.tensor_provider 396 cycle_embedding_generator = self.cycle_embedding_generator 397 goal_template = self.goal_template 398 cycle_node = self.cycle_node 399 operator_loader = self.operator_loader 400 self.loss = {} 401 # goal_template 402 # converting explanation graph to computational graph 403 goal_inside = [None] * len(graph.goals) 404 for i in range(len(graph.goals)): 405 g = graph.goals[i] 406 if verbose: 407 print( 408 "=== tensor equation (node_id:%d, %s) ===" 409 % (g.node.sorted_id, g.node.goal.name) 410 ) 411 path_inside = [] 412 path_template = [] 413 path_batch_flag = False 414 for path in g.paths: 415 ## build template and inside for switches in the path 416 sw_template, sw_inside, prob_sw_inside, batch_flag = self.forward_path_sw( 417 path, verbose, verbose_embedding, dryrun) 418 path_batch_flag = batch_flag or path_batch_flag 419 420 ## building template and inside for nodes in the path 421 node_template, node_inside, node_scalar_inside, batch_flag = self.forward_path_node( 422 path, goal_inside, verbose, dryrun) 423 path_batch_flag = batch_flag or path_batch_flag 424 425 ## building template and inside for all elements (switches and nodes) in the path 426 sw_node_template = sw_template + node_template 427 sw_node_inside = sw_inside + node_inside 428 429 ops = {op.name: op for op in path.operators} 430 431 out_inside,out_template = self.forward_path_op(ops, 432 sw_node_template, sw_node_inside, node_scalar_inside, prob_sw_inside, 433 path_batch_flag, 434 verbose, dryrun) 435 436 path_inside.append(out_inside) 437 path_template.append(out_template) 438 ## 439 ### update inside 440 path_template_list = self._get_unique_list(path_template) 441 if len(path_template_list) == 0: # non-tensor/non-probabilistic path 442 if dryrun: 443 goal_inside[i] = { 444 "template": [], 445 "inside": path_inside, 446 "batch_flag": False, 447 } 448 else: 449 goal_inside[i] = { 450 "template": [], 451 "inside": torch.tensor(1), 452 "batch_flag": False, 453 } 454 else: 455 if len(path_template_list) != 1: 456 print("[WARNING] missmatch indices:", path_template_list) 457 if dryrun: 458 goal_inside[i] = { 459 "template": path_template_list[0], 460 "inside": path_inside, 461 "batch_flag": path_batch_flag, 462 } 463 else: 464 if len(path_template_list[0]) == 0: # scalar inside 465 goal_inside[i] = { 466 "template": path_template_list[0], 467 "inside": path_inside[0], 468 "batch_flag": path_batch_flag, 469 } 470 else: 471 temp_inside=torch.sum(torch.stack(path_inside), dim=0) 472 goal_inside[i] = { 473 "template": path_template_list[0], 474 "inside": temp_inside, 475 "batch_flag": path_batch_flag, 476 } 477 if dryrun: 478 goal_inside[i]["id"]=g.node.sorted_id 479 goal_inside[i]["name"]=g.node.goal.name 480 goal_inside[i]["args"]=g.node.goal.args 481 482 self.loss.update(tensor_provider.get_loss()) 483 return goal_inside, self.loss
Arguments:
- verbose (bool): if true, this function displays an explanation graph with forward computation
- dryrun (bool): if true, this function outputs information required for calculation as goal_inside instead of computational graph
Returns:
Tuple[List[Dict],Dict]: a pair of goal_inside and loss: - goal_inside: tensors assigned for each goal - loss: loss derived from explanation graph: key = loss name and value = loss
Inherited Members
- torch.nn.modules.module.Module
- dump_patches
- training
- call_super_init
- register_buffer
- register_parameter
- add_module
- register_module
- get_submodule
- get_parameter
- get_buffer
- get_extra_state
- set_extra_state
- apply
- cuda
- ipu
- xpu
- cpu
- type
- float
- double
- half
- bfloat16
- to_empty
- to
- register_full_backward_pre_hook
- register_backward_hook
- register_full_backward_hook
- register_forward_pre_hook
- register_forward_hook
- register_state_dict_pre_hook
- state_dict
- register_load_state_dict_post_hook
- load_state_dict
- parameters
- named_parameters
- buffers
- named_buffers
- children
- named_children
- modules
- named_modules
- train
- eval
- requires_grad_
- zero_grad
- extra_repr
- compile
492class TorchTensorOnehot(TorchTensorBase): 493 def __init__(self, provider, shape, value): 494 self.shape = shape 495 self.value = value 496 497 def __call__(self): 498 v = torch.eye(self.shape)[self.value] 499 return v
502class TorchTensor(TorchTensorBase): 503 def __init__(self, provider: 'TorchSwitchTensorProvider', name: str, shape: List[Union[int64, int]], dtype: dtype=torch.float32, tensor_type=None) -> None: 504 self.shape = shape 505 self.dtype = dtype 506 if name is None: 507 self.name = "tensor%04d" % (np.random.randint(0, 10000),) 508 else: 509 self.name = name 510 self.provider = provider 511 self.tensor_type = tensor_type 512 ### 513 self.constraint_tensor=tprism.constraint.get_constraint_tensor(shape, tensor_type, device=None, dtype=None) 514 self.param = None 515 if self.constraint_tensor is None: 516 param = torch.nn.Parameter(torch.Tensor(*shape), requires_grad=True) 517 self.param = param 518 provider.add_param(self.name, param, tensor_type) 519 self.reset_parameters() 520 else: 521 param=list(self.constraint_tensor.parameters())[0] #TODO 522 provider.add_param(self.name, param, tensor_type) 523 524 ### 525 def reset_parameters(self) -> None: 526 if len(self.param.shape) == 2: 527 torch.nn.init.kaiming_uniform_(self.param, a=math.sqrt(5)) 528 else: 529 self.param.data.uniform_(-0.1, 0.1) 530 531 def __call__(self): 532 if self.constraint_tensor is None: 533 return self.param 534 else: 535 return self.constraint_tensor()
503 def __init__(self, provider: 'TorchSwitchTensorProvider', name: str, shape: List[Union[int64, int]], dtype: dtype=torch.float32, tensor_type=None) -> None: 504 self.shape = shape 505 self.dtype = dtype 506 if name is None: 507 self.name = "tensor%04d" % (np.random.randint(0, 10000),) 508 else: 509 self.name = name 510 self.provider = provider 511 self.tensor_type = tensor_type 512 ### 513 self.constraint_tensor=tprism.constraint.get_constraint_tensor(shape, tensor_type, device=None, dtype=None) 514 self.param = None 515 if self.constraint_tensor is None: 516 param = torch.nn.Parameter(torch.Tensor(*shape), requires_grad=True) 517 self.param = param 518 provider.add_param(self.name, param, tensor_type) 519 self.reset_parameters() 520 else: 521 param=list(self.constraint_tensor.parameters())[0] #TODO 522 provider.add_param(self.name, param, tensor_type) 523 524 ###
538class TorchGather(TorchTensorBase): 539 def __init__(self, provider: 'TorchSwitchTensorProvider', var: TorchTensor, idx: PlaceholderData) -> None: 540 self.var = var 541 self.idx = idx 542 self.provider = provider 543 544 def __call__(self): 545 if isinstance(self.idx, PlaceholderData): 546 idx = self.provider.get_embedding(self.idx) 547 else: 548 idx = self.idx 549 if isinstance(self.var, TorchTensor): 550 temp = self.var() 551 v = torch.index_select(temp, 0, idx) 552 elif isinstance(self.var, TorchTensorBase): 553 v = torch.index_select(self.var(), 0, idx) 554 elif isinstance(self.var, PlaceholderData): 555 v = self.provider.get_embedding(self.var) 556 v = v[idx] 557 else: 558 v = torch.index_select(self.var, 0, idx) 559 return v
562class TorchSwitchTensorProvider(SwitchTensorProvider): 563 def __init__(self) -> None: 564 self.tensor_onehot_class = TorchTensorOnehot 565 self.tensor_class = TorchTensor 566 self.tensor_gather_class = TorchGather 567 568 self.integer_dtype = torch.int32 569 super().__init__() 570 571 # forward 572 def get_loss(self, verbose:bool=False): 573 loss={} 574 for name, (param,tensor_type) in self.params.items(): 575 m=re.match(r"^sparse\(([0-9\.]*)\)$", tensor_type) 576 if m: 577 coeff = float(m.group(1)) 578 l=coeff*torch.norm(param,1) 579 loss["sparse_"+name]=l 580 elif tensor_type=="sparse": 581 l=torch.norm(param,1) 582 loss["sparse_"+name]=l 583 return loss 584 # forward 585 def get_embedding(self, name: Union[str,PlaceholderData], verbose:bool=False): 586 if verbose: 587 print("[INFO] get embedding:", name) 588 out = None 589 ## TODO: 590 if self.input_feed_dict is None: 591 if verbose: 592 print("[INFO] from tensor_embedding", name) 593 obj = self.tensor_embedding[name] 594 if isinstance(obj, TorchTensorBase): 595 out = obj() 596 else: 597 raise Exception("Unknoen embedding type", name, type(obj)) 598 elif type(name) is str: 599 key = self.tensor_embedding[name] 600 if type(key) is PlaceholderData: 601 if verbose: 602 print("[INFO] from PlaceholderData", name, "==>", key.name) 603 out = self.input_feed_dict[key] 604 elif isinstance(key, TorchTensorBase): 605 if verbose: 606 print("[INFO] from Tensor", name, "==>") 607 out = key() 608 else: 609 raise Exception("Unknoen embedding type", name, key) 610 elif type(name) is PlaceholderData: 611 if verbose: 612 print("[INFO] from PlaceholderData", name) 613 out = self.input_feed_dict[name] 614 else: 615 raise Exception("Unknoen embedding", name) 616 if verbose: 617 print(out) 618 print(type(out)) 619 print("sum:", out.sum()) 620 return out
This class provides information of switches
Attributes:
- tensor_embedding (Dict[str, Tensor]): embedding tensor
- sw_info (Dict[str, SwitchTensor]): switch infomation
- ph_graph (PlaceholderGraph): associated placeholder graph
- input_feed_dict (Dict[PlaceholderData, Tensor]): feed_dict to replace a placeholder with a tensor
- params (Dict[str,Tuple[Parameter,str]]): pytorch parameters associated with all switches provided by this provider
572 def get_loss(self, verbose:bool=False): 573 loss={} 574 for name, (param,tensor_type) in self.params.items(): 575 m=re.match(r"^sparse\(([0-9\.]*)\)$", tensor_type) 576 if m: 577 coeff = float(m.group(1)) 578 l=coeff*torch.norm(param,1) 579 loss["sparse_"+name]=l 580 elif tensor_type=="sparse": 581 l=torch.norm(param,1) 582 loss["sparse_"+name]=l 583 return loss
585 def get_embedding(self, name: Union[str,PlaceholderData], verbose:bool=False): 586 if verbose: 587 print("[INFO] get embedding:", name) 588 out = None 589 ## TODO: 590 if self.input_feed_dict is None: 591 if verbose: 592 print("[INFO] from tensor_embedding", name) 593 obj = self.tensor_embedding[name] 594 if isinstance(obj, TorchTensorBase): 595 out = obj() 596 else: 597 raise Exception("Unknoen embedding type", name, type(obj)) 598 elif type(name) is str: 599 key = self.tensor_embedding[name] 600 if type(key) is PlaceholderData: 601 if verbose: 602 print("[INFO] from PlaceholderData", name, "==>", key.name) 603 out = self.input_feed_dict[key] 604 elif isinstance(key, TorchTensorBase): 605 if verbose: 606 print("[INFO] from Tensor", name, "==>") 607 out = key() 608 else: 609 raise Exception("Unknoen embedding type", name, key) 610 elif type(name) is PlaceholderData: 611 if verbose: 612 print("[INFO] from PlaceholderData", name) 613 out = self.input_feed_dict[name] 614 else: 615 raise Exception("Unknoen embedding", name) 616 if verbose: 617 print(out) 618 print(type(out)) 619 print("sum:", out.sum()) 620 return out