• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to convert variables to constants in TensorFlow 2.0."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import numpy as np
23
24from tensorflow.core.framework import attr_value_pb2
25from tensorflow.core.framework import graph_pb2
26from tensorflow.core.framework import tensor_shape_pb2
27from tensorflow.core.framework import variable_pb2
28from tensorflow.core.protobuf import config_pb2
29from tensorflow.core.protobuf import meta_graph_pb2
30from tensorflow.core.protobuf import rewriter_config_pb2
31from tensorflow.python.eager import context
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import errors
34from tensorflow.python.framework import graph_util
35from tensorflow.python.framework import ops
36from tensorflow.python.framework import tensor_util
37from tensorflow.python.grappler import tf_optimizer
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import variables
40from tensorflow.python.platform import tf_logging as logging
41from tensorflow.python.training.saver import export_meta_graph
42from tensorflow.python.util import lazy_loader
43from tensorflow.python.util import object_identity
44
45# Lazy load the single eager module to avoid introducing new dependencies for
46# graph_util:convert_variables_to_constants (eg in
47# tensorflow/contrib/session_bundle:session_bundle_py_test).
48wrap_function = lazy_loader.LazyLoader(
49    "wrap_function", globals(),
50    "tensorflow.python.eager.wrap_function")
51
52# Used in _FunctionConverterDataInGraph().
53VAR_ASSIGN_COLLECTION = "extra_var_assign_ops"
54_CONDITIONAL_OPS = set(["If", "StatelessIf"])
55_LOOP_OPS = set(["While", "StatelessWhile"])
56_CONTROL_FLOW_OPS = _CONDITIONAL_OPS.union(_LOOP_OPS)
57
58
59class _TensorData(
60    collections.namedtuple("_TensorData", ["numpy", "dtype", "index"])):
61  """Data about a tensor that was converted to a constant."""
62  __slots__ = ()
63
64  @property
65  def dtype_attr(self):
66    return attr_value_pb2.AttrValue(type=self.dtype)
67
68
69class _EndPoint(collections.namedtuple("_EndPoint", ["convertible", "index"])):
70  """An endpoint in a graph."""
71  __slots__ = ()
72
73  def __str__(self):
74    return "{}[{}]".format(self.convertible, self.index)
75
76
77class _Edge(collections.namedtuple("_Edge", ["source", "destination"])):
78  """A directed graph edge."""
79  __slots__ = ()
80
81  def __str__(self):
82    return "{} -> {}".format(self.source, self.destination)
83
84
85class _Convertible(object):
86  """An entity that can have variables converted to constants."""
87
88  def __init__(self, enclosing_graph):
89    self._enclosing_graph = enclosing_graph
90    self._outgoing_edges = []
91    self._converted_self = None
92
93  def converted_self(self):
94    """A copy of this Convertible to be modified during conversion.
95
96    Returns:
97      Implementations should return the copied instance, which in turn should
98      be contained in converted_enclosing_graph(). This instance is the one that
99      will be modified during conversion. Its main use will be in the
100      implementations of convert_variable_to_constant().
101    """
102    raise NotImplementedError()
103
104  def convert_variable_to_constant(self, incoming_edge, tensor_data):
105    """Converts a variable in this Convertible and its dependencies.
106
107    This method should make sure that a converted copy of itself is present in
108    the converted graph, and that all Convertibles depending on this one also go
109    through the same process.
110
111    Args:
112      incoming_edge: The graph edge into this Convertible that is being
113        converted to a constant.
114      tensor_data: The tensor representing the constant.
115    """
116    raise NotImplementedError()
117
118  def create_edges(self):
119    """Calls add_outgoing_edge for all edges known to this Convertible.
120
121    This is used to build the graph dependencies, so that conversion of
122    variables to constants can be properly propagated through the graph. Usually
123    this method will call add_outgoing_edge() to all the Convertible inputs.
124    """
125    raise NotImplementedError()
126
127  def add_outgoing_edge(self, edge):
128    """Adds an outgoing edge to the Convertible's list of edges.
129
130    Args:
131      edge: The outgoing edge (its source should be 'self').
132    """
133    self._outgoing_edges.append(edge)
134
135  @property
136  def converted_enclosing_graph(self):
137    """The graph being converted."""
138    return self._enclosing_graph.converted_self()
139
140  @property
141  def outgoing_edges(self):
142    """The list of edges starting at this Convertible."""
143    return self._outgoing_edges
144
145
146class _Function(_Convertible):
147  """A library function Convertible.
148
149  Edges into functions are edges from node _inputs_ into function _inputs_:
150  Functions get their input from their callers, not from node outputs, and the
151  callers in turn get those values as inputs.
152  """
153
154  def __init__(self, function, enclosing_graph):
155    super(_Function, self).__init__(enclosing_graph)
156    self._function = function
157    self._nodes = {
158        n.name:
159        _Node.new(node=n, function=self, enclosing_graph=enclosing_graph)
160        for n in function.node_def
161    }
162
163  def __str__(self):
164    return self.function.signature.name
165
166  @property
167  def function(self):
168    return self._function
169
170  @property
171  def nodes(self):
172    return self._nodes
173
174  def converted_self(self):
175    """The Function copy to be converted.
176
177    The copy will be renamed according to the graph's converted_function_name
178    map, to ensure the name does not match anything currently in TensorFlow's
179    function cache.
180
181    Returns:
182      The function instance to be converted.
183    """
184    if self._converted_self is None:
185      old_name = self.function.signature.name
186      new_name = self._enclosing_graph.converted_function_names[old_name]
187      self.converted_enclosing_graph.rename_function(old_name, new_name)
188      self._converted_self = self.converted_enclosing_graph.functions[new_name]
189    return self._converted_self
190
191  def convert_variable_to_constant(self, incoming_edge, tensor_data):
192    """Converts one function argument into a constant.
193
194    Args:
195      incoming_edge: The edge into the argument to be converted.
196      tensor_data: The constant value.
197    """
198    function = self.converted_self().function
199    index = incoming_edge.destination.index
200    function.signature.input_arg[index].type = tensor_data.dtype
201
202    for edge in self.outgoing_edges:
203      if edge.source.index == index:
204        edge.destination.convertible.convert_variable_to_constant(
205            edge, tensor_data)
206
207  def create_edges(self):
208    for n in self._nodes.values():
209      n.create_edges()
210
211
212class _Node(_Convertible):
213  """A Convertible NodeDef."""
214
215  def __init__(self, node, function, enclosing_graph):
216    super(_Node, self).__init__(enclosing_graph)
217    self._node = node
218    self._function = function
219
220  def __str__(self):
221    return self._node.name
222
223  @staticmethod
224  def new(node, function, enclosing_graph):
225    """Creates a new _Node base on its operation type."""
226    if node.op in ["VariableV2", "VarHandleOp", "Placeholder"]:
227      return _VarHandle(node, function, enclosing_graph)
228    elif node.op == "Case":
229      return _Case(node, function, enclosing_graph)
230    elif node.op == "Merge":
231      return _Merge(node, function, enclosing_graph)
232    elif node.op == "PartitionedCall":
233      return _PartitionedCall(node, function, enclosing_graph)
234    elif node.op == "StatefulPartitionedCall":
235      return _PartitionedCall(node, function, enclosing_graph)
236    elif node.op == "ReadVariableOp":
237      return _ReadVariable(node, function, enclosing_graph)
238    elif node.op == "ResourceGather":
239      return _ResourceGather(node, function, enclosing_graph)
240    elif node.op == "ResourceGatherNd":
241      return _ResourceGatherNd(node, function, enclosing_graph)
242    elif node.op in ["If", "StatelessIf"]:
243      return _If(node, function, enclosing_graph)
244    elif node.op in ["While", "StatelessWhile"]:
245      return _While(node, function, enclosing_graph)
246    elif node.op in [
247        "Enter", "Exit", "Identity", "NextIteration", "Switch", "_SwitchN"]:
248      return _Intermediate(node, function, enclosing_graph)
249    else:
250      return _Node(node, function, enclosing_graph)
251
252  @property
253  def node(self):
254    return self._node
255
256  @property
257  def container(self):
258    """The node container (either a graph or a function)."""
259    if self._function is not None:
260      return self._function.function
261    return self._enclosing_graph.graph_def
262
263  def converted_self(self):
264    """The NodeDef to be converted.
265
266    Returns:
267      The NodeDef to be converted, which can come from either a graph for a
268      function. Derived classes should call this (via 'super') to make sure the
269      node is retrieved from the right place.
270    """
271    if self._converted_self is None:
272      source = self._function or self._enclosing_graph
273      self._converted_self = source.converted_self().nodes[self._node.name]
274    return self._converted_self
275
276  def convert_variable_to_constant(self, incoming_edge, tensor_data):
277    pass
278
279  def create_edges(self):
280    for index, name in enumerate(self._node.input):
281      # Discard edges from control inputs.
282      if name[0] == "^":
283        continue
284      source = self.resolve_input(name)
285      source.convertible.add_outgoing_edge(
286          _Edge(source, _EndPoint(self, index)))
287
288  def resolve_input(self, input_name):
289    """Resolves an input into its _EndPoint.
290
291    A NodeDef's input name can refer to either global NodeDefs (in the
292    GraphDef's node list), a NodeDef in a function's node list, or a Function
293    (in the GraphDef's function library). The name can also carry semantic
294    information, depending on whether it starts with "^". This method handles
295    all that logic in order to find the object to which the input name refers
296    to.
297
298    Args:
299      input_name: The input name to resolve.
300
301    Returns:
302      The object referred to by 'input_name'.
303    """
304
305    # The logic below oversimplifies the semantics, but is good enough for the
306    # purposes of converting to constants. The introduction of new types of
307    # operations may change this, forcing the code to be more generic.
308    #
309    # In particular, we are assuming that the lack of an index suffix means
310    # ":0", when it could mean "all the outputs of a node." This works now
311    # because converting to constants relies very little on output types, and
312    # when it does it specializes its treatment in dedicated classes.
313    name_elts = input_name.split(":")
314    source_name = name_elts[0]
315    if source_name[0] == "^":
316      source_name = source_name[1:]
317    source_index = 0
318    if len(name_elts) > 1 and name_elts[-1].isnumeric():
319      source_index = int(name_elts[-1])
320
321    if self._function is None:
322      return _EndPoint(self._enclosing_graph.nodes[source_name], source_index)
323
324    if source_index != 0 or source_name in self._function.nodes:
325      return _EndPoint(self._function.nodes[source_name], source_index)
326
327    inputs = [i.name for i in self._function.function.signature.input_arg]
328    return _EndPoint(self._function, inputs.index(source_name))
329
330  def update_dtype(self, attr_name, index, dtype):
331    """Changes the type of a given input.
332
333    Args:
334      attr_name: The NodeDef attribute containing the type to change.
335      index: The index of the input type to change.
336      dtype: The type to change to.
337    """
338    attr = self._node.attr[attr_name]
339    num_types = 0
340    # Check for various 'oneof' possibilities, and update the type if
341    # index in range.
342    if attr.HasField("list"):
343      types = attr.list.type
344      num_types = len(types)
345      if num_types > index:
346        types[index] = dtype
347        return
348    elif attr.HasField("type"):
349      num_types = 1
350      if index == 0:
351        attr.type = dtype
352        return
353    raise ValueError(
354        "Index %d out of range for node(%s).attr(%s), which has %d elements." %
355        (index, self._node.name, attr_name, num_types))
356
357
358class _Intermediate(_Node):
359  """Specialization of _Node to intermediate ops."""
360
361  def convert_variable_to_constant(self, incoming_edge, tensor_data):
362    node = self.converted_self()
363    node.update_dtype("T", incoming_edge.destination.index, tensor_data.dtype)
364    if "_output_shapes" in node.node.attr:
365      del node.node.attr["_output_shapes"]
366    for edge in self.outgoing_edges:
367      edge.destination.convertible.convert_variable_to_constant(
368          edge, tensor_data)
369
370
371class _Merge(_Node):
372  """Specialization of _Node to Merge ops."""
373
374  def convert_variable_to_constant(self, incoming_edge, tensor_data):
375    # The Merge operation has a single type for all its inputs, the number of
376    # which is reflected in the "N" attribute. For the time being, we assume
377    # that unilaterally changing all of them at once is ok.
378    super(_Merge, self).convert_variable_to_constant(
379        _Edge(incoming_edge.source,
380              _Edge(incoming_edge.destination.convertible, 0)), tensor_data)
381
382
383class _VarHandle(_Node):
384  """Specialization of _Node to VarHandleOp."""
385
386  def convert_variable_to_constant(self, incoming_edge, tensor_data):
387    tensor_proto = tensor_util.make_tensor_proto(tensor_data.numpy,
388                                                 tensor_data.dtype,
389                                                 tensor_data.numpy.shape)
390
391    node = self.converted_self().node
392    node.Clear()
393    node.name = self._node.name
394    node.op = "Const"
395    node.attr["dtype"].CopyFrom(tensor_data.dtype_attr)
396    node.attr["value"].tensor.CopyFrom(tensor_proto)
397
398    for edge in self.outgoing_edges:
399      edge.destination.convertible.convert_variable_to_constant(
400          edge, tensor_data)
401
402
403class _ResourceGather(_Node):
404  """Specialization of _Node to ResourceGather."""
405
406  def convert_variable_to_constant(self, incoming_edge, tensor_data):
407    # We currently skip the conversion if this is inside a function.
408    if self._function is not None:
409      return
410    if self._node.attr["batch_dims"].i != 0:
411      raise ValueError("batch_dims != 0 is not supported by freeze_graph.")
412    axis_node_name = self._node.name + "/axis"
413    axis_dtype = self._node.attr["Tindices"]
414    axis_data = np.array(self._node.attr["batch_dims"].i)
415    output_axis_node = self.converted_self().container.node.add()
416    output_axis_node.name = axis_node_name
417    output_axis_node.op = "Const"
418    output_axis_node.attr["dtype"].CopyFrom(axis_dtype)
419    tensor = tensor_util.make_tensor_proto(
420        axis_data, dtype=axis_dtype.type, shape=axis_data.shape)
421    output_axis_node.attr["value"].tensor.CopyFrom(tensor)
422
423    output_node = self.converted_self().node
424    output_node.Clear()
425    output_node.name = self._node.name
426    output_node.op = "GatherV2"
427    output_node.input.extend(
428        [self._node.input[0], self._node.input[1], axis_node_name])
429    output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
430    output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
431    output_node.attr["Taxis"].CopyFrom(axis_dtype)
432    if "_class" in self._node.attr:
433      output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
434
435
436class _ResourceGatherNd(_Node):
437  """Specialization of _Node to ResourceGatherNd."""
438
439  def convert_variable_to_constant(self, incoming_edge, tensor_data):
440    output_node = self.converted_self().node
441    output_node.Clear()
442    output_node.name = self._node.name
443    output_node.op = "GatherNd"
444    output_node.input.extend([self._node.input[0], self._node.input[1]])
445    output_node.attr["Tparams"].CopyFrom(self._node.attr["dtype"])
446    output_node.attr["Tindices"].CopyFrom(self._node.attr["Tindices"])
447    if "_class" in self._node.attr:
448      output_node.attr["_class"].CopyFrom(self._node.attr["_class"])
449
450
451class _ReadVariable(_Node):
452  """Specialization of _Node to ReadVariableOp."""
453
454  def convert_variable_to_constant(self, incoming_edge, tensor_data):
455    node = self.converted_self().node
456    node.Clear()
457    node.name = self._node.name
458    node.op = "Identity"
459
460    node.input.append(self._node.input[0])
461    node.attr["T"].CopyFrom(self._node.attr["dtype"])
462    if "_class" in self._node.attr:
463      node.attr["_class"].CopyFrom(self._node.attr["_class"])
464
465    # If the ReadVariableOp is part of a function, then every node having the
466    # ReadVariableOp one as its input will refer to it using a ":value"
467    # syntax. We need to change that to ":output".
468    if self._function is not None:
469      for edge in self.outgoing_edges:
470        index = edge.destination.index
471        dest = edge.destination.convertible.converted_self()
472        if isinstance(dest, _Node):
473          input_name_parts = dest.node.input[index].split(":")
474          if len(input_name_parts) > 1 and input_name_parts[1] == "value":
475            input_name_parts[1] = "output"
476            dest.node.input[index] = ":".join(input_name_parts)
477
478
479class _FunctionCaller(_Node):
480  """A base class for Convertibles that reference functions."""
481
482  def __init__(self, node, function, enclosing_graph, first_function_input,
483               type_attribute, function_attributes):
484    """Initializes a _FunctionCaller.
485
486    Args:
487      node: As in _Node.
488      function: As in _Node.
489      enclosing_graph: As in _Node.
490      first_function_input: The index of the first NodeDef input that is tied to
491        the function inputs. It is assumed that the rest of the NodeDef inputs
492        map one to one to function inputs.
493      type_attribute: The name of the NodeDef attribute that defines the input
494        types. It is assumed that the types listed here map one-to-one with the
495        function inputs (that is, they do _not_ specify types for inputs that
496        are not passed to functions).
497      function_attributes: The names of the NodeDef attributes containing
498        references to functions.
499    """
500    super(_FunctionCaller, self).__init__(node, function, enclosing_graph)
501    self._first_function_input = first_function_input
502    self._type_attribute = type_attribute
503    self._function_attributes = function_attributes
504
505  def converted_self(self):
506    if self._converted_self is None:
507      node = super(_FunctionCaller, self).converted_self().node
508      converted_names = self._enclosing_graph.converted_function_names
509      for attr_name in self._function_attributes:
510        attr = node.attr[attr_name]
511        if attr.HasField("func"):
512          attr.func.name = converted_names[attr.func.name]
513        elif attr.HasField("list"):
514          for func in attr.list.func:
515            func.name = converted_names[func.name]
516    return self._converted_self
517
518  def convert_variable_to_constant(self, incoming_edge, tensor_data):
519    node = self.converted_self()
520    index = incoming_edge.destination.index
521    if index >= self._first_function_input:
522      node.update_dtype(self._type_attribute,
523                        index - self._first_function_input, tensor_data.dtype)
524
525    # The loop below is reasonable but not correct in general:
526    # The outgoing edges going into the functions are correct, because the
527    # inputs map to the function inputs. But the edges going into other nodes do
528    # not take into account the logic of the body function, which may do
529    # arbitrary things to the node's output:
530    #
531    #   while x < 0:
532    #     return y
533    #
534    # In this case, the node's ":0" output may map to its ":1 input". For the
535    # time being, then, we only process edges into functions.
536    for edge in self.outgoing_edges:
537      dest = edge.destination.convertible
538      if edge.source.index == index and isinstance(dest, _Function):
539        dest.convert_variable_to_constant(edge, tensor_data)
540
541  def create_edges(self):
542    """Creates edges related to a function caller.
543
544    Edges from a function caller to its called functions are always edges from
545    _inputs_ to _inputs_: a FunctionDef input is given by the caller, based on
546    its own inputs.
547    """
548    super(_FunctionCaller, self).create_edges()
549    for attr_name in self._function_attributes:
550      attr = self._node.attr[attr_name]
551      if attr.HasField("func"):
552        function = self._enclosing_graph.functions[attr.func.name]
553        for index in range(len(self._node.input) - self._first_function_input):
554          self.add_outgoing_edge(
555              _Edge(
556                  _EndPoint(self, index + self._first_function_input),
557                  _EndPoint(function, index)))
558      elif attr.HasField("list"):
559        for func in attr.list.func:
560          function = self._enclosing_graph.functions[func.name]
561          for index in range(
562              len(self._node.input) - self._first_function_input):
563            self.add_outgoing_edge(
564                _Edge(
565                    _EndPoint(self, index + self._first_function_input),
566                    _EndPoint(function, index)))
567
568
569class _If(_FunctionCaller):
570  """Specialization of _Node to If-like operations."""
571
572  def __init__(self, node, function, enclosing_graph):
573    super(_If, self).__init__(
574        node,
575        function,
576        enclosing_graph,
577        first_function_input=1,
578        type_attribute="Tin",
579        function_attributes=["then_branch", "else_branch"])
580
581
582class _Case(_FunctionCaller):
583  """Specialization of _Node to Case-like operations."""
584
585  def __init__(self, node, function, enclosing_graph):
586    super(_Case, self).__init__(
587        node,
588        function,
589        enclosing_graph,
590        first_function_input=1,
591        type_attribute="Tin",
592        function_attributes=["branches"])
593
594
595class _PartitionedCall(_FunctionCaller):
596  """Specialization of _Node to PartitionedCall-like operations."""
597
598  def __init__(self, node, function, enclosing_graph):
599    super(_PartitionedCall, self).__init__(
600        node,
601        function,
602        enclosing_graph,
603        first_function_input=0,
604        type_attribute="Tin",
605        function_attributes=["f"])
606
607
608class _While(_FunctionCaller):
609  """Specialization of _Node to While-like operations."""
610
611  def __init__(self, node, function, enclosing_graph):
612    super(_While, self).__init__(
613        node,
614        function,
615        enclosing_graph,
616        first_function_input=0,
617        type_attribute="T",
618        function_attributes=["body", "cond"])
619
620  def convert_variable_to_constant(self, incoming_edge, tensor_data):
621    super(_While, self).convert_variable_to_constant(incoming_edge, tensor_data)
622    node = self.converted_self()
623    if node.node.attr["output_shapes"].list.shape:
624      node.node.attr["output_shapes"].list.shape[
625          incoming_edge.destination.index].CopyFrom(
626              tensor_shape_pb2.TensorShapeProto(dim=[
627                  tensor_shape_pb2.TensorShapeProto.Dim(size=dim)
628                  for dim in tensor_data.numpy.shape
629              ]))
630
631    # The while's body inputs and outputs have the same type, so here we can go
632    # ahead and change that function's output type.
633    body_name = self._node.attr["body"].func.name
634    body = self._enclosing_graph.functions[body_name].converted_self().function
635    body.signature.output_arg[
636        incoming_edge.destination.index].type = tensor_data.dtype
637
638
639class _GraphDef(_Convertible):
640  """A convertible GraphDef."""
641
642  def __init__(self, graph_def):
643    super(_GraphDef, self).__init__(enclosing_graph=None)
644    self._graph_def = graph_def
645    self._nodes = {
646        n.name: _Node.new(node=n, function=None, enclosing_graph=self)
647        for n in graph_def.node
648    }
649    self._functions = {
650        f.signature.name: _Function(f, enclosing_graph=self)
651        for f in graph_def.library.function
652    }
653    self.create_edges()
654    self._converted_function_names = None
655
656  @property
657  def graph_def(self):
658    return self._graph_def
659
660  @property
661  def nodes(self):
662    return self._nodes
663
664  @property
665  def functions(self):
666    return self._functions
667
668  @property
669  def converted_function_names(self):
670    """Map from original to new function names.
671
672    In order to avoid conflicts (two functions with the same name, one converted
673    and one not), we need to change the name of every converted function to
674    something that is hopefully unique.
675
676    Returns:
677      Map from original to new suggested function names.
678    """
679    if self._converted_function_names is None:
680      parsed_names = []  # List of (id, base_name, original_name)
681      for name in self.functions:
682        elements = name.rsplit("_", 1)
683        if len(elements) == 2 and elements[1].isnumeric():
684          parsed_names.append((int(elements[1]), elements[0], name))
685        else:
686          parsed_names.append((-1, name, name))
687      self._converted_function_names = {
688          name: "{}_frozen_{}".format(base_name, ops.uid())
689          for (_, base_name, name) in sorted(parsed_names)
690      }
691
692    return self._converted_function_names
693
694  def rename_function(self, old_name, new_name):
695    func = self.functions.pop(old_name)
696    func.function.signature.name = new_name
697    self.functions[new_name] = func
698
699  def converted_self(self):
700    if self._converted_self is None:
701      copied_graph = graph_pb2.GraphDef()
702      copied_graph.CopyFrom(self._graph_def)
703      self._converted_self = _GraphDef(copied_graph)
704    return self._converted_self
705
706  def create_edges(self):
707    for n in self._nodes.values():
708      n.create_edges()
709    for f in self._functions.values():
710      f.create_edges()
711
712
713class _ConverterData(object):
714  """Container for constant conversion supporting data.
715
716  The data includes the graph being converted, and the pre-converted
717  tensors. This class will be specialized for ConcreteFunction and Session-based
718  conversions, as the means to obtain that data is different for each case.
719  """
720
721  def __init__(self,
722               graph_def,
723               variable_names_allowlist=None,
724               variable_names_denylist=None):
725    self._graph_def = graph_def
726    self._tensor_data = {}
727    self._build_node_defs_list()
728    self._variable_names_allowlist = variable_names_allowlist
729    self._variable_names_denylist = variable_names_denylist
730
731  @property
732  def graph_def(self):
733    """The graph to be converted."""
734    return self._graph_def
735
736  @property
737  def node_defs(self):
738    """All the node defs in the graph to be converted.
739
740    Returns:
741      A map from node name to the NodeDef for all NodeDefs in the graph, as well
742      as all control flow NodeDefs in the functions.
743    """
744    return self._node_defs
745
746  @property
747  def tensor_data(self):
748    """A map from tensor name to its converted _TensorData."""
749    return self._tensor_data
750
751  def _should_convert(self, name):
752    """Checks whether to convert the given variable name to a constant."""
753    return (self._variable_names_allowlist is None or
754            name in self._variable_names_allowlist) and (
755                self._variable_names_denylist is None or
756                name not in self._variable_names_denylist)
757
758  def _build_node_defs_list(self):
759    """Builds the list of NodeDefs in the GraphDef.
760
761    This list consists of all NodeDefs in the main graph as well as all control
762    flow NodeDefs in the functions.
763
764    The remaining NodeDefs in the functions are not included because the op
765    names
766    are not unique and the variables are handled differently than the main
767    graph.
768    The control flow ops need to be extracted because they are need their
769    attributes to be updated similar to the control flow ops in the main graph.
770    """
771    self._node_defs = {node.name: node for node in self._graph_def.node}
772
773    if self._graph_def.library:
774      for func in self._graph_def.library.function:
775        self._node_defs.update({
776            node.name: node
777            for node in func.node_def
778            if node.op in _CONTROL_FLOW_OPS
779        })
780
781
782class _FunctionConverterData(_ConverterData):
783  """Container for ConcreteFunction-based conversion data."""
784
785  def __init__(self,
786               func,
787               lower_control_flow,
788               aggressive_inlining,
789               variable_names_allowlist=None,
790               variable_names_denylist=None):
791    """Creates the conversion data for the given function.
792
793    Args:
794      func: ConcreteFunction.
795      lower_control_flow: Boolean indicating whether or not to lower control
796        flow ops such as If and While.
797      aggressive_inlining: Boolean indicating whether or not to do aggressive
798        function inlining (might be unsafe if function has stateful ops, not
799        properly connected to control outputs).
800      variable_names_allowlist: The set of variable names to convert (by
801        default, all variables are converted).
802      variable_names_denylist: The set of variable names to omit converting to
803        constants.
804    """
805
806    self._func = func
807    # Inline the graph in order to remove functions when possible.
808    graph_def = _run_inline_graph_optimization(func, lower_control_flow,
809                                               aggressive_inlining)
810    super(_FunctionConverterData, self).__init__(
811        graph_def,
812        variable_names_allowlist=variable_names_allowlist,
813        variable_names_denylist=variable_names_denylist)
814
815    self._build_tensor_data()
816
817  def _eval(self, tensor):
818    """Returns the value in the tensor. Must be implemented in sub-classes."""
819    raise errors.UnimplementedError(
820        "The evaluation method should be implemented in sub-classes.")
821
822  def _build_tensor_data(self):
823    """Caches the tensor data for all Placeholders in the given function."""
824    map_index_to_variable = {}
825    for var in self._func.graph.variables:
826      for idx, captured_input in enumerate(self._func.captured_inputs):
827        if var.handle is captured_input:  # pylint: disable=protected-access
828          map_index_to_variable[idx] = var
829          break
830
831    # Iterates through all captures which are represented as Placeholders.
832    for idx, (val_tensor, name_tensor) in enumerate(self._func.graph.captures):
833      tensor_name = name_tensor.name.split(":")[0]
834      if not self._should_convert(tensor_name):
835        continue
836      if idx in map_index_to_variable:
837        data = self._eval(map_index_to_variable[idx])
838      else:
839        if val_tensor.dtype == dtypes.resource:
840          logging.vlog(1, "Skip converting resource tensor %s" % tensor_name)
841          continue
842        data = np.array(self._eval(val_tensor))
843
844      self._tensor_data[tensor_name] = _TensorData(
845          numpy=data,
846          dtype=dtypes.as_dtype(data.dtype).as_datatype_enum,
847          index=idx)
848
849    # Get data for VariableV2 ops (reference variables) that cannot be lifted.
850    for node in self.node_defs.values():
851      if node.op == "VariableV2":
852        if not self._should_convert(node.name):
853          continue
854        if node.name not in self.tensor_data:
855          with self._func.graph.as_default():
856            identity_node = array_ops.identity(
857                self._func.graph.as_graph_element(node.name + ":0"))
858          pruned_graph = self._func.prune([], [identity_node.name])()[0]
859          self._tensor_data[node.name] = _TensorData(
860              numpy=pruned_graph.numpy(),
861              dtype=node.attr["dtype"].type,
862              index=None)
863
864
865class _FunctionConverterDataInEager(_FunctionConverterData):
866  """Container for ConcreteFunction-based conversion data in Eager mode."""
867
868  def _eval(self, tensor):
869    """Returns the value in the tensor. Must be implemented in sub-classes."""
870    return tensor.numpy()
871
872
873class _FunctionConverterDataInGraph(_FunctionConverterData):
874  """Container for ConcreteFunction-based conversion data in Graph mode."""
875
876  def __init__(self,
877               func,
878               lower_control_flow,
879               aggressive_inlining,
880               variable_names_allowlist=None,
881               variable_names_denylist=None,
882               session=None):
883    """Creates the conversion data for the given function.
884
885    Args:
886      func: ConcreteFunction.
887      lower_control_flow: Boolean indicating whether or not to lower control
888        flow ops such as If and While.
889      aggressive_inlining: Boolean indicating whether or not to do aggressive
890        function inlining (might be unsafe if function has stateful ops, not
891        properly connected to control outputs).
892      variable_names_allowlist: The set of variable names to convert (by
893        default, all variables are converted).
894      variable_names_denylist: The set of variable names to omit converting to
895        constants.
896      session: Session object.
897    """
898    self._session = session
899
900    session.run(variables.global_variables_initializer())
901    # Run extra assignment ops if needed.
902    # These assignments are run sequentially to ensure order.
903    for op in ops.get_default_graph().get_collection(VAR_ASSIGN_COLLECTION):
904      session.run(op)
905
906    super(_FunctionConverterDataInGraph, self).__init__(
907        func,
908        lower_control_flow,
909        aggressive_inlining,
910        variable_names_allowlist,
911        variable_names_denylist)
912
913  def _eval(self, tensor):
914    """Returns the value in the tensor. Must be implemented in sub-classes."""
915    return self._session.run(tensor)
916
917
918class _SessionConverterData(_ConverterData):
919  """Container for Session-based conversion data."""
920
921  def __init__(self,
922               session,
923               graph_def,
924               output_node_names,
925               variable_names_allowlist=None,
926               variable_names_denylist=None):
927    graph_def = graph_util.extract_sub_graph(graph_def, output_node_names)
928    super(_SessionConverterData, self).__init__(
929        graph_def,
930        variable_names_allowlist=variable_names_allowlist,
931        variable_names_denylist=variable_names_denylist)
932
933    nodes_to_convert = []
934    tensor_names_to_convert = []
935    for node in self.graph_def.node:
936      if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
937        tensor_name = node.name
938        if not self._should_convert(tensor_name):
939          continue
940        if node.op == "VarHandleOp":
941          tensor_name = tensor_name + "/Read/ReadVariableOp"
942        nodes_to_convert.append(node)
943        tensor_names_to_convert.append(tensor_name + ":0")
944
945    if tensor_names_to_convert:
946      converted_tensors = session.run(tensor_names_to_convert)
947      for node, tensor_value in zip(nodes_to_convert, converted_tensors):
948        self._tensor_data[node.name] = _TensorData(
949            numpy=tensor_value, dtype=node.attr["dtype"].type, index=None)
950
951
952def disable_lower_using_switch_merge(graph_def):
953  """Set '_lower_using_switch_merge' attributes to False.
954
955  Sets the attribute to False in the NodeDefs in the main graph and the NodeDefs
956  in each function's graph.
957
958  Args:
959    graph_def: GraphDef proto.
960
961  Returns:
962    GraphDef
963  """
964  output_graph_def = graph_pb2.GraphDef()
965  output_graph_def.CopyFrom(graph_def)
966
967  def disable_control_flow_lowering(node):
968    if node.op in _CONTROL_FLOW_OPS:
969      node.attr["_lower_using_switch_merge"].b = False
970
971  for node in output_graph_def.node:
972    disable_control_flow_lowering(node)
973
974  if output_graph_def.library:
975    for func in output_graph_def.library.function:
976      for node in func.node_def:
977        disable_control_flow_lowering(node)
978  return output_graph_def
979
980
981def _run_inline_graph_optimization(func, lower_control_flow,
982                                   aggressive_inlining):
983  """Apply function inline optimization to the graph.
984
985  Returns the GraphDef after Grappler's function inlining optimization is
986  applied. This optimization does not work on models with control flow.
987
988  Args:
989    func: ConcreteFunction.
990    lower_control_flow: Boolean indicating whether or not to lower control flow
991      ops such as If and While. (default True)
992    aggressive_inlining: Boolean indicating whether or not to do aggressive
993      function inlining (might be unsafe if function has stateful ops not
994      properly connected to control outputs).
995
996  Returns:
997    GraphDef
998  """
999  graph_def = func.graph.as_graph_def()
1000  if not lower_control_flow:
1001    graph_def = disable_lower_using_switch_merge(graph_def)
1002
1003  # In some cases, a secondary implementation of the function (e.g. for GPU) is
1004  # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in
1005  # TF2 produces a CuDNN-based RNN for GPU).
1006  # This function suppose to inline all functions calls, but "api_implements"
1007  # prevents this from happening. Removing the attribute solves the problem.
1008  # To learn more about "api_implements", see:
1009  #   tensorflow/core/grappler/optimizers/implementation_selector.h
1010  for function in graph_def.library.function:
1011    if "api_implements" in function.attr:
1012      del function.attr["api_implements"]
1013
1014  meta_graph = export_meta_graph(graph_def=graph_def, graph=func.graph)
1015
1016  # Clear the initializer_name for the variables collections, since they are not
1017  # needed after saved to saved_model.
1018  for name in [
1019      "variables", "model_variables", "trainable_variables", "local_variables"
1020  ]:
1021    raw_list = []
1022    for raw in meta_graph.collection_def["variables"].bytes_list.value:
1023      variable = variable_pb2.VariableDef()
1024      variable.ParseFromString(raw)
1025      variable.ClearField("initializer_name")
1026      raw_list.append(variable.SerializeToString())
1027    meta_graph.collection_def[name].bytes_list.value[:] = raw_list
1028
1029  # Add a collection 'train_op' so that Grappler knows the outputs.
1030  fetch_collection = meta_graph_pb2.CollectionDef()
1031  for array in func.inputs + func.outputs:
1032    fetch_collection.node_list.value.append(array.name)
1033  meta_graph.collection_def["train_op"].CopyFrom(fetch_collection)
1034
1035  # Initialize RewriterConfig with everything disabled except function inlining.
1036  config = config_pb2.ConfigProto()
1037  rewrite_options = config.graph_options.rewrite_options
1038  rewrite_options.min_graph_nodes = -1  # do not skip small graphs
1039  rewrite_options.optimizers.append("function")
1040  if aggressive_inlining:
1041    rewrite_options.function_optimization =\
1042      rewriter_config_pb2.RewriterConfig.AGGRESSIVE
1043  return tf_optimizer.OptimizeGraph(config, meta_graph)
1044
1045
1046def _construct_concrete_function(func, output_graph_def,
1047                                 converted_input_indices):
1048  """Constructs a concrete function from the `output_graph_def`.
1049
1050  Args:
1051    func: ConcreteFunction
1052    output_graph_def: GraphDef proto.
1053    converted_input_indices: Set of integers of input indices that were
1054      converted to constants.
1055
1056  Returns:
1057    ConcreteFunction.
1058  """
1059  # Create a ConcreteFunction from the new GraphDef.
1060  input_tensors = func.graph.internal_captures
1061  converted_inputs = object_identity.ObjectIdentitySet(
1062      [input_tensors[index] for index in converted_input_indices])
1063  not_converted_inputs = [
1064      tensor for tensor in func.inputs if tensor not in converted_inputs
1065  ]
1066  not_converted_inputs_map = {
1067      tensor.name: tensor for tensor in not_converted_inputs
1068  }
1069
1070  new_input_names = [tensor.name for tensor in not_converted_inputs]
1071  new_output_names = [tensor.name for tensor in func.outputs]
1072
1073  # Remove old functions to use updated functions from graph def.
1074  for f in output_graph_def.library.function:
1075    if context.context().has_function(f.signature.name):
1076      context.context().remove_function(f.signature.name)
1077
1078  new_func = wrap_function.function_from_graph_def(output_graph_def,
1079                                                   new_input_names,
1080                                                   new_output_names)
1081
1082  # Manually propagate shape for input tensors where the shape is not correctly
1083  # propagated. Scalars shapes are lost when wrapping the function.
1084  for input_tensor in new_func.inputs:
1085    input_tensor.set_shape(not_converted_inputs_map[input_tensor.name].shape)
1086  return new_func
1087
1088
1089def _replace_variables_by_constants(converter_data):
1090  """Replaces variables by constants on a given graph.
1091
1092  Given a _ConverterData instance with converted variables in its tensor_data
1093  field, create a new graph where the respective variables are replaced with the
1094  converted constants.
1095
1096  Args:
1097    converter_data: A pre-populated _ConverterData instance.
1098
1099  Returns:
1100    The converted graph.
1101  """
1102  input_graph = _GraphDef(converter_data.graph_def)
1103
1104  for tensor_name, tensor_data in converter_data.tensor_data.items():
1105    input_graph.nodes[tensor_name].convert_variable_to_constant(
1106        None, tensor_data)
1107
1108  converted_graph = input_graph.converted_self().graph_def
1109
1110  converted_input_indices = {
1111      t.index
1112      for t in converter_data.tensor_data.values()
1113      if t.index is not None
1114  }
1115
1116  return converted_graph, converted_input_indices
1117
1118
1119def convert_variables_to_constants_v2(func,
1120                                      lower_control_flow=True,
1121                                      aggressive_inlining=False):
1122  """Replaces all the variables in a graph with constants of the same values.
1123
1124  TensorFlow 2.0 function for converting all Variable ops into Const ops holding
1125  the same values. This makes it possible to describe the network fully with a
1126  single GraphDef file, and allows the removal of a lot of ops related to
1127  loading and saving the variables. This function runs Grappler's function
1128  inlining optimization in order to return a single subgraph.
1129
1130  The current implementation only works for graphs that do not contain any
1131  control flow or embedding related ops.
1132
1133  Args:
1134    func: ConcreteFunction.
1135    lower_control_flow: Boolean indicating whether or not to lower control flow
1136      ops such as If and While. (default True)
1137    aggressive_inlining: Boolean indicating whether or not to do aggressive
1138      function inlining (might be unsafe if function has stateful ops, not
1139      properly connected to control outputs). (default False)
1140
1141  Returns:
1142    ConcreteFunction containing a simplified version of the original.
1143  """
1144
1145  converter_data = _FunctionConverterDataInEager(
1146      func=func,
1147      lower_control_flow=lower_control_flow,
1148      aggressive_inlining=aggressive_inlining)
1149
1150  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1151      converter_data=converter_data)
1152
1153  return _construct_concrete_function(func, output_graph_def,
1154                                      converted_input_indices)
1155
1156
1157def convert_var_to_const_function_in_v1(func,
1158                                        lower_control_flow=True,
1159                                        aggressive_inlining=False):
1160  """Replaces all the variables in a graph with constants of the same values.
1161
1162  This function works as same as convert_variables_to_constants_v2, but it
1163  should be used in Graph mode. It is a temporary solution when users want to
1164  integrate their models written in TF2 with infra that requires TF1 mode.
1165
1166  The current implementation only works for graphs that do not contain any
1167  control flow or embedding related ops.
1168
1169  The function must be called in a Session context.
1170
1171  Args:
1172    func: ConcreteFunction.
1173    lower_control_flow: Boolean indicating whether or not to lower control flow
1174      ops such as If and While. (default True)
1175    aggressive_inlining: Boolean indicating whether or not to do aggressive
1176      function inlining (might be unsafe if function has stateful ops, not
1177      properly connected to control outputs). (default False)
1178
1179  Raises:
1180      RuntimeError: If no Session context is present.
1181
1182  Returns:
1183    ConcreteFunction containing a simplified version of the original.
1184  """
1185
1186  session = ops.get_default_session()
1187  if session is None:
1188    raise RuntimeError(
1189        "The conversion must be carried out in a Session context.")
1190
1191  converter_data = _FunctionConverterDataInGraph(
1192      func=func,
1193      lower_control_flow=lower_control_flow,
1194      aggressive_inlining=aggressive_inlining,
1195      session=session)
1196
1197  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1198      converter_data=converter_data)
1199
1200  return _construct_concrete_function(func, output_graph_def,
1201                                      converted_input_indices)
1202
1203
1204def convert_variables_to_constants_v2_as_graph(func,
1205                                               lower_control_flow=True,
1206                                               aggressive_inlining=False):
1207  """Replaces all the variables in a graph with constants of the same values.
1208
1209  This function works as same as convert_variables_to_constants_v2, but it
1210  returns the intermediate `GraphDef` as well. This `GraphDef` contains all the
1211  debug information after all the transformations in the frozen phase.
1212
1213  Args:
1214    func: ConcreteFunction.
1215    lower_control_flow: Boolean indicating whether or not to lower control flow
1216      ops such as If and While. (default True)
1217    aggressive_inlining: Boolean indicating whether or not to do aggressive
1218      function inlining (might be unsafe if function has stateful ops, not
1219      properly connected to control outputs).
1220
1221  Returns:
1222    ConcreteFunction containing a simplified version of the original, and also
1223    the intermediate GraphDef containing the node debug information for the
1224    transformations in the frozen phase.
1225  """
1226  converter_data = _FunctionConverterDataInEager(
1227      func=func,
1228      lower_control_flow=lower_control_flow,
1229      aggressive_inlining=aggressive_inlining)
1230
1231  output_graph_def, converted_input_indices = _replace_variables_by_constants(
1232      converter_data=converter_data)
1233
1234  frozen_func = _construct_concrete_function(func, output_graph_def,
1235                                             converted_input_indices)
1236  return frozen_func, output_graph_def
1237
1238
1239def convert_variables_to_constants_from_session_graph(
1240    session,
1241    graph_def,
1242    output_node_names,
1243    variable_names_allowlist=None,
1244    variable_names_denylist=None):
1245  """Replaces all the variables in a graph with constants of the same values.
1246
1247  This function works similarly to convert_variables_to_constants_v2, but it
1248  retrieves the constant values from a Session instead of from a
1249  ConcreteFunction. This is useful when converting graphs generated from
1250  TensorFlow V1, where ConcreteFunctions are not available. This also differs
1251  from graph_util.convert_variables_to_constants in that it supports resource
1252  variables when V2 control flow constructions are present.
1253
1254  Args:
1255    session: Active TensorFlow session containing the variables.
1256    graph_def: A GraphDef to convert.
1257    output_node_names: List of name strings for the result nodes of the graph.
1258    variable_names_allowlist: The set of variable names to convert (by default,
1259      all variables are converted).
1260    variable_names_denylist: The set of variable names to omit converting to
1261      constants.
1262
1263  Returns:
1264    An optimized GraphDef.
1265  """
1266  # TODO(b/176982859): Find a more satisfying way to update shape information
1267  # than clearing it, or migrate users to a workflow that does not require
1268  # freezing.
1269  for function in graph_def.library.function:
1270    if "_input_shapes" in function.attr:
1271      for input_arg, shape_attribute in zip(
1272          function.signature.input_arg,
1273          function.attr["_input_shapes"].list.shape):
1274        if dtypes.as_dtype(input_arg.type) == dtypes.resource:
1275          shape_attribute.unknown_rank = True
1276  graph_def, _ = _replace_variables_by_constants(
1277      converter_data=_SessionConverterData(
1278          session=session,
1279          graph_def=graph_def,
1280          output_node_names=output_node_names,
1281          variable_names_allowlist=variable_names_allowlist,
1282          variable_names_denylist=variable_names_denylist))
1283  return graph_def
1284