• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to convert variables to constants in TensorFlow 2.0."""
16
17import collections
18import numpy as np
19
20from tensorflow.core.framework import attr_value_pb2
21from tensorflow.core.framework import graph_pb2
22from tensorflow.core.framework import tensor_shape_pb2
23from tensorflow.core.framework import variable_pb2
24from tensorflow.core.protobuf import config_pb2
25from tensorflow.core.protobuf import meta_graph_pb2
26from tensorflow.core.protobuf import rewriter_config_pb2
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import errors
30from tensorflow.python.framework import graph_util
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.grappler import tf_optimizer
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import variables
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.training.saver import export_meta_graph
38from tensorflow.python.util import lazy_loader
39from tensorflow.python.util import object_identity
40
41# Lazy load the single eager module to avoid introducing new dependencies for
42# graph_util:convert_variables_to_constants (eg in
43# tensorflow/contrib/session_bundle:session_bundle_py_test).
44wrap_function = lazy_loader.LazyLoader(
45    "wrap_function", globals(),
46    "tensorflow.python.eager.wrap_function")
47
48# Used in _FunctionConverterDataInGraph().
49VAR_ASSIGN_COLLECTION = "extra_var_assign_ops"
50_CONDITIONAL_OPS = set(["If", "StatelessIf"])
51_LOOP_OPS = set(["While", "StatelessWhile"])
52_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
53
54
55class _TensorData(
56    collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])):
57  """Data about a tensor that was converted to a constant."""
58  __slots__ = ()
59
60  @property
61  def dtype_attr(self):
62    return attr_value_pb2.AttrValue(type=self.dtype)
63
64
65class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])):
66  """An endpoint in a graph."""
67  __slots__ = ()
68
69  def __str__(self):
70    return "{}[{}]".format(self.convertible, self.index)
71
72
73class _Edge(collections.namedtuple("_Edge", ["source", "destination"])):
74  """A directed graph edge."""
75  __slots__ = ()
76
77  def __str__(self):
78    return "{} -> {}".format(self.source, self.destination)
79
80
81class _Convertible(object):
82  """An entity that can have variables converted to constants."""
83
84  def __init__(self, enclosing_graph):
85    self._enclosing_graph = enclosing_graph
86    self._outgoing_edges = []
87    self._converted_self = None
88
89  def converted_self(self):
90    """A copy of this Convertible to be modified during conversion.
91
92    Returns:
93      Implementations should return the copied instance, which in turn should
94      be contained in converted_enclosing_graph(). This instance is the one that
95      will be modified during conversion. Its main use will be in the
96      implementations of convert_variable_to_constant().
97    """
98    raise NotImplementedError
99
100  def convert_variable_to_constant(self, incoming_edge, tensor_data):
101    """Converts a variable in this Convertible and its dependencies.
102
103    This method should make sure that a converted copy of itself is present in
104    the converted graph, and that all Convertibles depending on this one also go
105    through the same process.
106
107    Args:
108      incoming_edge: The graph edge into this Convertible that is being
109        converted to a constant.
110      tensor_data: The tensor representing the constant.
111    """
112    raise NotImplementedError
113
114  def create_edges(self):
115    """Calls add_outgoing_edge for all edges known to this Convertible.
116
117    This is used to build the graph dependencies, so that conversion of
118    variables to constants can be properly propagated through the graph. Usually
119    this method will call add_outgoing_edge() to all the Convertible inputs.
120    """
121    raise NotImplementedError
122
123  def add_outgoing_edge(self, edge):
124    """Adds an outgoing edge to the Convertible's list of edges.
125
126    Args:
127      edge: The outgoing edge (its source should be 'self').
128    """
129    self._outgoing_edges.append(edge)
130
131  @property
132  def converted_enclosing_graph(self):
133    """The graph being converted."""
134    return self._enclosing_graph.converted_self()
135
136  @property
137  def outgoing_edges(self):
138    """The list of edges starting at this Convertible."""
139    return self._outgoing_edges
140
141
142class _Function(_Convertible):
143  """A library function Convertible.
144
145  Edges into functions are edges from node _inputs_ into function _inputs_:
146  Functions get their input from their callers, not from node outputs, and the
147  callers in turn get those values as inputs.
148  """
149
150  def __init__(self, function, enclosing_graph):
151    super(_Function, self).__init__(enclosing_graph)
152    self._function = function
153    self._nodes = {
154        n.name:
155        _Node.new(node=n, function=self, enclosing_graph=enclosing_graph)
156        for n in function.node_def
157    }
158
159  def __str__(self):
160    return self.function.signature.name
161
162  @property
163  def function(self):
164    return self._function
165
166  @property
167  def nodes(self):
168    return self._nodes
169
170  def converted_self(self):
171    """The Function copy to be converted.
172
173    The copy will be renamed according to the graph's converted_function_name
174    map, to ensure the name does not match anything currently in TensorFlow's
175    function cache.
176
177    Returns:
178      The function instance to be converted.
179    """
180    if self._converted_self is None:
181      old_name = self.function.signature.name
182      new_name = self._enclosing_graph.converted_function_names[old_name]
183      self.converted_enclosing_graph.rename_function(old_name, new_name)
184      self._converted_self = self.converted_enclosing_graph.functions[new_name]
185    return self._converted_self
186
187  def convert_variable_to_constant(self, incoming_edge, tensor_data):
188    """Converts one function argument into a constant.
189
190    Args:
191      incoming_edge: The edge into the argument to be converted.
192      tensor_data: The constant value.
193    """
194    function = self.converted_self().function
195    index = incoming_edge.destination.index
196    function.signature.input_arg[index].type = tensor_data.dtype
197
198    # TODO(b/176982859): Find a more satisfying way to update shape information
199    # than clearing it, or migrate users to a workflow that does not require
200    # freezing.
201    if "_input_shapes" in function.attr:
202      function.attr["_input_shapes"].list.shape[index].unknown_rank = True
203      del function.attr["_input_shapes"].list.shape[index].dim[:]
204    arg_attrs = function.arg_attr[index].attr
205    if "_output_shapes" in arg_attrs:
206      arg_attrs["_output_shapes"].list.shape[0].unknown_rank = True
207      del arg_attrs["_output_shapes"].list.shape[0].dim[:]
208
209    for edge in self.outgoing_edges:
210      if edge.source.index == index:
211        edge.destination.convertible.convert_variable_to_constant(
212            edge, tensor_data)
213
214  def create_edges(self):
215    for n in self._nodes.values():
216      n.create_edges()
217
218
219class _Node(_Convertible):
220  """A Convertible NodeDef."""
221
222  def __init__(self, node, function, enclosing_graph):
223    super(_Node, self).__init__(enclosing_graph)
224    self._node = node
225    self._function = function
226
227  def __str__(self):
228    return self._node.name
229
230  @staticmethod
231  def new(node, function, enclosing_graph):
232    """Creates a new _Node base on its operation type."""
233    if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]:
234      return _VarHandle(node, function, enclosing_graph)
235    elif node.op == "Case":
236      return _Case(node, function, enclosing_graph)
237    elif node.op == "Merge":
238      return _Merge(node, function, enclosing_graph)
239    elif node.op == "PartitionedCall":
240      return _PartitionedCall(node, function, enclosing_graph)
241    elif node.op == "StatefulPartitionedCall":
242      return _PartitionedCall(node, function, enclosing_graph)
243    elif node.op == "ReadVariableOp":
244      return _ReadVariable(node, function, enclosing_graph)
245    elif node.op == "ResourceGather":
246      return _ResourceGather(node, function, enclosing_graph)
247    elif node.op == "ResourceGatherNd":
248      return _ResourceGatherNd(node, function, enclosing_graph)
249    elif node.op in ["If", "StatelessIf"]:
250      return _If(node, function, enclosing_graph)
251    elif node.op in ["While", "StatelessWhile"]:
252      return _While(node, function, enclosing_graph)
253    elif node.op in [
254        "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
255      return _Intermediate(node, function, enclosing_graph)
256    else:
257      return _Node(node, function, enclosing_graph)
258
259  @property
260  def node(self):
261    return self._node
262
263  @property
264  def container(self):
265    """The node container (either a graph or a function)."""
266    if self._function is not None:
267      return self._function.function
268    return self._enclosing_graph.graph_def
269
270  def converted_self(self):
271    """The NodeDef to be converted.
272
273    Returns:
274      The NodeDef to be converted, which can come from either a graph for a
275      function. Derived classes should call this (via 'super') to make sure the
276      node is retrieved from the right place.
277    """
278    if self._converted_self is None:
279      source = self._function or self._enclosing_graph
280      self._converted_self = source.converted_self().nodes[self._node.name]
281    return self._converted_self
282
283  def convert_variable_to_constant(self, incoming_edge, tensor_data):
284    pass
285
286  def create_edges(self):
287    for index, name in enumerate(self._node.input):
288      # Discard edges from control inputs.
289      if name[0] == "^":
290        continue
291      source = self.resolve_input(name)
292      source.convertible.add_outgoing_edge(
293          _Edge(source, _EndPoint(self, index)))
294
295  def resolve_input(self, input_name):
296    """Resolves an input into its _EndPoint.
297
298    A NodeDef's input name can refer to either global NodeDefs (in the
299    GraphDef's node list), a NodeDef in a function's node list, or a Function
300    (in the GraphDef's function library). The name can also carry semantic
301    information, depending on whether it starts with "^". This method handles
302    all that logic in order to find the object to which the input name refers
303    to.
304
305    Args:
306      input_name: The input name to resolve.
307
308    Returns:
309      The object referred to by 'input_name'.
310    """
311
312    # The logic below oversimplifies the semantics, but is good enough for the
313    # purposes of converting to constants. The introduction of new types of
314    # operations may change this, forcing the code to be more generic.
315    #
316    # In particular, we are assuming that the lack of an index suffix means
317    # ":0", when it could mean "all the outputs of a node." This works now
318    # because converting to constants relies very little on output types, and
319    # when it does it specializes its treatment in dedicated classes.
320    name_elts = input_name.split(":")
321    source_name = name_elts[0]
322    if source_name[0] == "^":
323      source_name = source_name[1:]
324    source_index = 0
325    if len(name_elts) > 1 and name_elts[-1].isnumeric():
326      source_index = int(name_elts[-1])
327
328    if self._function is None:
329      return _EndPoint(self._enclosing_graph.nodes[source_name], source_index)
330
331    if source_index != 0 or source_name in self._function.nodes:
332      return _EndPoint(self._function.nodes[source_name], source_index)
333
334    inputs = [i.name for i in self._function.function.signature.input_arg]
335    return _EndPoint(self._function, inputs.index(source_name))
336
337  def update_dtype(self, attr_name, index, dtype):
338    """Changes the type of a given input.
339
340    Args:
341      attr_name: The NodeDef attribute containing the type to change.
342      index: The index of the input type to change.
343      dtype: The type to change to.
344    """
345    attr = self._node.attr[attr_name]
346    num_types = 0
347    # Check for various 'oneof' possibilities, and update the type if
348    # index in range.
349    if attr.HasField("list"):
350      types = attr.list.type
351      num_types = len(types)
352      if num_types > index:
353        types[index] = dtype
354        return
355    elif attr.HasField("type"):
356      num_types = 1
357      if index == 0:
358        attr.type = dtype
359        return
360    raise ValueError(f"`index` {index:d} is out of range for "
361                     f"node({self._node.name}).attr({attr_name}), which has "
362                     f"{num_types:d} elements.")
363
364
365class _Intermediate(_Node):
366  """Specialization of _Node to intermediate ops."""
367
368  def convert_variable_to_constant(self, incoming_edge, tensor_data):
369    node = self.converted_self()
370    node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype)
371    if "_output_shapes" in node.node.attr:
372      del node.node.attr["_output_shapes"]
373    for edge in self.outgoing_edges:
374      edge.destination.convertible.convert_variable_to_constant(
375          edge, tensor_data)
376
377
378class _Merge(_Node):
379  """Specialization of _Node to Merge ops."""
380
381  def convert_variable_to_constant(self, incoming_edge, tensor_data):
382    # The Merge operation has a single type for all its inputs, the number of
383    # which is reflected in the "N" attribute. For the time being, we assume
384    # that unilaterally changing all of them at once is ok.
385    super(_Merge, self).convert_variable_to_constant(
386        _Edge(incoming_edge.source,
387              _Edge(incoming_edge.destination.convertible, 0)), tensor_data)
388
389
390class _VarHandle(_Node):
391  """Specialization of _Node to VarHandleOp."""
392
393  def convert_variable_to_constant(self, incoming_edge, tensor_data):
394    tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy,
395                                                 tensor_data.dtype,
396                                                 tensor_data.numpy.shape)
397
398    node = self.converted_self().node
399    node.Clear()
400    node.name = self._node.name
401    node.op = "Const"
402    node.attr["dtype"].CopyFrom(tensor_data.dtype_attr)
403    node.attr["value"].tensor.CopyFrom(tensor_proto)
404
405    for edge in self.outgoing_edges:
406      edge.destination.convertible.convert_variable_to_constant(
407          edge, tensor_data)
408
409
410class _ResourceGather(_Node):
411  """Specialization of _Node to ResourceGather."""
412
413  def convert_variable_to_constant(self, incoming_edge, tensor_data):
414    # We currently skip the conversion if this is inside a function.
415    if self._function is not None:
416      return
417    if self._node.attr["batch_dims"].i != 0:
418      raise ValueError("batch_dims must be 0 for freeze_graph, but got "
419                       f"node({self._node.name}).attr('batch_dims') = "
420                       f"{self._node.attr['batch_dims'].i}.")
421    axis_node_name = self._node.name + "/axis"
422    axis_dtype = self._node.attr["Tindices"]
423    axis_data = np.array(self._node.attr["batch_dims"].i)
424    output_axis_node = self.converted_self().container.node.add()
425    output_axis_node.name = axis_node_name
426    output_axis_node.op = "Const"
427    output_axis_node.attr["dtype"].CopyFrom(axis_dtype)
428    tensor = tensor_util.make_tensor_proto(
429        axis_data, dtype=axis_dtype.type, shape=axis_data.shape)
430    output_axis_node.attr["value"].tensor.CopyFrom(tensor)
431
432    output_node = self.converted_self().node
433    output_node.Clear()
434    output_node.name = self._node.name
435    output_node.op = "GatherV2"
436    output_node.input.extend(
437        [self._node.input[0], self._node.input[1], axis_node_name])
438    output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
439    output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
440    output_node.attr["Taxis"].CopyFrom(axis_dtype)
441    if "_class" in self._node.attr:
442      output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
443
444
445class _ResourceGatherNd(_Node):
446  """Specialization of _Node to ResourceGatherNd."""
447
448  def convert_variable_to_constant(self, incoming_edge, tensor_data):
449    output_node = self.converted_self().node
450    output_node.Clear()
451    output_node.name = self._node.name
452    output_node.op = "GatherNd"
453    output_node.input.extend([self._node.input[0], self._node.input[1]])
454    output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
455    output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
456    if "_class" in self._node.attr:
457      output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
458
459
460class _ReadVariable(_Node):
461  """Specialization of _Node to ReadVariableOp."""
462
463  def convert_variable_to_constant(self, incoming_edge, tensor_data):
464    node = self.converted_self().node
465    node.Clear()
466    node.name = self._node.name
467    node.op = "Identity"
468
469    node.input.append(self._node.input[0])
470    node.attr["T"].CopyFrom(self._node.attr["dtype"])
471    if "_class" in self._node.attr:
472      node.attr["_class"].CopyFrom(self._node.attr["_class"])
473
474    # If the ReadVariableOp is part of a function, then every node having the
475    # ReadVariableOp one as its input will refer to it using a ":value"
476    # syntax. We need to change that to ":output".
477    if self._function is not None:
478      for edge in self.outgoing_edges:
479        index = edge.destination.index
480        dest = edge.destination.convertible.converted_self()
481        if isinstance(dest, _Node):
482          input_name_parts = dest.node.input[index].split(":")
483          if len(input_name_parts) > 1 and input_name_parts[1] == "value":
484            input_name_parts[1] = "output"
485            dest.node.input[index] = ":".join(input_name_parts)
486
487
488class _FunctionCaller(_Node):
489  """A base class for Convertibles that reference functions."""
490
491  def __init__(self, node, function, enclosing_graph, first_function_input,
492               type_attribute, function_attributes):
493    """Initializes a _FunctionCaller.
494
495    Args:
496      node: As in _Node.
497      function: As in _Node.
498      enclosing_graph: As in _Node.
499      first_function_input: The index of the first NodeDef input that is tied to
500        the function inputs. It is assumed that the rest of the NodeDef inputs
501        map one to one to function inputs.
502      type_attribute: The name of the NodeDef attribute that defines the input
503        types. It is assumed that the types listed here map one-to-one with the
504        function inputs (that is, they do _not_ specify types for inputs that
505        are not passed to functions).
506      function_attributes: The names of the NodeDef attributes containing
507        references to functions.
508    """
509    super(_FunctionCaller, self).__init__(node, function, enclosing_graph)
510    self._first_function_input = first_function_input
511    self._type_attribute = type_attribute
512    self._function_attributes = function_attributes
513
514  def converted_self(self):
515    if self._converted_self is None:
516      node = super(_FunctionCaller, self).converted_self().node
517      converted_names = self._enclosing_graph.converted_function_names
518      for attr_name in self._function_attributes:
519        attr = node.attr[attr_name]
520        if attr.HasField("func"):
521          attr.func.name = converted_names[attr.func.name]
522        elif attr.HasField("list"):
523          for func in attr.list.func:
524            func.name = converted_names[func.name]
525    return self._converted_self
526
527  def convert_variable_to_constant(self, incoming_edge, tensor_data):
528    node = self.converted_self()
529    index = incoming_edge.destination.index
530    if index >= self._first_function_input:
531      node.update_dtype(self._type_attribute,
532                        index - self._first_function_input, tensor_data.dtype)
533
534    # The loop below is reasonable but not correct in general:
535    # The outgoing edges going into the functions are correct, because the
536    # inputs map to the function inputs. But the edges going into other nodes do
537    # not take into account the logic of the body function, which may do
538    # arbitrary things to the node's output:
539    #
540    #   while x < 0:
541    #     return y
542    #
543    # In this case, the node's ":0" output may map to its ":1 input". For the
544    # time being, then, we only process edges into functions.
545    for edge in self.outgoing_edges:
546      dest = edge.destination.convertible
547      if edge.source.index == index and isinstance(dest, _Function):
548        dest.convert_variable_to_constant(edge, tensor_data)
549
550  def create_edges(self):
551    """Creates edges related to a function caller.
552
553    Edges from a function caller to its called functions are always edges from
554    _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on
555    its own inputs.
556    """
557    super(_FunctionCaller, self).create_edges()
558    for attr_name in self._function_attributes:
559      attr = self._node.attr[attr_name]
560      if attr.HasField("func"):
561        function = self._enclosing_graph.functions[attr.func.name]
562        for index in range(len(self._node.input) - self._first_function_input):
563          self.add_outgoing_edge(
564              _Edge(
565                  _EndPoint(self, index + self._first_function_input),
566                  _EndPoint(function, index)))
567      elif attr.HasField("list"):
568        for func in attr.list.func:
569          function = self._enclosing_graph.functions[func.name]
570          for index in range(
571              len(self._node.input) - self._first_function_input):
572            self.add_outgoing_edge(
573                _Edge(
574                    _EndPoint(self, index + self._first_function_input),
575                    _EndPoint(function, index)))
576
577
578class _If(_FunctionCaller):
579  """Specialization of _Node to If-like operations."""
580
581  def __init__(self, node, function, enclosing_graph):
582    super(_If, self).__init__(
583        node,
584        function,
585        enclosing_graph,
586        first_function_input=1,
587        type_attribute="Tin",
588        function_attributes=["then_branch", "else_branch"])
589
590
591class _Case(_FunctionCaller):
592  """Specialization of _Node to Case-like operations."""
593
594  def __init__(self, node, function, enclosing_graph):
595    super(_Case, self).__init__(
596        node,
597        function,
598        enclosing_graph,
599        first_function_input=1,
600        type_attribute="Tin",
601        function_attributes=["branches"])
602
603
604class _PartitionedCall(_FunctionCaller):
605  """Specialization of _Node to PartitionedCall-like operations."""
606
607  def __init__(self, node, function, enclosing_graph):
608    super(_PartitionedCall, self).__init__(
609        node,
610        function,
611        enclosing_graph,
612        first_function_input=0,
613        type_attribute="Tin",
614        function_attributes=["f"])
615
616
617class _While(_FunctionCaller):
618  """Specialization of _Node to While-like operations."""
619
620  def __init__(self, node, function, enclosing_graph):
621    super(_While, self).__init__(
622        node,
623        function,
624        enclosing_graph,
625        first_function_input=0,
626        type_attribute="T",
627        function_attributes=["body", "cond"])
628
629  def convert_variable_to_constant(self, incoming_edge, tensor_data):
630    super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data)
631    node = self.converted_self()
632    if node.node.attr["output_shapes"].list.shape:
633      node.node.attr["output_shapes"].list.shape[
634          incoming_edge.destination.index].CopyFrom(
635              tensor_shape_pb2.TensorShapeProto(dim=[
636                  tensor_shape_pb2.TensorShapeProto.Dim(size=dim)
637                  for dim in tensor_data.numpy.shape
638              ]))
639
640    # The while's body inputs and outputs have the same type, so here we can go
641    # ahead and change that function's output type.
642    body_name = self._node.attr["body"].func.name
643    body = self._enclosing_graph.functions[body_name].converted_self().function
644    body.signature.output_arg[
645        incoming_edge.destination.index].type = tensor_data.dtype
646
647
648class _GraphDef(_Convertible):
649  """A convertible GraphDef."""
650
651  def __init__(self, graph_def):
652    super(_GraphDef, self).__init__(enclosing_graph=None)
653    self._graph_def = graph_def
654    self._nodes = {
655        n.name: _Node.new(node=n, function=None, enclosing_graph=self)
656        for n in graph_def.node
657    }
658    self._functions = {
659        f.signature.name: _Function(f, enclosing_graph=self)
660        for f in graph_def.library.function
661    }
662    self.create_edges()
663    self._converted_function_names = None
664
665  @property
666  def graph_def(self):
667    return self._graph_def
668
669  @property
670  def nodes(self):
671    return self._nodes
672
673  @property
674  def functions(self):
675    return self._functions
676
677  @property
678  def converted_function_names(self):
679    """Map from original to new function names.
680
681    In order to avoid conflicts (two functions with the same name, one converted
682    and one not), we need to change the name of every converted function to
683    something that is hopefully unique.
684
685    Returns:
686      Map from original to new suggested function names.
687    """
688    if self._converted_function_names is None:
689      parsed_names = []  # List of (id, base_name, original_name)
690      for name in self.functions:
691        elements = name.rsplit("_", 1)
692        if len(elements) == 2 and elements[1].isnumeric():
693          parsed_names.append((int(elements[1]), elements[0], name))
694        else:
695          parsed_names.append((-1, name, name))
696      self._converted_function_names = {
697          name: "{}_frozen_{}".format(base_name, ops.uid())
698          for (_, base_name, name) in sorted(parsed_names)
699      }
700
701    return self._converted_function_names
702
703  def rename_function(self, old_name, new_name):
704    func = self.functions.pop(old_name)
705    func.function.signature.name = new_name
706    self.functions[new_name] = func
707
708  def converted_self(self):
709    if self._converted_self is None:
710      copied_graph = graph_pb2.GraphDef()
711      copied_graph.CopyFrom(self._graph_def)
712      self._converted_self = _GraphDef(copied_graph)
713    return self._converted_self
714
715  def create_edges(self):
716    for n in self._nodes.values():
717      n.create_edges()
718    for f in self._functions.values():
719      f.create_edges()
720
721
722class _ConverterData(object):
723  """Container for constant conversion supporting data.
724
725  The data includes the graph being converted, and the pre-converted
726  tensors. This class will be specialized for ConcreteFunction and Session-based
727  conversions, as the means to obtain that data is different for each case.
728  """
729
730  def __init__(self,
731               graph_def,
732               variable_names_allowlist=None,
733               variable_names_denylist=None):
734    self._graph_def = graph_def
735    self._tensor_data = {}
736    self._build_node_defs_list()
737    self._variable_names_allowlist = variable_names_allowlist
738    self._variable_names_denylist = variable_names_denylist
739
740  @property
741  def graph_def(self):
742    """The graph to be converted."""
743    return self._graph_def
744
745  @property
746  def node_defs(self):
747    """All the node defs in the graph to be converted.
748
749    Returns:
750      A map from node name to the NodeDef for all NodeDefs in the graph, as well
751      as all control flow NodeDefs in the functions.
752    """
753    return self._node_defs
754
755  @property
756  def tensor_data(self):
757    """A map from tensor name to its converted _TensorData."""
758    return self._tensor_data
759
760  def _should_convert(self, name):
761    """Checks whether to convert the given variable name to a constant."""
762    return (self._variable_names_allowlist is None or
763            name in self._variable_names_allowlist) and (
764                self._variable_names_denylist is None or
765                name not in self._variable_names_denylist)
766
767  def _build_node_defs_list(self):
768    """Builds the list of NodeDefs in the GraphDef.
769
770    This list consists of all NodeDefs in the main graph as well as all control
771    flow NodeDefs in the functions.
772
773    The remaining NodeDefs in the functions are not included because the op
774    names
775    are not unique and the variables are handled differently than the main
776    graph.
777    The control flow ops need to be extracted because they are need their
778    attributes to be updated similar to the control flow ops in the main graph.
779    """
780    self._node_defs = {node.name: node for node in self._graph_def.node}
781
782    if self._graph_def.library:
783      for func in self._graph_def.library.function:
784        self._node_defs.update({
785            node.name: node
786            for node in func.node_def
787            if node.op in _CONTROL_FLOW_OPS
788        })
789
790
791class _FunctionConverterData(_ConverterData):
792  """Container for ConcreteFunction-based conversion data."""
793
794  def __init__(self,
795               func,
796               lower_control_flow,
797               aggressive_inlining,
798               variable_names_allowlist=None,
799               variable_names_denylist=None):
800    """Creates the conversion data for the given function.
801
802    Args:
803      func: ConcreteFunction.
804      lower_control_flow: Boolean indicating whether or not to lower control
805        flow ops such as If and While.
806      aggressive_inlining: Boolean indicating whether or not to do aggressive
807        function inlining (might be unsafe if function has stateful ops, not
808        properly connected to control outputs).
809      variable_names_allowlist: The set of variable names to convert (by
810        default, all variables are converted).
811      variable_names_denylist: The set of variable names to omit converting to
812        constants.
813    """
814
815    self._func = func
816    # Inline the graph in order to remove functions when possible.
817    graph_def = _run_inline_graph_optimization(func, lower_control_flow,
818                                               aggressive_inlining)
819    super(_FunctionConverterData, self).__init__(
820        graph_def,
821        variable_names_allowlist=variable_names_allowlist,
822        variable_names_denylist=variable_names_denylist)
823
824    self._build_tensor_data()
825
826  def _eval(self, tensor):
827    """Returns the value in the tensor. Must be implemented in sub-classes."""
828    raise errors.UnimplementedError(
829        "The evaluation method should be implemented in sub-classes.")
830
831  def _build_tensor_data(self):
832    """Caches the tensor data for all Placeholders in the given function."""
833    map_index_to_variable = {}
834    for var in self._func.graph.variables:
835      for idx, captured_input in enumerate(self._func.captured_inputs):
836        if var.handle is captured_input:  # pylint: disable=protected-access
837          map_index_to_variable[idx] = var
838          break
839
840    # Iterates through all captures which are represented as Placeholders.
841    for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures):
842      tensor_name = name_tensor.name.split(":")[0]
843      if not self._should_convert(tensor_name):
844        continue
845      if idx in map_index_to_variable:
846        data = self._eval(map_index_to_variable[idx])
847      else:
848        if val_tensor.dtype == dtypes.resource:
849          logging.vlog(1, "Skip converting resource tensor %s" % tensor_name)
850          continue
851        data = np.array(self._eval(val_tensor))
852
853      self._tensor_data[tensor_name] = _TensorData(
854          numpy=data,
855          dtype=dtypes.as_dtype(data.dtype).as_datatype_enum,
856          index=idx)
857
858    # Get data for VariableV2 ops (reference variables) that cannot be lifted.
859    for node in self.node_defs.values():
860      if node.op == "VariableV2":
861        if not self._should_convert(node.name):
862          continue
863        if node.name not in self.tensor_data:
864          with self._func.graph.as_default():
865            identity_node = array_ops.identity(
866                self._func.graph.as_graph_element(node.name + ":0"))
867          pruned_graph = self._func.prune([], [identity_node.name])()[0]
868          self._tensor_data[node.name] = _TensorData(
869              numpy=pruned_graph.numpy(),
870              dtype=node.attr["dtype"].type,
871              index=None)
872
873
874class _FunctionConverterDataInEager(_FunctionConverterData):
875  """Container for ConcreteFunction-based conversion data in Eager mode."""
876
877  def _eval(self, tensor):
878    """Returns the value in the tensor. Must be implemented in sub-classes."""
879    return tensor.numpy()
880
881
882class _FunctionConverterDataInGraph(_FunctionConverterData):
883  """Container for ConcreteFunction-based conversion data in Graph mode."""
884
885  def __init__(self,
886               func,
887               lower_control_flow,
888               aggressive_inlining,
889               variable_names_allowlist=None,
890               variable_names_denylist=None,
891               session=None):
892    """Creates the conversion data for the given function.
893
894    Args:
895      func: ConcreteFunction.
896      lower_control_flow: Boolean indicating whether or not to lower control
897        flow ops such as If and While.
898      aggressive_inlining: Boolean indicating whether or not to do aggressive
899        function inlining (might be unsafe if function has stateful ops, not
900        properly connected to control outputs).
901      variable_names_allowlist: The set of variable names to convert (by
902        default, all variables are converted).
903      variable_names_denylist: The set of variable names to omit converting to
904        constants.
905      session: Session object.
906    """
907    self._session = session
908
909    session.run(variables.global_variables_initializer())
910    # Run extra assignment ops if needed.
911    # These assignments are run sequentially to ensure order.
912    for op in ops.get_default_graph().get_collection(VAR_ASSIGN_COLLECTION):
913      session.run(op)
914
915    super(_FunctionConverterDataInGraph, self).__init__(
916        func,
917        lower_control_flow,
918        aggressive_inlining,
919        variable_names_allowlist,
920        variable_names_denylist)
921
922  def _eval(self, tensor):
923    """Returns the value in the tensor. Must be implemented in sub-classes."""
924    return self._session.run(tensor)
925
926
927class _SessionConverterData(_ConverterData):
928  """Container for Session-based conversion data."""
929
930  def __init__(self,
931               session,
932               graph_def,
933               output_node_names,
934               variable_names_allowlist=None,
935               variable_names_denylist=None):
936    graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
937    super(_SessionConverterData, self).__init__(
938        graph_def,
939        variable_names_allowlist=variable_names_allowlist,
940        variable_names_denylist=variable_names_denylist)
941
942    nodes_to_convert = []
943    tensor_names_to_convert = []
944    for node in self.graph_def.node:
945      if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
946        tensor_name = node.name
947        if not self._should_convert(tensor_name):
948          continue
949        if node.op == "VarHandleOp":
950          tensor_name = tensor_name + "/Read/ReadVariableOp"
951        nodes_to_convert.append(node)
952        tensor_names_to_convert.append(tensor_name + ":0")
953
954    if tensor_names_to_convert:
955      converted_tensors = session.run(tensor_names_to_convert)
956      for node, tensor_value in zip(nodes_to_convert, converted_tensors):
957        self._tensor_data[node.name] = _TensorData(
958            numpy=tensor_value, dtype=node.attr["dtype"].type, index=None)
959
960
961def disable_lower_using_switch_merge(graph_def):
962  """Set '_lower_using_switch_merge' attributes to False.
963
964  Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs
965  in each function's graph.
966
967  Args:
968    graph_def: GraphDef proto.
969
970  Returns:
971    GraphDef
972  """
973  output_graph_def = graph_pb2.GraphDef()
974  output_graph_def.CopyFrom(graph_def)
975
976  def disable_control_flow_lowering(node):
977    if node.op in _CONTROL_FLOW_OPS:
978      node.attr["_lower_using_switch_merge"].b = False
979
980  for node in output_graph_def.node:
981    disable_control_flow_lowering(node)
982
983  if output_graph_def.library:
984    for func in output_graph_def.library.function:
985      for node in func.node_def:
986        disable_control_flow_lowering(node)
987  return output_graph_def
988
989
990def _run_inline_graph_optimization(func, lower_control_flow,
991                                   aggressive_inlining):
992  """Apply function inline optimization to the graph.
993
994  Returns the GraphDef after Grappler's function inlining optimization is
995  applied. This optimization does not work on models with control flow.
996
997  Args:
998    func: ConcreteFunction.
999    lower_control_flow: Boolean indicating whether or not to lower control flow
1000      ops such as If and While. (default True)
1001    aggressive_inlining: Boolean indicating whether or not to do aggressive
1002      function inlining (might be unsafe if function has stateful ops not
1003      properly connected to control outputs).
1004
1005  Returns:
1006    GraphDef
1007  """
1008  graph_def = func.graph.as_graph_def()
1009  if not lower_control_flow:
1010    graph_def = disable_lower_using_switch_merge(graph_def)
1011
1012  # In some cases, a secondary implementation of the function (e.g. for GPU) is
1013  # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
1014  # TF2 produces a CuDNN-based RNN for GPU).
1015  # This function suppose to inline all functions calls, but "api_implements"
1016  # prevents this from happening. Removing the attribute solves the problem.
1017  # To learn more about "api_implements", see:
1018  #   tensorflow/core/grappler/optimizers/implementation_selector.h
1019  for function in graph_def.library.function:
1020    if "api_implements" in function.attr:
1021      del function.attr["api_implements"]
1022
1023  meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)
1024
1025  # Clear the initializer_name for the variables collections, since they are not
1026  # needed after saved to saved_model.
1027  for name in [
1028      "variables", "model_variables", "trainable_variables", "local_variables"
1029  ]:
1030    raw_list = []
1031    for raw in meta_graph.collection_def["variables"].bytes_list.value:
1032      variable = variable_pb2.VariableDef()
1033      variable.ParseFromString(raw)
1034      variable.ClearField("initializer_name")
1035      raw_list.append(variable.SerializeToString())
1036    meta_graph.collection_def[name].bytes_list.value[:] = raw_list
1037
1038  # Add a collection 'train_op' so that Grappler knows the outputs.
1039  fetch_collection = meta_graph_pb2.CollectionDef()
1040  for array in func.inputs + func.outputs:
1041    fetch_collection.node_list.value.append(array.name)
1042  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
1043
1044  # Initialize RewriterConfig with everything disabled except function inlining.
1045  config = config_pb2.ConfigProto()
1046  rewrite_options = config.graph_options.rewrite_options
1047  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
1048  rewrite_options.optimizers.append("function")
1049  if aggressive_inlining:
1050    rewrite_options.function_optimization =\
1051      rewriter_config_pb2.RewriterConfig.AGGRESSIVE
1052  return tf_optimizer.OptimizeGraph(config, meta_graph)
1053
1054
1055def _construct_concrete_function(func, output_graph_def,
1056                                 converted_input_indices):
1057  """Constructs a concrete function from the `output_graph_def`.
1058
1059  Args:
1060    func: ConcreteFunction
1061    output_graph_def: GraphDef proto.
1062    converted_input_indices: Set of integers of input indices that were
1063      converted to constants.
1064
1065  Returns:
1066    ConcreteFunction.
1067  """
1068  # Create a ConcreteFunction from the new GraphDef.
1069  input_tensors = func.graph.internal_captures
1070  converted_inputs = object_identity.ObjectIdentitySet(
1071      [input_tensors[index] for index in converted_input_indices])
1072  not_converted_inputs = [
1073      tensor for tensor in func.inputs if tensor not in converted_inputs
1074  ]
1075  not_converted_inputs_map = {
1076      tensor.name: tensor for tensor in not_converted_inputs
1077  }
1078
1079  new_input_names = [tensor.name for tensor in not_converted_inputs]
1080  new_output_names = [tensor.name for tensor in func.outputs]
1081
1082  # Remove old functions to use updated functions from graph def.
1083  for f in output_graph_def.library.function:
1084    if context.context().has_function(f.signature.name):
1085      context.context().remove_function(f.signature.name)
1086
1087  new_func = wrap_function.function_from_graph_def(output_graph_def,
1088                                                   new_input_names,
1089                                                   new_output_names)
1090
1091  # Manually propagate shape for input tensors where the shape is not correctly
1092  # propagated. Scalars shapes are lost when wrapping the function.
1093  for input_tensor in new_func.inputs:
1094    input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape)
1095  return new_func
1096
1097
1098def _replace_variables_by_constants(converter_data):
1099  """Replaces variables by constants on a given graph.
1100
1101  Given a _ConverterData instance with converted variables in its tensor_data
1102  field, create a new graph where the respective variables are replaced with the
1103  converted constants.
1104
1105  Args:
1106    converter_data: A pre-populated _ConverterData instance.
1107
1108  Returns:
1109    The converted graph.
1110  """
1111  input_graph = _GraphDef(converter_data.graph_def)
1112
1113  for tensor_name, tensor_data in converter_data.tensor_data.items():
1114    input_graph.nodes[tensor_name].convert_variable_to_constant(
1115        None, tensor_data)
1116
1117  converted_graph = input_graph.converted_self().graph_def
1118
1119  converted_input_indices = {
1120      t.index
1121      for t in converter_data.tensor_data.values()
1122      if t.index is not None
1123  }
1124
1125  return converted_graph, converted_input_indices
1126
1127
1128def convert_variables_to_constants_v2(func,
1129                                      lower_control_flow=True,
1130                                      aggressive_inlining=False):
1131  """Replaces all the variables in a graph with constants of the same values.
1132
1133  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
1134  the same values. This makes it possible to describe the network fully with a
1135  single GraphDef file, and allows the removal of a lot of ops related to
1136  loading and saving the variables. This function runs Grappler's function
1137  inlining optimization in order to return a single subgraph.
1138
1139  The current implementation only works for graphs that do not contain any
1140  control flow or embedding related ops.
1141
1142  Args:
1143    func: ConcreteFunction.
1144    lower_control_flow: Boolean indicating whether or not to lower control flow
1145      ops such as If and While. (default True)
1146    aggressive_inlining: Boolean indicating whether or not to do aggressive
1147      function inlining (might be unsafe if function has stateful ops, not
1148      properly connected to control outputs). (default False)
1149
1150  Returns:
1151    ConcreteFunction containing a simplified version of the original.
1152  """
1153
1154  converter_data = _FunctionConverterDataInEager(
1155      func=func,
1156      lower_control_flow=lower_control_flow,
1157      aggressive_inlining=aggressive_inlining)
1158
1159  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1160      converter_data=converter_data)
1161
1162  return _construct_concrete_function(func, output_graph_def,
1163                                      converted_input_indices)
1164
1165
1166def convert_var_to_const_function_in_v1(func,
1167                                        lower_control_flow=True,
1168                                        aggressive_inlining=False):
1169  """Replaces all the variables in a graph with constants of the same values.
1170
1171  This function works as same as convert_variables_to_constants_v2, but it
1172  should be used in Graph mode. It is a temporary solution when users want to
1173  integrate their models written in TF2 with infra that requires TF1 mode.
1174
1175  The current implementation only works for graphs that do not contain any
1176  control flow or embedding related ops.
1177
1178  The function must be called in a Session context.
1179
1180  Args:
1181    func: ConcreteFunction.
1182    lower_control_flow: Boolean indicating whether or not to lower control flow
1183      ops such as If and While. (default True)
1184    aggressive_inlining: Boolean indicating whether or not to do aggressive
1185      function inlining (might be unsafe if function has stateful ops, not
1186      properly connected to control outputs). (default False)
1187
1188  Raises:
1189      RuntimeError: If no Session context is present.
1190
1191  Returns:
1192    ConcreteFunction containing a simplified version of the original.
1193  """
1194
1195  session = ops.get_default_session()
1196  if session is None:
1197    raise RuntimeError(
1198        "The conversion must be carried out in a Session context.")
1199
1200  converter_data = _FunctionConverterDataInGraph(
1201      func=func,
1202      lower_control_flow=lower_control_flow,
1203      aggressive_inlining=aggressive_inlining,
1204      session=session)
1205
1206  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1207      converter_data=converter_data)
1208
1209  return _construct_concrete_function(func, output_graph_def,
1210                                      converted_input_indices)
1211
1212
1213def convert_variables_to_constants_v2_as_graph(func,
1214                                               lower_control_flow=True,
1215                                               aggressive_inlining=False):
1216  """Replaces all the variables in a graph with constants of the same values.
1217
1218  This function works as same as convert_variables_to_constants_v2, but it
1219  returns the intermediate `GraphDef` as well. This `GraphDef` contains all the
1220  debug information after all the transformations in the frozen phase.
1221
1222  Args:
1223    func: ConcreteFunction.
1224    lower_control_flow: Boolean indicating whether or not to lower control flow
1225      ops such as If and While. (default True)
1226    aggressive_inlining: Boolean indicating whether or not to do aggressive
1227      function inlining (might be unsafe if function has stateful ops, not
1228      properly connected to control outputs).
1229
1230  Returns:
1231    ConcreteFunction containing a simplified version of the original, and also
1232    the intermediate GraphDef containing the node debug information for the
1233    transformations in the frozen phase.
1234  """
1235  converter_data = _FunctionConverterDataInEager(
1236      func=func,
1237      lower_control_flow=lower_control_flow,
1238      aggressive_inlining=aggressive_inlining)
1239
1240  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1241      converter_data=converter_data)
1242
1243  frozen_func = _construct_concrete_function(func, output_graph_def,
1244                                             converted_input_indices)
1245  return frozen_func, output_graph_def
1246
1247
1248def convert_variables_to_constants_from_session_graph(
1249    session,
1250    graph_def,
1251    output_node_names,
1252    variable_names_allowlist=None,
1253    variable_names_denylist=None):
1254  """Replaces all the variables in a graph with constants of the same values.
1255
1256  This function works similarly to convert_variables_to_constants_v2, but it
1257  retrieves the constant values from a Session instead of from a
1258  ConcreteFunction. This is useful when converting graphs generated from
1259  TensorFlow V1, where ConcreteFunctions are not available. This also differs
1260  from graph_util.convert_variables_to_constants in that it supports resource
1261  variables when V2 control flow constructions are present.
1262
1263  Args:
1264    session: Active TensorFlow session containing the variables.
1265    graph_def: A GraphDef to convert.
1266    output_node_names: List of name strings for the result nodes of the graph.
1267    variable_names_allowlist: The set of variable names to convert (by default,
1268      all variables are converted).
1269    variable_names_denylist: The set of variable names to omit converting to
1270      constants.
1271
1272  Returns:
1273    An optimized GraphDef.
1274  """
1275  graph_def, _ = _replace_variables_by_constants(
1276      converter_data=_SessionConverterData(
1277          session=session,
1278          graph_def=graph_def,
1279          output_node_names=output_node_names,
1280          variable_names_allowlist=variable_names_allowlist,
1281          variable_names_denylist=variable_names_denylist))
1282  return graph_def
1283