1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to convert variables to constants in TensorFlow 2.0.""" 16 17import collections 18import numpy as np 19 20from tensorflow.core.framework import attr_value_pb2 21from tensorflow.core.framework import graph_pb2 22from tensorflow.core.framework import tensor_shape_pb2 23from tensorflow.core.framework import variable_pb2 24from tensorflow.core.protobuf import config_pb2 25from tensorflow.core.protobuf import meta_graph_pb2 26from tensorflow.core.protobuf import rewriter_config_pb2 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import errors 30from tensorflow.python.framework import graph_util 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.grappler import tf_optimizer 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import variables 36from tensorflow.python.platform import tf_logging as logging 37from tensorflow.python.training.saver import export_meta_graph 38from tensorflow.python.util import lazy_loader 39from tensorflow.python.util import object_identity 40 41# Lazy load the single eager module to avoid introducing new dependencies for 42# graph_util:convert_variables_to_constants (eg in 43# tensorflow/contrib/session_bundle:session_bundle_py_test). 44wrap_function = lazy_loader.LazyLoader( 45 "wrap_function", globals(), 46 "tensorflow.python.eager.wrap_function") 47 48# Used in _FunctionConverterDataInGraph(). 49VAR_ASSIGN_COLLECTION = "extra_var_assign_ops" 50_CONDITIONAL_OPS = set(["If", "StatelessIf"]) 51_LOOP_OPS = set(["While", "StatelessWhile"]) 52_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS) 53 54 55class _TensorData( 56 collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])): 57 """Data about a tensor that was converted to a constant.""" 58 __slots__ = () 59 60 @property 61 def dtype_attr(self): 62 return attr_value_pb2.AttrValue(type=self.dtype) 63 64 65class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])): 66 """An endpoint in a graph.""" 67 __slots__ = () 68 69 def __str__(self): 70 return "{}[{}]".format(self.convertible, self.index) 71 72 73class _Edge(collections.namedtuple("_Edge", ["source", "destination"])): 74 """A directed graph edge.""" 75 __slots__ = () 76 77 def __str__(self): 78 return "{} -> {}".format(self.source, self.destination) 79 80 81class _Convertible(object): 82 """An entity that can have variables converted to constants.""" 83 84 def __init__(self, enclosing_graph): 85 self._enclosing_graph = enclosing_graph 86 self._outgoing_edges = [] 87 self._converted_self = None 88 89 def converted_self(self): 90 """A copy of this Convertible to be modified during conversion. 91 92 Returns: 93 Implementations should return the copied instance, which in turn should 94 be contained in converted_enclosing_graph(). This instance is the one that 95 will be modified during conversion. Its main use will be in the 96 implementations of convert_variable_to_constant(). 97 """ 98 raise NotImplementedError 99 100 def convert_variable_to_constant(self, incoming_edge, tensor_data): 101 """Converts a variable in this Convertible and its dependencies. 102 103 This method should make sure that a converted copy of itself is present in 104 the converted graph, and that all Convertibles depending on this one also go 105 through the same process. 106 107 Args: 108 incoming_edge: The graph edge into this Convertible that is being 109 converted to a constant. 110 tensor_data: The tensor representing the constant. 111 """ 112 raise NotImplementedError 113 114 def create_edges(self): 115 """Calls add_outgoing_edge for all edges known to this Convertible. 116 117 This is used to build the graph dependencies, so that conversion of 118 variables to constants can be properly propagated through the graph. Usually 119 this method will call add_outgoing_edge() to all the Convertible inputs. 120 """ 121 raise NotImplementedError 122 123 def add_outgoing_edge(self, edge): 124 """Adds an outgoing edge to the Convertible's list of edges. 125 126 Args: 127 edge: The outgoing edge (its source should be 'self'). 128 """ 129 self._outgoing_edges.append(edge) 130 131 @property 132 def converted_enclosing_graph(self): 133 """The graph being converted.""" 134 return self._enclosing_graph.converted_self() 135 136 @property 137 def outgoing_edges(self): 138 """The list of edges starting at this Convertible.""" 139 return self._outgoing_edges 140 141 142class _Function(_Convertible): 143 """A library function Convertible. 144 145 Edges into functions are edges from node _inputs_ into function _inputs_: 146 Functions get their input from their callers, not from node outputs, and the 147 callers in turn get those values as inputs. 148 """ 149 150 def __init__(self, function, enclosing_graph): 151 super(_Function, self).__init__(enclosing_graph) 152 self._function = function 153 self._nodes = { 154 n.name: 155 _Node.new(node=n, function=self, enclosing_graph=enclosing_graph) 156 for n in function.node_def 157 } 158 159 def __str__(self): 160 return self.function.signature.name 161 162 @property 163 def function(self): 164 return self._function 165 166 @property 167 def nodes(self): 168 return self._nodes 169 170 def converted_self(self): 171 """The Function copy to be converted. 172 173 The copy will be renamed according to the graph's converted_function_name 174 map, to ensure the name does not match anything currently in TensorFlow's 175 function cache. 176 177 Returns: 178 The function instance to be converted. 179 """ 180 if self._converted_self is None: 181 old_name = self.function.signature.name 182 new_name = self._enclosing_graph.converted_function_names[old_name] 183 self.converted_enclosing_graph.rename_function(old_name, new_name) 184 self._converted_self = self.converted_enclosing_graph.functions[new_name] 185 return self._converted_self 186 187 def convert_variable_to_constant(self, incoming_edge, tensor_data): 188 """Converts one function argument into a constant. 189 190 Args: 191 incoming_edge: The edge into the argument to be converted. 192 tensor_data: The constant value. 193 """ 194 function = self.converted_self().function 195 index = incoming_edge.destination.index 196 function.signature.input_arg[index].type = tensor_data.dtype 197 198 # TODO(b/176982859): Find a more satisfying way to update shape information 199 # than clearing it, or migrate users to a workflow that does not require 200 # freezing. 201 if "_input_shapes" in function.attr: 202 function.attr["_input_shapes"].list.shape[index].unknown_rank = True 203 del function.attr["_input_shapes"].list.shape[index].dim[:] 204 arg_attrs = function.arg_attr[index].attr 205 if "_output_shapes" in arg_attrs: 206 arg_attrs["_output_shapes"].list.shape[0].unknown_rank = True 207 del arg_attrs["_output_shapes"].list.shape[0].dim[:] 208 209 for edge in self.outgoing_edges: 210 if edge.source.index == index: 211 edge.destination.convertible.convert_variable_to_constant( 212 edge, tensor_data) 213 214 def create_edges(self): 215 for n in self._nodes.values(): 216 n.create_edges() 217 218 219class _Node(_Convertible): 220 """A Convertible NodeDef.""" 221 222 def __init__(self, node, function, enclosing_graph): 223 super(_Node, self).__init__(enclosing_graph) 224 self._node = node 225 self._function = function 226 227 def __str__(self): 228 return self._node.name 229 230 @staticmethod 231 def new(node, function, enclosing_graph): 232 """Creates a new _Node base on its operation type.""" 233 if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]: 234 return _VarHandle(node, function, enclosing_graph) 235 elif node.op == "Case": 236 return _Case(node, function, enclosing_graph) 237 elif node.op == "Merge": 238 return _Merge(node, function, enclosing_graph) 239 elif node.op == "PartitionedCall": 240 return _PartitionedCall(node, function, enclosing_graph) 241 elif node.op == "StatefulPartitionedCall": 242 return _PartitionedCall(node, function, enclosing_graph) 243 elif node.op == "ReadVariableOp": 244 return _ReadVariable(node, function, enclosing_graph) 245 elif node.op == "ResourceGather": 246 return _ResourceGather(node, function, enclosing_graph) 247 elif node.op == "ResourceGatherNd": 248 return _ResourceGatherNd(node, function, enclosing_graph) 249 elif node.op in ["If", "StatelessIf"]: 250 return _If(node, function, enclosing_graph) 251 elif node.op in ["While", "StatelessWhile"]: 252 return _While(node, function, enclosing_graph) 253 elif node.op in [ 254 "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]: 255 return _Intermediate(node, function, enclosing_graph) 256 else: 257 return _Node(node, function, enclosing_graph) 258 259 @property 260 def node(self): 261 return self._node 262 263 @property 264 def container(self): 265 """The node container (either a graph or a function).""" 266 if self._function is not None: 267 return self._function.function 268 return self._enclosing_graph.graph_def 269 270 def converted_self(self): 271 """The NodeDef to be converted. 272 273 Returns: 274 The NodeDef to be converted, which can come from either a graph for a 275 function. Derived classes should call this (via 'super') to make sure the 276 node is retrieved from the right place. 277 """ 278 if self._converted_self is None: 279 source = self._function or self._enclosing_graph 280 self._converted_self = source.converted_self().nodes[self._node.name] 281 return self._converted_self 282 283 def convert_variable_to_constant(self, incoming_edge, tensor_data): 284 pass 285 286 def create_edges(self): 287 for index, name in enumerate(self._node.input): 288 # Discard edges from control inputs. 289 if name[0] == "^": 290 continue 291 source = self.resolve_input(name) 292 source.convertible.add_outgoing_edge( 293 _Edge(source, _EndPoint(self, index))) 294 295 def resolve_input(self, input_name): 296 """Resolves an input into its _EndPoint. 297 298 A NodeDef's input name can refer to either global NodeDefs (in the 299 GraphDef's node list), a NodeDef in a function's node list, or a Function 300 (in the GraphDef's function library). The name can also carry semantic 301 information, depending on whether it starts with "^". This method handles 302 all that logic in order to find the object to which the input name refers 303 to. 304 305 Args: 306 input_name: The input name to resolve. 307 308 Returns: 309 The object referred to by 'input_name'. 310 """ 311 312 # The logic below oversimplifies the semantics, but is good enough for the 313 # purposes of converting to constants. The introduction of new types of 314 # operations may change this, forcing the code to be more generic. 315 # 316 # In particular, we are assuming that the lack of an index suffix means 317 # ":0", when it could mean "all the outputs of a node." This works now 318 # because converting to constants relies very little on output types, and 319 # when it does it specializes its treatment in dedicated classes. 320 name_elts = input_name.split(":") 321 source_name = name_elts[0] 322 if source_name[0] == "^": 323 source_name = source_name[1:] 324 source_index = 0 325 if len(name_elts) > 1 and name_elts[-1].isnumeric(): 326 source_index = int(name_elts[-1]) 327 328 if self._function is None: 329 return _EndPoint(self._enclosing_graph.nodes[source_name], source_index) 330 331 if source_index != 0 or source_name in self._function.nodes: 332 return _EndPoint(self._function.nodes[source_name], source_index) 333 334 inputs = [i.name for i in self._function.function.signature.input_arg] 335 return _EndPoint(self._function, inputs.index(source_name)) 336 337 def update_dtype(self, attr_name, index, dtype): 338 """Changes the type of a given input. 339 340 Args: 341 attr_name: The NodeDef attribute containing the type to change. 342 index: The index of the input type to change. 343 dtype: The type to change to. 344 """ 345 attr = self._node.attr[attr_name] 346 num_types = 0 347 # Check for various 'oneof' possibilities, and update the type if 348 # index in range. 349 if attr.HasField("list"): 350 types = attr.list.type 351 num_types = len(types) 352 if num_types > index: 353 types[index] = dtype 354 return 355 elif attr.HasField("type"): 356 num_types = 1 357 if index == 0: 358 attr.type = dtype 359 return 360 raise ValueError(f"`index` {index:d} is out of range for " 361 f"node({self._node.name}).attr({attr_name}), which has " 362 f"{num_types:d} elements.") 363 364 365class _Intermediate(_Node): 366 """Specialization of _Node to intermediate ops.""" 367 368 def convert_variable_to_constant(self, incoming_edge, tensor_data): 369 node = self.converted_self() 370 node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype) 371 if "_output_shapes" in node.node.attr: 372 del node.node.attr["_output_shapes"] 373 for edge in self.outgoing_edges: 374 edge.destination.convertible.convert_variable_to_constant( 375 edge, tensor_data) 376 377 378class _Merge(_Node): 379 """Specialization of _Node to Merge ops.""" 380 381 def convert_variable_to_constant(self, incoming_edge, tensor_data): 382 # The Merge operation has a single type for all its inputs, the number of 383 # which is reflected in the "N" attribute. For the time being, we assume 384 # that unilaterally changing all of them at once is ok. 385 super(_Merge, self).convert_variable_to_constant( 386 _Edge(incoming_edge.source, 387 _Edge(incoming_edge.destination.convertible, 0)), tensor_data) 388 389 390class _VarHandle(_Node): 391 """Specialization of _Node to VarHandleOp.""" 392 393 def convert_variable_to_constant(self, incoming_edge, tensor_data): 394 tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy, 395 tensor_data.dtype, 396 tensor_data.numpy.shape) 397 398 node = self.converted_self().node 399 node.Clear() 400 node.name = self._node.name 401 node.op = "Const" 402 node.attr["dtype"].CopyFrom(tensor_data.dtype_attr) 403 node.attr["value"].tensor.CopyFrom(tensor_proto) 404 405 for edge in self.outgoing_edges: 406 edge.destination.convertible.convert_variable_to_constant( 407 edge, tensor_data) 408 409 410class _ResourceGather(_Node): 411 """Specialization of _Node to ResourceGather.""" 412 413 def convert_variable_to_constant(self, incoming_edge, tensor_data): 414 # We currently skip the conversion if this is inside a function. 415 if self._function is not None: 416 return 417 if self._node.attr["batch_dims"].i != 0: 418 raise ValueError("batch_dims must be 0 for freeze_graph, but got " 419 f"node({self._node.name}).attr('batch_dims') = " 420 f"{self._node.attr['batch_dims'].i}.") 421 axis_node_name = self._node.name + "/axis" 422 axis_dtype = self._node.attr["Tindices"] 423 axis_data = np.array(self._node.attr["batch_dims"].i) 424 output_axis_node = self.converted_self().container.node.add() 425 output_axis_node.name = axis_node_name 426 output_axis_node.op = "Const" 427 output_axis_node.attr["dtype"].CopyFrom(axis_dtype) 428 tensor = tensor_util.make_tensor_proto( 429 axis_data, dtype=axis_dtype.type, shape=axis_data.shape) 430 output_axis_node.attr["value"].tensor.CopyFrom(tensor) 431 432 output_node = self.converted_self().node 433 output_node.Clear() 434 output_node.name = self._node.name 435 output_node.op = "GatherV2" 436 output_node.input.extend( 437 [self._node.input[0], self._node.input[1], axis_node_name]) 438 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 439 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 440 output_node.attr["Taxis"].CopyFrom(axis_dtype) 441 if "_class" in self._node.attr: 442 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 443 444 445class _ResourceGatherNd(_Node): 446 """Specialization of _Node to ResourceGatherNd.""" 447 448 def convert_variable_to_constant(self, incoming_edge, tensor_data): 449 output_node = self.converted_self().node 450 output_node.Clear() 451 output_node.name = self._node.name 452 output_node.op = "GatherNd" 453 output_node.input.extend([self._node.input[0], self._node.input[1]]) 454 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 455 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 456 if "_class" in self._node.attr: 457 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 458 459 460class _ReadVariable(_Node): 461 """Specialization of _Node to ReadVariableOp.""" 462 463 def convert_variable_to_constant(self, incoming_edge, tensor_data): 464 node = self.converted_self().node 465 node.Clear() 466 node.name = self._node.name 467 node.op = "Identity" 468 469 node.input.append(self._node.input[0]) 470 node.attr["T"].CopyFrom(self._node.attr["dtype"]) 471 if "_class" in self._node.attr: 472 node.attr["_class"].CopyFrom(self._node.attr["_class"]) 473 474 # If the ReadVariableOp is part of a function, then every node having the 475 # ReadVariableOp one as its input will refer to it using a ":value" 476 # syntax. We need to change that to ":output". 477 if self._function is not None: 478 for edge in self.outgoing_edges: 479 index = edge.destination.index 480 dest = edge.destination.convertible.converted_self() 481 if isinstance(dest, _Node): 482 input_name_parts = dest.node.input[index].split(":") 483 if len(input_name_parts) > 1 and input_name_parts[1] == "value": 484 input_name_parts[1] = "output" 485 dest.node.input[index] = ":".join(input_name_parts) 486 487 488class _FunctionCaller(_Node): 489 """A base class for Convertibles that reference functions.""" 490 491 def __init__(self, node, function, enclosing_graph, first_function_input, 492 type_attribute, function_attributes): 493 """Initializes a _FunctionCaller. 494 495 Args: 496 node: As in _Node. 497 function: As in _Node. 498 enclosing_graph: As in _Node. 499 first_function_input: The index of the first NodeDef input that is tied to 500 the function inputs. It is assumed that the rest of the NodeDef inputs 501 map one to one to function inputs. 502 type_attribute: The name of the NodeDef attribute that defines the input 503 types. It is assumed that the types listed here map one-to-one with the 504 function inputs (that is, they do _not_ specify types for inputs that 505 are not passed to functions). 506 function_attributes: The names of the NodeDef attributes containing 507 references to functions. 508 """ 509 super(_FunctionCaller, self).__init__(node, function, enclosing_graph) 510 self._first_function_input = first_function_input 511 self._type_attribute = type_attribute 512 self._function_attributes = function_attributes 513 514 def converted_self(self): 515 if self._converted_self is None: 516 node = super(_FunctionCaller, self).converted_self().node 517 converted_names = self._enclosing_graph.converted_function_names 518 for attr_name in self._function_attributes: 519 attr = node.attr[attr_name] 520 if attr.HasField("func"): 521 attr.func.name = converted_names[attr.func.name] 522 elif attr.HasField("list"): 523 for func in attr.list.func: 524 func.name = converted_names[func.name] 525 return self._converted_self 526 527 def convert_variable_to_constant(self, incoming_edge, tensor_data): 528 node = self.converted_self() 529 index = incoming_edge.destination.index 530 if index >= self._first_function_input: 531 node.update_dtype(self._type_attribute, 532 index - self._first_function_input, tensor_data.dtype) 533 534 # The loop below is reasonable but not correct in general: 535 # The outgoing edges going into the functions are correct, because the 536 # inputs map to the function inputs. But the edges going into other nodes do 537 # not take into account the logic of the body function, which may do 538 # arbitrary things to the node's output: 539 # 540 # while x < 0: 541 # return y 542 # 543 # In this case, the node's ":0" output may map to its ":1 input". For the 544 # time being, then, we only process edges into functions. 545 for edge in self.outgoing_edges: 546 dest = edge.destination.convertible 547 if edge.source.index == index and isinstance(dest, _Function): 548 dest.convert_variable_to_constant(edge, tensor_data) 549 550 def create_edges(self): 551 """Creates edges related to a function caller. 552 553 Edges from a function caller to its called functions are always edges from 554 _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on 555 its own inputs. 556 """ 557 super(_FunctionCaller, self).create_edges() 558 for attr_name in self._function_attributes: 559 attr = self._node.attr[attr_name] 560 if attr.HasField("func"): 561 function = self._enclosing_graph.functions[attr.func.name] 562 for index in range(len(self._node.input) - self._first_function_input): 563 self.add_outgoing_edge( 564 _Edge( 565 _EndPoint(self, index + self._first_function_input), 566 _EndPoint(function, index))) 567 elif attr.HasField("list"): 568 for func in attr.list.func: 569 function = self._enclosing_graph.functions[func.name] 570 for index in range( 571 len(self._node.input) - self._first_function_input): 572 self.add_outgoing_edge( 573 _Edge( 574 _EndPoint(self, index + self._first_function_input), 575 _EndPoint(function, index))) 576 577 578class _If(_FunctionCaller): 579 """Specialization of _Node to If-like operations.""" 580 581 def __init__(self, node, function, enclosing_graph): 582 super(_If, self).__init__( 583 node, 584 function, 585 enclosing_graph, 586 first_function_input=1, 587 type_attribute="Tin", 588 function_attributes=["then_branch", "else_branch"]) 589 590 591class _Case(_FunctionCaller): 592 """Specialization of _Node to Case-like operations.""" 593 594 def __init__(self, node, function, enclosing_graph): 595 super(_Case, self).__init__( 596 node, 597 function, 598 enclosing_graph, 599 first_function_input=1, 600 type_attribute="Tin", 601 function_attributes=["branches"]) 602 603 604class _PartitionedCall(_FunctionCaller): 605 """Specialization of _Node to PartitionedCall-like operations.""" 606 607 def __init__(self, node, function, enclosing_graph): 608 super(_PartitionedCall, self).__init__( 609 node, 610 function, 611 enclosing_graph, 612 first_function_input=0, 613 type_attribute="Tin", 614 function_attributes=["f"]) 615 616 617class _While(_FunctionCaller): 618 """Specialization of _Node to While-like operations.""" 619 620 def __init__(self, node, function, enclosing_graph): 621 super(_While, self).__init__( 622 node, 623 function, 624 enclosing_graph, 625 first_function_input=0, 626 type_attribute="T", 627 function_attributes=["body", "cond"]) 628 629 def convert_variable_to_constant(self, incoming_edge, tensor_data): 630 super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data) 631 node = self.converted_self() 632 if node.node.attr["output_shapes"].list.shape: 633 node.node.attr["output_shapes"].list.shape[ 634 incoming_edge.destination.index].CopyFrom( 635 tensor_shape_pb2.TensorShapeProto(dim=[ 636 tensor_shape_pb2.TensorShapeProto.Dim(size=dim) 637 for dim in tensor_data.numpy.shape 638 ])) 639 640 # The while's body inputs and outputs have the same type, so here we can go 641 # ahead and change that function's output type. 642 body_name = self._node.attr["body"].func.name 643 body = self._enclosing_graph.functions[body_name].converted_self().function 644 body.signature.output_arg[ 645 incoming_edge.destination.index].type = tensor_data.dtype 646 647 648class _GraphDef(_Convertible): 649 """A convertible GraphDef.""" 650 651 def __init__(self, graph_def): 652 super(_GraphDef, self).__init__(enclosing_graph=None) 653 self._graph_def = graph_def 654 self._nodes = { 655 n.name: _Node.new(node=n, function=None, enclosing_graph=self) 656 for n in graph_def.node 657 } 658 self._functions = { 659 f.signature.name: _Function(f, enclosing_graph=self) 660 for f in graph_def.library.function 661 } 662 self.create_edges() 663 self._converted_function_names = None 664 665 @property 666 def graph_def(self): 667 return self._graph_def 668 669 @property 670 def nodes(self): 671 return self._nodes 672 673 @property 674 def functions(self): 675 return self._functions 676 677 @property 678 def converted_function_names(self): 679 """Map from original to new function names. 680 681 In order to avoid conflicts (two functions with the same name, one converted 682 and one not), we need to change the name of every converted function to 683 something that is hopefully unique. 684 685 Returns: 686 Map from original to new suggested function names. 687 """ 688 if self._converted_function_names is None: 689 parsed_names = [] # List of (id, base_name, original_name) 690 for name in self.functions: 691 elements = name.rsplit("_", 1) 692 if len(elements) == 2 and elements[1].isnumeric(): 693 parsed_names.append((int(elements[1]), elements[0], name)) 694 else: 695 parsed_names.append((-1, name, name)) 696 self._converted_function_names = { 697 name: "{}_frozen_{}".format(base_name, ops.uid()) 698 for (_, base_name, name) in sorted(parsed_names) 699 } 700 701 return self._converted_function_names 702 703 def rename_function(self, old_name, new_name): 704 func = self.functions.pop(old_name) 705 func.function.signature.name = new_name 706 self.functions[new_name] = func 707 708 def converted_self(self): 709 if self._converted_self is None: 710 copied_graph = graph_pb2.GraphDef() 711 copied_graph.CopyFrom(self._graph_def) 712 self._converted_self = _GraphDef(copied_graph) 713 return self._converted_self 714 715 def create_edges(self): 716 for n in self._nodes.values(): 717 n.create_edges() 718 for f in self._functions.values(): 719 f.create_edges() 720 721 722class _ConverterData(object): 723 """Container for constant conversion supporting data. 724 725 The data includes the graph being converted, and the pre-converted 726 tensors. This class will be specialized for ConcreteFunction and Session-based 727 conversions, as the means to obtain that data is different for each case. 728 """ 729 730 def __init__(self, 731 graph_def, 732 variable_names_allowlist=None, 733 variable_names_denylist=None): 734 self._graph_def = graph_def 735 self._tensor_data = {} 736 self._build_node_defs_list() 737 self._variable_names_allowlist = variable_names_allowlist 738 self._variable_names_denylist = variable_names_denylist 739 740 @property 741 def graph_def(self): 742 """The graph to be converted.""" 743 return self._graph_def 744 745 @property 746 def node_defs(self): 747 """All the node defs in the graph to be converted. 748 749 Returns: 750 A map from node name to the NodeDef for all NodeDefs in the graph, as well 751 as all control flow NodeDefs in the functions. 752 """ 753 return self._node_defs 754 755 @property 756 def tensor_data(self): 757 """A map from tensor name to its converted _TensorData.""" 758 return self._tensor_data 759 760 def _should_convert(self, name): 761 """Checks whether to convert the given variable name to a constant.""" 762 return (self._variable_names_allowlist is None or 763 name in self._variable_names_allowlist) and ( 764 self._variable_names_denylist is None or 765 name not in self._variable_names_denylist) 766 767 def _build_node_defs_list(self): 768 """Builds the list of NodeDefs in the GraphDef. 769 770 This list consists of all NodeDefs in the main graph as well as all control 771 flow NodeDefs in the functions. 772 773 The remaining NodeDefs in the functions are not included because the op 774 names 775 are not unique and the variables are handled differently than the main 776 graph. 777 The control flow ops need to be extracted because they are need their 778 attributes to be updated similar to the control flow ops in the main graph. 779 """ 780 self._node_defs = {node.name: node for node in self._graph_def.node} 781 782 if self._graph_def.library: 783 for func in self._graph_def.library.function: 784 self._node_defs.update({ 785 node.name: node 786 for node in func.node_def 787 if node.op in _CONTROL_FLOW_OPS 788 }) 789 790 791class _FunctionConverterData(_ConverterData): 792 """Container for ConcreteFunction-based conversion data.""" 793 794 def __init__(self, 795 func, 796 lower_control_flow, 797 aggressive_inlining, 798 variable_names_allowlist=None, 799 variable_names_denylist=None): 800 """Creates the conversion data for the given function. 801 802 Args: 803 func: ConcreteFunction. 804 lower_control_flow: Boolean indicating whether or not to lower control 805 flow ops such as If and While. 806 aggressive_inlining: Boolean indicating whether or not to do aggressive 807 function inlining (might be unsafe if function has stateful ops, not 808 properly connected to control outputs). 809 variable_names_allowlist: The set of variable names to convert (by 810 default, all variables are converted). 811 variable_names_denylist: The set of variable names to omit converting to 812 constants. 813 """ 814 815 self._func = func 816 # Inline the graph in order to remove functions when possible. 817 graph_def = _run_inline_graph_optimization(func, lower_control_flow, 818 aggressive_inlining) 819 super(_FunctionConverterData, self).__init__( 820 graph_def, 821 variable_names_allowlist=variable_names_allowlist, 822 variable_names_denylist=variable_names_denylist) 823 824 self._build_tensor_data() 825 826 def _eval(self, tensor): 827 """Returns the value in the tensor. Must be implemented in sub-classes.""" 828 raise errors.UnimplementedError( 829 "The evaluation method should be implemented in sub-classes.") 830 831 def _build_tensor_data(self): 832 """Caches the tensor data for all Placeholders in the given function.""" 833 map_index_to_variable = {} 834 for var in self._func.graph.variables: 835 for idx, captured_input in enumerate(self._func.captured_inputs): 836 if var.handle is captured_input: # pylint: disable=protected-access 837 map_index_to_variable[idx] = var 838 break 839 840 # Iterates through all captures which are represented as Placeholders. 841 for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures): 842 tensor_name = name_tensor.name.split(":")[0] 843 if not self._should_convert(tensor_name): 844 continue 845 if idx in map_index_to_variable: 846 data = self._eval(map_index_to_variable[idx]) 847 else: 848 if val_tensor.dtype == dtypes.resource: 849 logging.vlog(1, "Skip converting resource tensor %s" % tensor_name) 850 continue 851 data = np.array(self._eval(val_tensor)) 852 853 self._tensor_data[tensor_name] = _TensorData( 854 numpy=data, 855 dtype=dtypes.as_dtype(data.dtype).as_datatype_enum, 856 index=idx) 857 858 # Get data for VariableV2 ops (reference variables) that cannot be lifted. 859 for node in self.node_defs.values(): 860 if node.op == "VariableV2": 861 if not self._should_convert(node.name): 862 continue 863 if node.name not in self.tensor_data: 864 with self._func.graph.as_default(): 865 identity_node = array_ops.identity( 866 self._func.graph.as_graph_element(node.name + ":0")) 867 pruned_graph = self._func.prune([], [identity_node.name])()[0] 868 self._tensor_data[node.name] = _TensorData( 869 numpy=pruned_graph.numpy(), 870 dtype=node.attr["dtype"].type, 871 index=None) 872 873 874class _FunctionConverterDataInEager(_FunctionConverterData): 875 """Container for ConcreteFunction-based conversion data in Eager mode.""" 876 877 def _eval(self, tensor): 878 """Returns the value in the tensor. Must be implemented in sub-classes.""" 879 return tensor.numpy() 880 881 882class _FunctionConverterDataInGraph(_FunctionConverterData): 883 """Container for ConcreteFunction-based conversion data in Graph mode.""" 884 885 def __init__(self, 886 func, 887 lower_control_flow, 888 aggressive_inlining, 889 variable_names_allowlist=None, 890 variable_names_denylist=None, 891 session=None): 892 """Creates the conversion data for the given function. 893 894 Args: 895 func: ConcreteFunction. 896 lower_control_flow: Boolean indicating whether or not to lower control 897 flow ops such as If and While. 898 aggressive_inlining: Boolean indicating whether or not to do aggressive 899 function inlining (might be unsafe if function has stateful ops, not 900 properly connected to control outputs). 901 variable_names_allowlist: The set of variable names to convert (by 902 default, all variables are converted). 903 variable_names_denylist: The set of variable names to omit converting to 904 constants. 905 session: Session object. 906 """ 907 self._session = session 908 909 session.run(variables.global_variables_initializer()) 910 # Run extra assignment ops if needed. 911 # These assignments are run sequentially to ensure order. 912 for op in ops.get_default_graph().get_collection(VAR_ASSIGN_COLLECTION): 913 session.run(op) 914 915 super(_FunctionConverterDataInGraph, self).__init__( 916 func, 917 lower_control_flow, 918 aggressive_inlining, 919 variable_names_allowlist, 920 variable_names_denylist) 921 922 def _eval(self, tensor): 923 """Returns the value in the tensor. Must be implemented in sub-classes.""" 924 return self._session.run(tensor) 925 926 927class _SessionConverterData(_ConverterData): 928 """Container for Session-based conversion data.""" 929 930 def __init__(self, 931 session, 932 graph_def, 933 output_node_names, 934 variable_names_allowlist=None, 935 variable_names_denylist=None): 936 graph_def = graph_util.extract_sub_graph(graph_def, output_node_names) 937 super(_SessionConverterData, self).__init__( 938 graph_def, 939 variable_names_allowlist=variable_names_allowlist, 940 variable_names_denylist=variable_names_denylist) 941 942 nodes_to_convert = [] 943 tensor_names_to_convert = [] 944 for node in self.graph_def.node: 945 if node.op in ["Variable", "VariableV2", "VarHandleOp"]: 946 tensor_name = node.name 947 if not self._should_convert(tensor_name): 948 continue 949 if node.op == "VarHandleOp": 950 tensor_name = tensor_name + "/Read/ReadVariableOp" 951 nodes_to_convert.append(node) 952 tensor_names_to_convert.append(tensor_name + ":0") 953 954 if tensor_names_to_convert: 955 converted_tensors = session.run(tensor_names_to_convert) 956 for node, tensor_value in zip(nodes_to_convert, converted_tensors): 957 self._tensor_data[node.name] = _TensorData( 958 numpy=tensor_value, dtype=node.attr["dtype"].type, index=None) 959 960 961def disable_lower_using_switch_merge(graph_def): 962 """Set '_lower_using_switch_merge' attributes to False. 963 964 Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs 965 in each function's graph. 966 967 Args: 968 graph_def: GraphDef proto. 969 970 Returns: 971 GraphDef 972 """ 973 output_graph_def = graph_pb2.GraphDef() 974 output_graph_def.CopyFrom(graph_def) 975 976 def disable_control_flow_lowering(node): 977 if node.op in _CONTROL_FLOW_OPS: 978 node.attr["_lower_using_switch_merge"].b = False 979 980 for node in output_graph_def.node: 981 disable_control_flow_lowering(node) 982 983 if output_graph_def.library: 984 for func in output_graph_def.library.function: 985 for node in func.node_def: 986 disable_control_flow_lowering(node) 987 return output_graph_def 988 989 990def _run_inline_graph_optimization(func, lower_control_flow, 991 aggressive_inlining): 992 """Apply function inline optimization to the graph. 993 994 Returns the GraphDef after Grappler's function inlining optimization is 995 applied. This optimization does not work on models with control flow. 996 997 Args: 998 func: ConcreteFunction. 999 lower_control_flow: Boolean indicating whether or not to lower control flow 1000 ops such as If and While. (default True) 1001 aggressive_inlining: Boolean indicating whether or not to do aggressive 1002 function inlining (might be unsafe if function has stateful ops not 1003 properly connected to control outputs). 1004 1005 Returns: 1006 GraphDef 1007 """ 1008 graph_def = func.graph.as_graph_def() 1009 if not lower_control_flow: 1010 graph_def = disable_lower_using_switch_merge(graph_def) 1011 1012 # In some cases, a secondary implementation of the function (e.g. for GPU) is 1013 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in 1014 # TF2 produces a CuDNN-based RNN for GPU). 1015 # This function suppose to inline all functions calls, but "api_implements" 1016 # prevents this from happening. Removing the attribute solves the problem. 1017 # To learn more about "api_implements", see: 1018 # tensorflow/core/grappler/optimizers/implementation_selector.h 1019 for function in graph_def.library.function: 1020 if "api_implements" in function.attr: 1021 del function.attr["api_implements"] 1022 1023 meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) 1024 1025 # Clear the initializer_name for the variables collections, since they are not 1026 # needed after saved to saved_model. 1027 for name in [ 1028 "variables", "model_variables", "trainable_variables", "local_variables" 1029 ]: 1030 raw_list = [] 1031 for raw in meta_graph.collection_def["variables"].bytes_list.value: 1032 variable = variable_pb2.VariableDef() 1033 variable.ParseFromString(raw) 1034 variable.ClearField("initializer_name") 1035 raw_list.append(variable.SerializeToString()) 1036 meta_graph.collection_def[name].bytes_list.value[:] = raw_list 1037 1038 # Add a collection 'train_op' so that Grappler knows the outputs. 1039 fetch_collection = meta_graph_pb2.CollectionDef() 1040 for array in func.inputs + func.outputs: 1041 fetch_collection.node_list.value.append(array.name) 1042 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 1043 1044 # Initialize RewriterConfig with everything disabled except function inlining. 1045 config = config_pb2.ConfigProto() 1046 rewrite_options = config.graph_options.rewrite_options 1047 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 1048 rewrite_options.optimizers.append("function") 1049 if aggressive_inlining: 1050 rewrite_options.function_optimization =\ 1051 rewriter_config_pb2.RewriterConfig.AGGRESSIVE 1052 return tf_optimizer.OptimizeGraph(config, meta_graph) 1053 1054 1055def _construct_concrete_function(func, output_graph_def, 1056 converted_input_indices): 1057 """Constructs a concrete function from the `output_graph_def`. 1058 1059 Args: 1060 func: ConcreteFunction 1061 output_graph_def: GraphDef proto. 1062 converted_input_indices: Set of integers of input indices that were 1063 converted to constants. 1064 1065 Returns: 1066 ConcreteFunction. 1067 """ 1068 # Create a ConcreteFunction from the new GraphDef. 1069 input_tensors = func.graph.internal_captures 1070 converted_inputs = object_identity.ObjectIdentitySet( 1071 [input_tensors[index] for index in converted_input_indices]) 1072 not_converted_inputs = [ 1073 tensor for tensor in func.inputs if tensor not in converted_inputs 1074 ] 1075 not_converted_inputs_map = { 1076 tensor.name: tensor for tensor in not_converted_inputs 1077 } 1078 1079 new_input_names = [tensor.name for tensor in not_converted_inputs] 1080 new_output_names = [tensor.name for tensor in func.outputs] 1081 1082 # Remove old functions to use updated functions from graph def. 1083 for f in output_graph_def.library.function: 1084 if context.context().has_function(f.signature.name): 1085 context.context().remove_function(f.signature.name) 1086 1087 new_func = wrap_function.function_from_graph_def(output_graph_def, 1088 new_input_names, 1089 new_output_names) 1090 1091 # Manually propagate shape for input tensors where the shape is not correctly 1092 # propagated. Scalars shapes are lost when wrapping the function. 1093 for input_tensor in new_func.inputs: 1094 input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) 1095 return new_func 1096 1097 1098def _replace_variables_by_constants(converter_data): 1099 """Replaces variables by constants on a given graph. 1100 1101 Given a _ConverterData instance with converted variables in its tensor_data 1102 field, create a new graph where the respective variables are replaced with the 1103 converted constants. 1104 1105 Args: 1106 converter_data: A pre-populated _ConverterData instance. 1107 1108 Returns: 1109 The converted graph. 1110 """ 1111 input_graph = _GraphDef(converter_data.graph_def) 1112 1113 for tensor_name, tensor_data in converter_data.tensor_data.items(): 1114 input_graph.nodes[tensor_name].convert_variable_to_constant( 1115 None, tensor_data) 1116 1117 converted_graph = input_graph.converted_self().graph_def 1118 1119 converted_input_indices = { 1120 t.index 1121 for t in converter_data.tensor_data.values() 1122 if t.index is not None 1123 } 1124 1125 return converted_graph, converted_input_indices 1126 1127 1128def convert_variables_to_constants_v2(func, 1129 lower_control_flow=True, 1130 aggressive_inlining=False): 1131 """Replaces all the variables in a graph with constants of the same values. 1132 1133 TensorFlow 2.0 function for converting all Variable ops into Const ops holding 1134 the same values. This makes it possible to describe the network fully with a 1135 single GraphDef file, and allows the removal of a lot of ops related to 1136 loading and saving the variables. This function runs Grappler's function 1137 inlining optimization in order to return a single subgraph. 1138 1139 The current implementation only works for graphs that do not contain any 1140 control flow or embedding related ops. 1141 1142 Args: 1143 func: ConcreteFunction. 1144 lower_control_flow: Boolean indicating whether or not to lower control flow 1145 ops such as If and While. (default True) 1146 aggressive_inlining: Boolean indicating whether or not to do aggressive 1147 function inlining (might be unsafe if function has stateful ops, not 1148 properly connected to control outputs). (default False) 1149 1150 Returns: 1151 ConcreteFunction containing a simplified version of the original. 1152 """ 1153 1154 converter_data = _FunctionConverterDataInEager( 1155 func=func, 1156 lower_control_flow=lower_control_flow, 1157 aggressive_inlining=aggressive_inlining) 1158 1159 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1160 converter_data=converter_data) 1161 1162 return _construct_concrete_function(func, output_graph_def, 1163 converted_input_indices) 1164 1165 1166def convert_var_to_const_function_in_v1(func, 1167 lower_control_flow=True, 1168 aggressive_inlining=False): 1169 """Replaces all the variables in a graph with constants of the same values. 1170 1171 This function works as same as convert_variables_to_constants_v2, but it 1172 should be used in Graph mode. It is a temporary solution when users want to 1173 integrate their models written in TF2 with infra that requires TF1 mode. 1174 1175 The current implementation only works for graphs that do not contain any 1176 control flow or embedding related ops. 1177 1178 The function must be called in a Session context. 1179 1180 Args: 1181 func: ConcreteFunction. 1182 lower_control_flow: Boolean indicating whether or not to lower control flow 1183 ops such as If and While. (default True) 1184 aggressive_inlining: Boolean indicating whether or not to do aggressive 1185 function inlining (might be unsafe if function has stateful ops, not 1186 properly connected to control outputs). (default False) 1187 1188 Raises: 1189 RuntimeError: If no Session context is present. 1190 1191 Returns: 1192 ConcreteFunction containing a simplified version of the original. 1193 """ 1194 1195 session = ops.get_default_session() 1196 if session is None: 1197 raise RuntimeError( 1198 "The conversion must be carried out in a Session context.") 1199 1200 converter_data = _FunctionConverterDataInGraph( 1201 func=func, 1202 lower_control_flow=lower_control_flow, 1203 aggressive_inlining=aggressive_inlining, 1204 session=session) 1205 1206 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1207 converter_data=converter_data) 1208 1209 return _construct_concrete_function(func, output_graph_def, 1210 converted_input_indices) 1211 1212 1213def convert_variables_to_constants_v2_as_graph(func, 1214 lower_control_flow=True, 1215 aggressive_inlining=False): 1216 """Replaces all the variables in a graph with constants of the same values. 1217 1218 This function works as same as convert_variables_to_constants_v2, but it 1219 returns the intermediate `GraphDef` as well. This `GraphDef` contains all the 1220 debug information after all the transformations in the frozen phase. 1221 1222 Args: 1223 func: ConcreteFunction. 1224 lower_control_flow: Boolean indicating whether or not to lower control flow 1225 ops such as If and While. (default True) 1226 aggressive_inlining: Boolean indicating whether or not to do aggressive 1227 function inlining (might be unsafe if function has stateful ops, not 1228 properly connected to control outputs). 1229 1230 Returns: 1231 ConcreteFunction containing a simplified version of the original, and also 1232 the intermediate GraphDef containing the node debug information for the 1233 transformations in the frozen phase. 1234 """ 1235 converter_data = _FunctionConverterDataInEager( 1236 func=func, 1237 lower_control_flow=lower_control_flow, 1238 aggressive_inlining=aggressive_inlining) 1239 1240 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1241 converter_data=converter_data) 1242 1243 frozen_func = _construct_concrete_function(func, output_graph_def, 1244 converted_input_indices) 1245 return frozen_func, output_graph_def 1246 1247 1248def convert_variables_to_constants_from_session_graph( 1249 session, 1250 graph_def, 1251 output_node_names, 1252 variable_names_allowlist=None, 1253 variable_names_denylist=None): 1254 """Replaces all the variables in a graph with constants of the same values. 1255 1256 This function works similarly to convert_variables_to_constants_v2, but it 1257 retrieves the constant values from a Session instead of from a 1258 ConcreteFunction. This is useful when converting graphs generated from 1259 TensorFlow V1, where ConcreteFunctions are not available. This also differs 1260 from graph_util.convert_variables_to_constants in that it supports resource 1261 variables when V2 control flow constructions are present. 1262 1263 Args: 1264 session: Active TensorFlow session containing the variables. 1265 graph_def: A GraphDef to convert. 1266 output_node_names: List of name strings for the result nodes of the graph. 1267 variable_names_allowlist: The set of variable names to convert (by default, 1268 all variables are converted). 1269 variable_names_denylist: The set of variable names to omit converting to 1270 constants. 1271 1272 Returns: 1273 An optimized GraphDef. 1274 """ 1275 graph_def, _ = _replace_variables_by_constants( 1276 converter_data=_SessionConverterData( 1277 session=session, 1278 graph_def=graph_def, 1279 output_node_names=output_node_names, 1280 variable_names_allowlist=variable_names_allowlist, 1281 variable_names_denylist=variable_names_denylist)) 1282 return graph_def 1283