1# Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to convert variables to constants in TensorFlow 2.0.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import collections 22import numpy as np 23 24from tensorflow.core.framework import attr_value_pb2 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.framework import tensor_shape_pb2 27from tensorflow.core.framework import variable_pb2 28from tensorflow.core.protobuf import config_pb2 29from tensorflow.core.protobuf import meta_graph_pb2 30from tensorflow.core.protobuf import rewriter_config_pb2 31from tensorflow.python.eager import context 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import errors 34from tensorflow.python.framework import graph_util 35from tensorflow.python.framework import ops 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.grappler import tf_optimizer 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import variables 40from tensorflow.python.platform import tf_logging as logging 41from tensorflow.python.training.saver import export_meta_graph 42from tensorflow.python.util import lazy_loader 43from tensorflow.python.util import object_identity 44 45# Lazy load the single eager module to avoid introducing new dependencies for 46# graph_util:convert_variables_to_constants (eg in 47# tensorflow/contrib/session_bundle:session_bundle_py_test). 48wrap_function = lazy_loader.LazyLoader( 49 "wrap_function", globals(), 50 "tensorflow.python.eager.wrap_function") 51 52# Used in _FunctionConverterDataInGraph(). 53VAR_ASSIGN_COLLECTION = "extra_var_assign_ops" 54_CONDITIONAL_OPS = set(["If", "StatelessIf"]) 55_LOOP_OPS = set(["While", "StatelessWhile"]) 56_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS) 57 58 59class _TensorData( 60 collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])): 61 """Data about a tensor that was converted to a constant.""" 62 __slots__ = () 63 64 @property 65 def dtype_attr(self): 66 return attr_value_pb2.AttrValue(type=self.dtype) 67 68 69class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])): 70 """An endpoint in a graph.""" 71 __slots__ = () 72 73 def __str__(self): 74 return "{}[{}]".format(self.convertible, self.index) 75 76 77class _Edge(collections.namedtuple("_Edge", ["source", "destination"])): 78 """A directed graph edge.""" 79 __slots__ = () 80 81 def __str__(self): 82 return "{} -> {}".format(self.source, self.destination) 83 84 85class _Convertible(object): 86 """An entity that can have variables converted to constants.""" 87 88 def __init__(self, enclosing_graph): 89 self._enclosing_graph = enclosing_graph 90 self._outgoing_edges = [] 91 self._converted_self = None 92 93 def converted_self(self): 94 """A copy of this Convertible to be modified during conversion. 95 96 Returns: 97 Implementations should return the copied instance, which in turn should 98 be contained in converted_enclosing_graph(). This instance is the one that 99 will be modified during conversion. Its main use will be in the 100 implementations of convert_variable_to_constant(). 101 """ 102 raise NotImplementedError() 103 104 def convert_variable_to_constant(self, incoming_edge, tensor_data): 105 """Converts a variable in this Convertible and its dependencies. 106 107 This method should make sure that a converted copy of itself is present in 108 the converted graph, and that all Convertibles depending on this one also go 109 through the same process. 110 111 Args: 112 incoming_edge: The graph edge into this Convertible that is being 113 converted to a constant. 114 tensor_data: The tensor representing the constant. 115 """ 116 raise NotImplementedError() 117 118 def create_edges(self): 119 """Calls add_outgoing_edge for all edges known to this Convertible. 120 121 This is used to build the graph dependencies, so that conversion of 122 variables to constants can be properly propagated through the graph. Usually 123 this method will call add_outgoing_edge() to all the Convertible inputs. 124 """ 125 raise NotImplementedError() 126 127 def add_outgoing_edge(self, edge): 128 """Adds an outgoing edge to the Convertible's list of edges. 129 130 Args: 131 edge: The outgoing edge (its source should be 'self'). 132 """ 133 self._outgoing_edges.append(edge) 134 135 @property 136 def converted_enclosing_graph(self): 137 """The graph being converted.""" 138 return self._enclosing_graph.converted_self() 139 140 @property 141 def outgoing_edges(self): 142 """The list of edges starting at this Convertible.""" 143 return self._outgoing_edges 144 145 146class _Function(_Convertible): 147 """A library function Convertible. 148 149 Edges into functions are edges from node _inputs_ into function _inputs_: 150 Functions get their input from their callers, not from node outputs, and the 151 callers in turn get those values as inputs. 152 """ 153 154 def __init__(self, function, enclosing_graph): 155 super(_Function, self).__init__(enclosing_graph) 156 self._function = function 157 self._nodes = { 158 n.name: 159 _Node.new(node=n, function=self, enclosing_graph=enclosing_graph) 160 for n in function.node_def 161 } 162 163 def __str__(self): 164 return self.function.signature.name 165 166 @property 167 def function(self): 168 return self._function 169 170 @property 171 def nodes(self): 172 return self._nodes 173 174 def converted_self(self): 175 """The Function copy to be converted. 176 177 The copy will be renamed according to the graph's converted_function_name 178 map, to ensure the name does not match anything currently in TensorFlow's 179 function cache. 180 181 Returns: 182 The function instance to be converted. 183 """ 184 if self._converted_self is None: 185 old_name = self.function.signature.name 186 new_name = self._enclosing_graph.converted_function_names[old_name] 187 self.converted_enclosing_graph.rename_function(old_name, new_name) 188 self._converted_self = self.converted_enclosing_graph.functions[new_name] 189 return self._converted_self 190 191 def convert_variable_to_constant(self, incoming_edge, tensor_data): 192 """Converts one function argument into a constant. 193 194 Args: 195 incoming_edge: The edge into the argument to be converted. 196 tensor_data: The constant value. 197 """ 198 function = self.converted_self().function 199 index = incoming_edge.destination.index 200 function.signature.input_arg[index].type = tensor_data.dtype 201 202 for edge in self.outgoing_edges: 203 if edge.source.index == index: 204 edge.destination.convertible.convert_variable_to_constant( 205 edge, tensor_data) 206 207 def create_edges(self): 208 for n in self._nodes.values(): 209 n.create_edges() 210 211 212class _Node(_Convertible): 213 """A Convertible NodeDef.""" 214 215 def __init__(self, node, function, enclosing_graph): 216 super(_Node, self).__init__(enclosing_graph) 217 self._node = node 218 self._function = function 219 220 def __str__(self): 221 return self._node.name 222 223 @staticmethod 224 def new(node, function, enclosing_graph): 225 """Creates a new _Node base on its operation type.""" 226 if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]: 227 return _VarHandle(node, function, enclosing_graph) 228 elif node.op == "Case": 229 return _Case(node, function, enclosing_graph) 230 elif node.op == "Merge": 231 return _Merge(node, function, enclosing_graph) 232 elif node.op == "PartitionedCall": 233 return _PartitionedCall(node, function, enclosing_graph) 234 elif node.op == "StatefulPartitionedCall": 235 return _PartitionedCall(node, function, enclosing_graph) 236 elif node.op == "ReadVariableOp": 237 return _ReadVariable(node, function, enclosing_graph) 238 elif node.op == "ResourceGather": 239 return _ResourceGather(node, function, enclosing_graph) 240 elif node.op == "ResourceGatherNd": 241 return _ResourceGatherNd(node, function, enclosing_graph) 242 elif node.op in ["If", "StatelessIf"]: 243 return _If(node, function, enclosing_graph) 244 elif node.op in ["While", "StatelessWhile"]: 245 return _While(node, function, enclosing_graph) 246 elif node.op in [ 247 "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]: 248 return _Intermediate(node, function, enclosing_graph) 249 else: 250 return _Node(node, function, enclosing_graph) 251 252 @property 253 def node(self): 254 return self._node 255 256 @property 257 def container(self): 258 """The node container (either a graph or a function).""" 259 if self._function is not None: 260 return self._function.function 261 return self._enclosing_graph.graph_def 262 263 def converted_self(self): 264 """The NodeDef to be converted. 265 266 Returns: 267 The NodeDef to be converted, which can come from either a graph for a 268 function. Derived classes should call this (via 'super') to make sure the 269 node is retrieved from the right place. 270 """ 271 if self._converted_self is None: 272 source = self._function or self._enclosing_graph 273 self._converted_self = source.converted_self().nodes[self._node.name] 274 return self._converted_self 275 276 def convert_variable_to_constant(self, incoming_edge, tensor_data): 277 pass 278 279 def create_edges(self): 280 for index, name in enumerate(self._node.input): 281 # Discard edges from control inputs. 282 if name[0] == "^": 283 continue 284 source = self.resolve_input(name) 285 source.convertible.add_outgoing_edge( 286 _Edge(source, _EndPoint(self, index))) 287 288 def resolve_input(self, input_name): 289 """Resolves an input into its _EndPoint. 290 291 A NodeDef's input name can refer to either global NodeDefs (in the 292 GraphDef's node list), a NodeDef in a function's node list, or a Function 293 (in the GraphDef's function library). The name can also carry semantic 294 information, depending on whether it starts with "^". This method handles 295 all that logic in order to find the object to which the input name refers 296 to. 297 298 Args: 299 input_name: The input name to resolve. 300 301 Returns: 302 The object referred to by 'input_name'. 303 """ 304 305 # The logic below oversimplifies the semantics, but is good enough for the 306 # purposes of converting to constants. The introduction of new types of 307 # operations may change this, forcing the code to be more generic. 308 # 309 # In particular, we are assuming that the lack of an index suffix means 310 # ":0", when it could mean "all the outputs of a node." This works now 311 # because converting to constants relies very little on output types, and 312 # when it does it specializes its treatment in dedicated classes. 313 name_elts = input_name.split(":") 314 source_name = name_elts[0] 315 if source_name[0] == "^": 316 source_name = source_name[1:] 317 source_index = 0 318 if len(name_elts) > 1 and name_elts[-1].isnumeric(): 319 source_index = int(name_elts[-1]) 320 321 if self._function is None: 322 return _EndPoint(self._enclosing_graph.nodes[source_name], source_index) 323 324 if source_index != 0 or source_name in self._function.nodes: 325 return _EndPoint(self._function.nodes[source_name], source_index) 326 327 inputs = [i.name for i in self._function.function.signature.input_arg] 328 return _EndPoint(self._function, inputs.index(source_name)) 329 330 def update_dtype(self, attr_name, index, dtype): 331 """Changes the type of a given input. 332 333 Args: 334 attr_name: The NodeDef attribute containing the type to change. 335 index: The index of the input type to change. 336 dtype: The type to change to. 337 """ 338 attr = self._node.attr[attr_name] 339 num_types = 0 340 # Check for various 'oneof' possibilities, and update the type if 341 # index in range. 342 if attr.HasField("list"): 343 types = attr.list.type 344 num_types = len(types) 345 if num_types > index: 346 types[index] = dtype 347 return 348 elif attr.HasField("type"): 349 num_types = 1 350 if index == 0: 351 attr.type = dtype 352 return 353 raise ValueError( 354 "Index %d out of range for node(%s).attr(%s), which has %d elements." % 355 (index, self._node.name, attr_name, num_types)) 356 357 358class _Intermediate(_Node): 359 """Specialization of _Node to intermediate ops.""" 360 361 def convert_variable_to_constant(self, incoming_edge, tensor_data): 362 node = self.converted_self() 363 node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype) 364 if "_output_shapes" in node.node.attr: 365 del node.node.attr["_output_shapes"] 366 for edge in self.outgoing_edges: 367 edge.destination.convertible.convert_variable_to_constant( 368 edge, tensor_data) 369 370 371class _Merge(_Node): 372 """Specialization of _Node to Merge ops.""" 373 374 def convert_variable_to_constant(self, incoming_edge, tensor_data): 375 # The Merge operation has a single type for all its inputs, the number of 376 # which is reflected in the "N" attribute. For the time being, we assume 377 # that unilaterally changing all of them at once is ok. 378 super(_Merge, self).convert_variable_to_constant( 379 _Edge(incoming_edge.source, 380 _Edge(incoming_edge.destination.convertible, 0)), tensor_data) 381 382 383class _VarHandle(_Node): 384 """Specialization of _Node to VarHandleOp.""" 385 386 def convert_variable_to_constant(self, incoming_edge, tensor_data): 387 tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy, 388 tensor_data.dtype, 389 tensor_data.numpy.shape) 390 391 node = self.converted_self().node 392 node.Clear() 393 node.name = self._node.name 394 node.op = "Const" 395 node.attr["dtype"].CopyFrom(tensor_data.dtype_attr) 396 node.attr["value"].tensor.CopyFrom(tensor_proto) 397 398 for edge in self.outgoing_edges: 399 edge.destination.convertible.convert_variable_to_constant( 400 edge, tensor_data) 401 402 403class _ResourceGather(_Node): 404 """Specialization of _Node to ResourceGather.""" 405 406 def convert_variable_to_constant(self, incoming_edge, tensor_data): 407 # We currently skip the conversion if this is inside a function. 408 if self._function is not None: 409 return 410 if self._node.attr["batch_dims"].i != 0: 411 raise ValueError("batch_dims != 0 is not supported by freeze_graph.") 412 axis_node_name = self._node.name + "/axis" 413 axis_dtype = self._node.attr["Tindices"] 414 axis_data = np.array(self._node.attr["batch_dims"].i) 415 output_axis_node = self.converted_self().container.node.add() 416 output_axis_node.name = axis_node_name 417 output_axis_node.op = "Const" 418 output_axis_node.attr["dtype"].CopyFrom(axis_dtype) 419 tensor = tensor_util.make_tensor_proto( 420 axis_data, dtype=axis_dtype.type, shape=axis_data.shape) 421 output_axis_node.attr["value"].tensor.CopyFrom(tensor) 422 423 output_node = self.converted_self().node 424 output_node.Clear() 425 output_node.name = self._node.name 426 output_node.op = "GatherV2" 427 output_node.input.extend( 428 [self._node.input[0], self._node.input[1], axis_node_name]) 429 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 430 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 431 output_node.attr["Taxis"].CopyFrom(axis_dtype) 432 if "_class" in self._node.attr: 433 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 434 435 436class _ResourceGatherNd(_Node): 437 """Specialization of _Node to ResourceGatherNd.""" 438 439 def convert_variable_to_constant(self, incoming_edge, tensor_data): 440 output_node = self.converted_self().node 441 output_node.Clear() 442 output_node.name = self._node.name 443 output_node.op = "GatherNd" 444 output_node.input.extend([self._node.input[0], self._node.input[1]]) 445 output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"]) 446 output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"]) 447 if "_class" in self._node.attr: 448 output_node.attr["_class"].CopyFrom(self._node.attr["_class"]) 449 450 451class _ReadVariable(_Node): 452 """Specialization of _Node to ReadVariableOp.""" 453 454 def convert_variable_to_constant(self, incoming_edge, tensor_data): 455 node = self.converted_self().node 456 node.Clear() 457 node.name = self._node.name 458 node.op = "Identity" 459 460 node.input.append(self._node.input[0]) 461 node.attr["T"].CopyFrom(self._node.attr["dtype"]) 462 if "_class" in self._node.attr: 463 node.attr["_class"].CopyFrom(self._node.attr["_class"]) 464 465 # If the ReadVariableOp is part of a function, then every node having the 466 # ReadVariableOp one as its input will refer to it using a ":value" 467 # syntax. We need to change that to ":output". 468 if self._function is not None: 469 for edge in self.outgoing_edges: 470 index = edge.destination.index 471 dest = edge.destination.convertible.converted_self() 472 if isinstance(dest, _Node): 473 input_name_parts = dest.node.input[index].split(":") 474 if len(input_name_parts) > 1 and input_name_parts[1] == "value": 475 input_name_parts[1] = "output" 476 dest.node.input[index] = ":".join(input_name_parts) 477 478 479class _FunctionCaller(_Node): 480 """A base class for Convertibles that reference functions.""" 481 482 def __init__(self, node, function, enclosing_graph, first_function_input, 483 type_attribute, function_attributes): 484 """Initializes a _FunctionCaller. 485 486 Args: 487 node: As in _Node. 488 function: As in _Node. 489 enclosing_graph: As in _Node. 490 first_function_input: The index of the first NodeDef input that is tied to 491 the function inputs. It is assumed that the rest of the NodeDef inputs 492 map one to one to function inputs. 493 type_attribute: The name of the NodeDef attribute that defines the input 494 types. It is assumed that the types listed here map one-to-one with the 495 function inputs (that is, they do _not_ specify types for inputs that 496 are not passed to functions). 497 function_attributes: The names of the NodeDef attributes containing 498 references to functions. 499 """ 500 super(_FunctionCaller, self).__init__(node, function, enclosing_graph) 501 self._first_function_input = first_function_input 502 self._type_attribute = type_attribute 503 self._function_attributes = function_attributes 504 505 def converted_self(self): 506 if self._converted_self is None: 507 node = super(_FunctionCaller, self).converted_self().node 508 converted_names = self._enclosing_graph.converted_function_names 509 for attr_name in self._function_attributes: 510 attr = node.attr[attr_name] 511 if attr.HasField("func"): 512 attr.func.name = converted_names[attr.func.name] 513 elif attr.HasField("list"): 514 for func in attr.list.func: 515 func.name = converted_names[func.name] 516 return self._converted_self 517 518 def convert_variable_to_constant(self, incoming_edge, tensor_data): 519 node = self.converted_self() 520 index = incoming_edge.destination.index 521 if index >= self._first_function_input: 522 node.update_dtype(self._type_attribute, 523 index - self._first_function_input, tensor_data.dtype) 524 525 # The loop below is reasonable but not correct in general: 526 # The outgoing edges going into the functions are correct, because the 527 # inputs map to the function inputs. But the edges going into other nodes do 528 # not take into account the logic of the body function, which may do 529 # arbitrary things to the node's output: 530 # 531 # while x < 0: 532 # return y 533 # 534 # In this case, the node's ":0" output may map to its ":1 input". For the 535 # time being, then, we only process edges into functions. 536 for edge in self.outgoing_edges: 537 dest = edge.destination.convertible 538 if edge.source.index == index and isinstance(dest, _Function): 539 dest.convert_variable_to_constant(edge, tensor_data) 540 541 def create_edges(self): 542 """Creates edges related to a function caller. 543 544 Edges from a function caller to its called functions are always edges from 545 _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on 546 its own inputs. 547 """ 548 super(_FunctionCaller, self).create_edges() 549 for attr_name in self._function_attributes: 550 attr = self._node.attr[attr_name] 551 if attr.HasField("func"): 552 function = self._enclosing_graph.functions[attr.func.name] 553 for index in range(len(self._node.input) - self._first_function_input): 554 self.add_outgoing_edge( 555 _Edge( 556 _EndPoint(self, index + self._first_function_input), 557 _EndPoint(function, index))) 558 elif attr.HasField("list"): 559 for func in attr.list.func: 560 function = self._enclosing_graph.functions[func.name] 561 for index in range( 562 len(self._node.input) - self._first_function_input): 563 self.add_outgoing_edge( 564 _Edge( 565 _EndPoint(self, index + self._first_function_input), 566 _EndPoint(function, index))) 567 568 569class _If(_FunctionCaller): 570 """Specialization of _Node to If-like operations.""" 571 572 def __init__(self, node, function, enclosing_graph): 573 super(_If, self).__init__( 574 node, 575 function, 576 enclosing_graph, 577 first_function_input=1, 578 type_attribute="Tin", 579 function_attributes=["then_branch", "else_branch"]) 580 581 582class _Case(_FunctionCaller): 583 """Specialization of _Node to Case-like operations.""" 584 585 def __init__(self, node, function, enclosing_graph): 586 super(_Case, self).__init__( 587 node, 588 function, 589 enclosing_graph, 590 first_function_input=1, 591 type_attribute="Tin", 592 function_attributes=["branches"]) 593 594 595class _PartitionedCall(_FunctionCaller): 596 """Specialization of _Node to PartitionedCall-like operations.""" 597 598 def __init__(self, node, function, enclosing_graph): 599 super(_PartitionedCall, self).__init__( 600 node, 601 function, 602 enclosing_graph, 603 first_function_input=0, 604 type_attribute="Tin", 605 function_attributes=["f"]) 606 607 608class _While(_FunctionCaller): 609 """Specialization of _Node to While-like operations.""" 610 611 def __init__(self, node, function, enclosing_graph): 612 super(_While, self).__init__( 613 node, 614 function, 615 enclosing_graph, 616 first_function_input=0, 617 type_attribute="T", 618 function_attributes=["body", "cond"]) 619 620 def convert_variable_to_constant(self, incoming_edge, tensor_data): 621 super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data) 622 node = self.converted_self() 623 if node.node.attr["output_shapes"].list.shape: 624 node.node.attr["output_shapes"].list.shape[ 625 incoming_edge.destination.index].CopyFrom( 626 tensor_shape_pb2.TensorShapeProto(dim=[ 627 tensor_shape_pb2.TensorShapeProto.Dim(size=dim) 628 for dim in tensor_data.numpy.shape 629 ])) 630 631 # The while's body inputs and outputs have the same type, so here we can go 632 # ahead and change that function's output type. 633 body_name = self._node.attr["body"].func.name 634 body = self._enclosing_graph.functions[body_name].converted_self().function 635 body.signature.output_arg[ 636 incoming_edge.destination.index].type = tensor_data.dtype 637 638 639class _GraphDef(_Convertible): 640 """A convertible GraphDef.""" 641 642 def __init__(self, graph_def): 643 super(_GraphDef, self).__init__(enclosing_graph=None) 644 self._graph_def = graph_def 645 self._nodes = { 646 n.name: _Node.new(node=n, function=None, enclosing_graph=self) 647 for n in graph_def.node 648 } 649 self._functions = { 650 f.signature.name: _Function(f, enclosing_graph=self) 651 for f in graph_def.library.function 652 } 653 self.create_edges() 654 self._converted_function_names = None 655 656 @property 657 def graph_def(self): 658 return self._graph_def 659 660 @property 661 def nodes(self): 662 return self._nodes 663 664 @property 665 def functions(self): 666 return self._functions 667 668 @property 669 def converted_function_names(self): 670 """Map from original to new function names. 671 672 In order to avoid conflicts (two functions with the same name, one converted 673 and one not), we need to change the name of every converted function to 674 something that is hopefully unique. 675 676 Returns: 677 Map from original to new suggested function names. 678 """ 679 if self._converted_function_names is None: 680 parsed_names = [] # List of (id, base_name, original_name) 681 for name in self.functions: 682 elements = name.rsplit("_", 1) 683 if len(elements) == 2 and elements[1].isnumeric(): 684 parsed_names.append((int(elements[1]), elements[0], name)) 685 else: 686 parsed_names.append((-1, name, name)) 687 self._converted_function_names = { 688 name: "{}_frozen_{}".format(base_name, ops.uid()) 689 for (_, base_name, name) in sorted(parsed_names) 690 } 691 692 return self._converted_function_names 693 694 def rename_function(self, old_name, new_name): 695 func = self.functions.pop(old_name) 696 func.function.signature.name = new_name 697 self.functions[new_name] = func 698 699 def converted_self(self): 700 if self._converted_self is None: 701 copied_graph = graph_pb2.GraphDef() 702 copied_graph.CopyFrom(self._graph_def) 703 self._converted_self = _GraphDef(copied_graph) 704 return self._converted_self 705 706 def create_edges(self): 707 for n in self._nodes.values(): 708 n.create_edges() 709 for f in self._functions.values(): 710 f.create_edges() 711 712 713class _ConverterData(object): 714 """Container for constant conversion supporting data. 715 716 The data includes the graph being converted, and the pre-converted 717 tensors. This class will be specialized for ConcreteFunction and Session-based 718 conversions, as the means to obtain that data is different for each case. 719 """ 720 721 def __init__(self, 722 graph_def, 723 variable_names_allowlist=None, 724 variable_names_denylist=None): 725 self._graph_def = graph_def 726 self._tensor_data = {} 727 self._build_node_defs_list() 728 self._variable_names_allowlist = variable_names_allowlist 729 self._variable_names_denylist = variable_names_denylist 730 731 @property 732 def graph_def(self): 733 """The graph to be converted.""" 734 return self._graph_def 735 736 @property 737 def node_defs(self): 738 """All the node defs in the graph to be converted. 739 740 Returns: 741 A map from node name to the NodeDef for all NodeDefs in the graph, as well 742 as all control flow NodeDefs in the functions. 743 """ 744 return self._node_defs 745 746 @property 747 def tensor_data(self): 748 """A map from tensor name to its converted _TensorData.""" 749 return self._tensor_data 750 751 def _should_convert(self, name): 752 """Checks whether to convert the given variable name to a constant.""" 753 return (self._variable_names_allowlist is None or 754 name in self._variable_names_allowlist) and ( 755 self._variable_names_denylist is None or 756 name not in self._variable_names_denylist) 757 758 def _build_node_defs_list(self): 759 """Builds the list of NodeDefs in the GraphDef. 760 761 This list consists of all NodeDefs in the main graph as well as all control 762 flow NodeDefs in the functions. 763 764 The remaining NodeDefs in the functions are not included because the op 765 names 766 are not unique and the variables are handled differently than the main 767 graph. 768 The control flow ops need to be extracted because they are need their 769 attributes to be updated similar to the control flow ops in the main graph. 770 """ 771 self._node_defs = {node.name: node for node in self._graph_def.node} 772 773 if self._graph_def.library: 774 for func in self._graph_def.library.function: 775 self._node_defs.update({ 776 node.name: node 777 for node in func.node_def 778 if node.op in _CONTROL_FLOW_OPS 779 }) 780 781 782class _FunctionConverterData(_ConverterData): 783 """Container for ConcreteFunction-based conversion data.""" 784 785 def __init__(self, 786 func, 787 lower_control_flow, 788 aggressive_inlining, 789 variable_names_allowlist=None, 790 variable_names_denylist=None): 791 """Creates the conversion data for the given function. 792 793 Args: 794 func: ConcreteFunction. 795 lower_control_flow: Boolean indicating whether or not to lower control 796 flow ops such as If and While. 797 aggressive_inlining: Boolean indicating whether or not to do aggressive 798 function inlining (might be unsafe if function has stateful ops, not 799 properly connected to control outputs). 800 variable_names_allowlist: The set of variable names to convert (by 801 default, all variables are converted). 802 variable_names_denylist: The set of variable names to omit converting to 803 constants. 804 """ 805 806 self._func = func 807 # Inline the graph in order to remove functions when possible. 808 graph_def = _run_inline_graph_optimization(func, lower_control_flow, 809 aggressive_inlining) 810 super(_FunctionConverterData, self).__init__( 811 graph_def, 812 variable_names_allowlist=variable_names_allowlist, 813 variable_names_denylist=variable_names_denylist) 814 815 self._build_tensor_data() 816 817 def _eval(self, tensor): 818 """Returns the value in the tensor. Must be implemented in sub-classes.""" 819 raise errors.UnimplementedError( 820 "The evaluation method should be implemented in sub-classes.") 821 822 def _build_tensor_data(self): 823 """Caches the tensor data for all Placeholders in the given function.""" 824 map_index_to_variable = {} 825 for var in self._func.graph.variables: 826 for idx, captured_input in enumerate(self._func.captured_inputs): 827 if var.handle is captured_input: # pylint: disable=protected-access 828 map_index_to_variable[idx] = var 829 break 830 831 # Iterates through all captures which are represented as Placeholders. 832 for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures): 833 tensor_name = name_tensor.name.split(":")[0] 834 if not self._should_convert(tensor_name): 835 continue 836 if idx in map_index_to_variable: 837 data = self._eval(map_index_to_variable[idx]) 838 else: 839 if val_tensor.dtype == dtypes.resource: 840 logging.vlog(1, "Skip converting resource tensor %s" % tensor_name) 841 continue 842 data = np.array(self._eval(val_tensor)) 843 844 self._tensor_data[tensor_name] = _TensorData( 845 numpy=data, 846 dtype=dtypes.as_dtype(data.dtype).as_datatype_enum, 847 index=idx) 848 849 # Get data for VariableV2 ops (reference variables) that cannot be lifted. 850 for node in self.node_defs.values(): 851 if node.op == "VariableV2": 852 if not self._should_convert(node.name): 853 continue 854 if node.name not in self.tensor_data: 855 with self._func.graph.as_default(): 856 identity_node = array_ops.identity( 857 self._func.graph.as_graph_element(node.name + ":0")) 858 pruned_graph = self._func.prune([], [identity_node.name])()[0] 859 self._tensor_data[node.name] = _TensorData( 860 numpy=pruned_graph.numpy(), 861 dtype=node.attr["dtype"].type, 862 index=None) 863 864 865class _FunctionConverterDataInEager(_FunctionConverterData): 866 """Container for ConcreteFunction-based conversion data in Eager mode.""" 867 868 def _eval(self, tensor): 869 """Returns the value in the tensor. Must be implemented in sub-classes.""" 870 return tensor.numpy() 871 872 873class _FunctionConverterDataInGraph(_FunctionConverterData): 874 """Container for ConcreteFunction-based conversion data in Graph mode.""" 875 876 def __init__(self, 877 func, 878 lower_control_flow, 879 aggressive_inlining, 880 variable_names_allowlist=None, 881 variable_names_denylist=None, 882 session=None): 883 """Creates the conversion data for the given function. 884 885 Args: 886 func: ConcreteFunction. 887 lower_control_flow: Boolean indicating whether or not to lower control 888 flow ops such as If and While. 889 aggressive_inlining: Boolean indicating whether or not to do aggressive 890 function inlining (might be unsafe if function has stateful ops, not 891 properly connected to control outputs). 892 variable_names_allowlist: The set of variable names to convert (by 893 default, all variables are converted). 894 variable_names_denylist: The set of variable names to omit converting to 895 constants. 896 session: Session object. 897 """ 898 self._session = session 899 900 session.run(variables.global_variables_initializer()) 901 # Run extra assignment ops if needed. 902 # These assignments are run sequentially to ensure order. 903 for op in ops.get_default_graph().get_collection(VAR_ASSIGN_COLLECTION): 904 session.run(op) 905 906 super(_FunctionConverterDataInGraph, self).__init__( 907 func, 908 lower_control_flow, 909 aggressive_inlining, 910 variable_names_allowlist, 911 variable_names_denylist) 912 913 def _eval(self, tensor): 914 """Returns the value in the tensor. Must be implemented in sub-classes.""" 915 return self._session.run(tensor) 916 917 918class _SessionConverterData(_ConverterData): 919 """Container for Session-based conversion data.""" 920 921 def __init__(self, 922 session, 923 graph_def, 924 output_node_names, 925 variable_names_allowlist=None, 926 variable_names_denylist=None): 927 graph_def = graph_util.extract_sub_graph(graph_def, output_node_names) 928 super(_SessionConverterData, self).__init__( 929 graph_def, 930 variable_names_allowlist=variable_names_allowlist, 931 variable_names_denylist=variable_names_denylist) 932 933 nodes_to_convert = [] 934 tensor_names_to_convert = [] 935 for node in self.graph_def.node: 936 if node.op in ["Variable", "VariableV2", "VarHandleOp"]: 937 tensor_name = node.name 938 if not self._should_convert(tensor_name): 939 continue 940 if node.op == "VarHandleOp": 941 tensor_name = tensor_name + "/Read/ReadVariableOp" 942 nodes_to_convert.append(node) 943 tensor_names_to_convert.append(tensor_name + ":0") 944 945 if tensor_names_to_convert: 946 converted_tensors = session.run(tensor_names_to_convert) 947 for node, tensor_value in zip(nodes_to_convert, converted_tensors): 948 self._tensor_data[node.name] = _TensorData( 949 numpy=tensor_value, dtype=node.attr["dtype"].type, index=None) 950 951 952def disable_lower_using_switch_merge(graph_def): 953 """Set '_lower_using_switch_merge' attributes to False. 954 955 Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs 956 in each function's graph. 957 958 Args: 959 graph_def: GraphDef proto. 960 961 Returns: 962 GraphDef 963 """ 964 output_graph_def = graph_pb2.GraphDef() 965 output_graph_def.CopyFrom(graph_def) 966 967 def disable_control_flow_lowering(node): 968 if node.op in _CONTROL_FLOW_OPS: 969 node.attr["_lower_using_switch_merge"].b = False 970 971 for node in output_graph_def.node: 972 disable_control_flow_lowering(node) 973 974 if output_graph_def.library: 975 for func in output_graph_def.library.function: 976 for node in func.node_def: 977 disable_control_flow_lowering(node) 978 return output_graph_def 979 980 981def _run_inline_graph_optimization(func, lower_control_flow, 982 aggressive_inlining): 983 """Apply function inline optimization to the graph. 984 985 Returns the GraphDef after Grappler's function inlining optimization is 986 applied. This optimization does not work on models with control flow. 987 988 Args: 989 func: ConcreteFunction. 990 lower_control_flow: Boolean indicating whether or not to lower control flow 991 ops such as If and While. (default True) 992 aggressive_inlining: Boolean indicating whether or not to do aggressive 993 function inlining (might be unsafe if function has stateful ops not 994 properly connected to control outputs). 995 996 Returns: 997 GraphDef 998 """ 999 graph_def = func.graph.as_graph_def() 1000 if not lower_control_flow: 1001 graph_def = disable_lower_using_switch_merge(graph_def) 1002 1003 # In some cases, a secondary implementation of the function (e.g. for GPU) is 1004 # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in 1005 # TF2 produces a CuDNN-based RNN for GPU). 1006 # This function suppose to inline all functions calls, but "api_implements" 1007 # prevents this from happening. Removing the attribute solves the problem. 1008 # To learn more about "api_implements", see: 1009 # tensorflow/core/grappler/optimizers/implementation_selector.h 1010 for function in graph_def.library.function: 1011 if "api_implements" in function.attr: 1012 del function.attr["api_implements"] 1013 1014 meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph) 1015 1016 # Clear the initializer_name for the variables collections, since they are not 1017 # needed after saved to saved_model. 1018 for name in [ 1019 "variables", "model_variables", "trainable_variables", "local_variables" 1020 ]: 1021 raw_list = [] 1022 for raw in meta_graph.collection_def["variables"].bytes_list.value: 1023 variable = variable_pb2.VariableDef() 1024 variable.ParseFromString(raw) 1025 variable.ClearField("initializer_name") 1026 raw_list.append(variable.SerializeToString()) 1027 meta_graph.collection_def[name].bytes_list.value[:] = raw_list 1028 1029 # Add a collection 'train_op' so that Grappler knows the outputs. 1030 fetch_collection = meta_graph_pb2.CollectionDef() 1031 for array in func.inputs + func.outputs: 1032 fetch_collection.node_list.value.append(array.name) 1033 meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) 1034 1035 # Initialize RewriterConfig with everything disabled except function inlining. 1036 config = config_pb2.ConfigProto() 1037 rewrite_options = config.graph_options.rewrite_options 1038 rewrite_options.min_graph_nodes = -1 # do not skip small graphs 1039 rewrite_options.optimizers.append("function") 1040 if aggressive_inlining: 1041 rewrite_options.function_optimization =\ 1042 rewriter_config_pb2.RewriterConfig.AGGRESSIVE 1043 return tf_optimizer.OptimizeGraph(config, meta_graph) 1044 1045 1046def _construct_concrete_function(func, output_graph_def, 1047 converted_input_indices): 1048 """Constructs a concrete function from the `output_graph_def`. 1049 1050 Args: 1051 func: ConcreteFunction 1052 output_graph_def: GraphDef proto. 1053 converted_input_indices: Set of integers of input indices that were 1054 converted to constants. 1055 1056 Returns: 1057 ConcreteFunction. 1058 """ 1059 # Create a ConcreteFunction from the new GraphDef. 1060 input_tensors = func.graph.internal_captures 1061 converted_inputs = object_identity.ObjectIdentitySet( 1062 [input_tensors[index] for index in converted_input_indices]) 1063 not_converted_inputs = [ 1064 tensor for tensor in func.inputs if tensor not in converted_inputs 1065 ] 1066 not_converted_inputs_map = { 1067 tensor.name: tensor for tensor in not_converted_inputs 1068 } 1069 1070 new_input_names = [tensor.name for tensor in not_converted_inputs] 1071 new_output_names = [tensor.name for tensor in func.outputs] 1072 1073 # Remove old functions to use updated functions from graph def. 1074 for f in output_graph_def.library.function: 1075 if context.context().has_function(f.signature.name): 1076 context.context().remove_function(f.signature.name) 1077 1078 new_func = wrap_function.function_from_graph_def(output_graph_def, 1079 new_input_names, 1080 new_output_names) 1081 1082 # Manually propagate shape for input tensors where the shape is not correctly 1083 # propagated. Scalars shapes are lost when wrapping the function. 1084 for input_tensor in new_func.inputs: 1085 input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape) 1086 return new_func 1087 1088 1089def _replace_variables_by_constants(converter_data): 1090 """Replaces variables by constants on a given graph. 1091 1092 Given a _ConverterData instance with converted variables in its tensor_data 1093 field, create a new graph where the respective variables are replaced with the 1094 converted constants. 1095 1096 Args: 1097 converter_data: A pre-populated _ConverterData instance. 1098 1099 Returns: 1100 The converted graph. 1101 """ 1102 input_graph = _GraphDef(converter_data.graph_def) 1103 1104 for tensor_name, tensor_data in converter_data.tensor_data.items(): 1105 input_graph.nodes[tensor_name].convert_variable_to_constant( 1106 None, tensor_data) 1107 1108 converted_graph = input_graph.converted_self().graph_def 1109 1110 converted_input_indices = { 1111 t.index 1112 for t in converter_data.tensor_data.values() 1113 if t.index is not None 1114 } 1115 1116 return converted_graph, converted_input_indices 1117 1118 1119def convert_variables_to_constants_v2(func, 1120 lower_control_flow=True, 1121 aggressive_inlining=False): 1122 """Replaces all the variables in a graph with constants of the same values. 1123 1124 TensorFlow 2.0 function for converting all Variable ops into Const ops holding 1125 the same values. This makes it possible to describe the network fully with a 1126 single GraphDef file, and allows the removal of a lot of ops related to 1127 loading and saving the variables. This function runs Grappler's function 1128 inlining optimization in order to return a single subgraph. 1129 1130 The current implementation only works for graphs that do not contain any 1131 control flow or embedding related ops. 1132 1133 Args: 1134 func: ConcreteFunction. 1135 lower_control_flow: Boolean indicating whether or not to lower control flow 1136 ops such as If and While. (default True) 1137 aggressive_inlining: Boolean indicating whether or not to do aggressive 1138 function inlining (might be unsafe if function has stateful ops, not 1139 properly connected to control outputs). (default False) 1140 1141 Returns: 1142 ConcreteFunction containing a simplified version of the original. 1143 """ 1144 1145 converter_data = _FunctionConverterDataInEager( 1146 func=func, 1147 lower_control_flow=lower_control_flow, 1148 aggressive_inlining=aggressive_inlining) 1149 1150 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1151 converter_data=converter_data) 1152 1153 return _construct_concrete_function(func, output_graph_def, 1154 converted_input_indices) 1155 1156 1157def convert_var_to_const_function_in_v1(func, 1158 lower_control_flow=True, 1159 aggressive_inlining=False): 1160 """Replaces all the variables in a graph with constants of the same values. 1161 1162 This function works as same as convert_variables_to_constants_v2, but it 1163 should be used in Graph mode. It is a temporary solution when users want to 1164 integrate their models written in TF2 with infra that requires TF1 mode. 1165 1166 The current implementation only works for graphs that do not contain any 1167 control flow or embedding related ops. 1168 1169 The function must be called in a Session context. 1170 1171 Args: 1172 func: ConcreteFunction. 1173 lower_control_flow: Boolean indicating whether or not to lower control flow 1174 ops such as If and While. (default True) 1175 aggressive_inlining: Boolean indicating whether or not to do aggressive 1176 function inlining (might be unsafe if function has stateful ops, not 1177 properly connected to control outputs). (default False) 1178 1179 Raises: 1180 RuntimeError: If no Session context is present. 1181 1182 Returns: 1183 ConcreteFunction containing a simplified version of the original. 1184 """ 1185 1186 session = ops.get_default_session() 1187 if session is None: 1188 raise RuntimeError( 1189 "The conversion must be carried out in a Session context.") 1190 1191 converter_data = _FunctionConverterDataInGraph( 1192 func=func, 1193 lower_control_flow=lower_control_flow, 1194 aggressive_inlining=aggressive_inlining, 1195 session=session) 1196 1197 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1198 converter_data=converter_data) 1199 1200 return _construct_concrete_function(func, output_graph_def, 1201 converted_input_indices) 1202 1203 1204def convert_variables_to_constants_v2_as_graph(func, 1205 lower_control_flow=True, 1206 aggressive_inlining=False): 1207 """Replaces all the variables in a graph with constants of the same values. 1208 1209 This function works as same as convert_variables_to_constants_v2, but it 1210 returns the intermediate `GraphDef` as well. This `GraphDef` contains all the 1211 debug information after all the transformations in the frozen phase. 1212 1213 Args: 1214 func: ConcreteFunction. 1215 lower_control_flow: Boolean indicating whether or not to lower control flow 1216 ops such as If and While. (default True) 1217 aggressive_inlining: Boolean indicating whether or not to do aggressive 1218 function inlining (might be unsafe if function has stateful ops, not 1219 properly connected to control outputs). 1220 1221 Returns: 1222 ConcreteFunction containing a simplified version of the original, and also 1223 the intermediate GraphDef containing the node debug information for the 1224 transformations in the frozen phase. 1225 """ 1226 converter_data = _FunctionConverterDataInEager( 1227 func=func, 1228 lower_control_flow=lower_control_flow, 1229 aggressive_inlining=aggressive_inlining) 1230 1231 output_graph_def, converted_input_indices = _replace_variables_by_constants( 1232 converter_data=converter_data) 1233 1234 frozen_func = _construct_concrete_function(func, output_graph_def, 1235 converted_input_indices) 1236 return frozen_func, output_graph_def 1237 1238 1239def convert_variables_to_constants_from_session_graph( 1240 session, 1241 graph_def, 1242 output_node_names, 1243 variable_names_allowlist=None, 1244 variable_names_denylist=None): 1245 """Replaces all the variables in a graph with constants of the same values. 1246 1247 This function works similarly to convert_variables_to_constants_v2, but it 1248 retrieves the constant values from a Session instead of from a 1249 ConcreteFunction. This is useful when converting graphs generated from 1250 TensorFlow V1, where ConcreteFunctions are not available. This also differs 1251 from graph_util.convert_variables_to_constants in that it supports resource 1252 variables when V2 control flow constructions are present. 1253 1254 Args: 1255 session: Active TensorFlow session containing the variables. 1256 graph_def: A GraphDef to convert. 1257 output_node_names: List of name strings for the result nodes of the graph. 1258 variable_names_allowlist: The set of variable names to convert (by default, 1259 all variables are converted). 1260 variable_names_denylist: The set of variable names to omit converting to 1261 constants. 1262 1263 Returns: 1264 An optimized GraphDef. 1265 """ 1266 # TODO(b/176982859): Find a more satisfying way to update shape information 1267 # than clearing it, or migrate users to a workflow that does not require 1268 # freezing. 1269 for function in graph_def.library.function: 1270 if "_input_shapes" in function.attr: 1271 for input_arg, shape_attribute in zip( 1272 function.signature.input_arg, 1273 function.attr["_input_shapes"].list.shape): 1274 if dtypes.as_dtype(input_arg.type) == dtypes.resource: 1275 shape_attribute.unknown_rank = True 1276 graph_def, _ = _replace_variables_by_constants( 1277 converter_data=_SessionConverterData( 1278 session=session, 1279 graph_def=graph_def, 1280 output_node_names=output_node_names, 1281 variable_names_allowlist=variable_names_allowlist, 1282 variable_names_denylist=variable_names_denylist)) 1283 return graph_def 1284