• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Define tflite op hints (intrinsic operations).
16
17This essentially allows defining a TensorFlow API for tflite operations in
18Python with hints on how they are represented in TensorFlow Lite. This basically
19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution
20graph and is useful for LSTMs and other complicated TensorFlow constructions
21that are difficult to pattern match in TOCO, but are represented by a single
22accelerated tflite op.
23
24Example:
25  def tflite_cool_activation(input):
26    # A cool activation function.
27    custom = tf.lite.OpHint("cool_activation")
28    input, = custom.add_inputs(input)
29    output = tf.sigmoid(input) * input
30    output, = custom.add_outputs(output)
31    return output
32
33  image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1))
34  output = tf.identity(tflite_cool_activation(image))
35
36  session = tf.compat.v1.Session()
37
38  graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session)
39  tflite_graph = tf.compat.v1.lite.toco_convert(
40      graphdef_to_convert, [image], [output], allow_custom_ops=True)
41  with open("/tmp/graph.fb", "wb") as fp:
42    fp.write(tflite_graph)
43
44How does it work?:
45
46OpHint is a helper that you use when defining a vanilla python function.
47It allows you to wrap arguments with tf.identities with some custom attributes.
48These attributes allow you to find the original block of ops that was created.
49For example, if you use cool_activation above you essentially get:
50
51a_input = tf.identity()
52result = tf.multiply(tf.sigmoid(a_input), a_input)
53output = tf.identity()
54
55a_input, output are identities that have parameters representing
56what argument they are, what the name of the function they should turn into
57in tf lite as well as a guid that uniquely identifies a particular invocation.
58
59Once you have built your whole tensorflow graph, you can run it and train it
60as usual, but after you have done that, you need to convert the graph into
61a form that replaces these subgraphs wrapped in identities to stub ops. These
62ops don't actually exist in the normal TensorFlow runtime, but will be
63understood by toco later. The generated TensorFlow Lite flatbuffer file will
64contain a custom operator called "cool_activation". Developer needs to implement
65and register this operator in TensorFlow Lite in order to do inference.
66"""
67
68import collections as _collections
69import copy as _copy
70import json as _json
71import uuid as _uuid
72
73from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
74from tensorflow.core.framework import graph_pb2 as _graph_pb2
75from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
76from tensorflow.python.framework import dtypes as _dtypes
77from tensorflow.python.framework import ops as _ops
78from tensorflow.python.framework import tensor_util as _tensor_util
79from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
80from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
81from tensorflow.python.ops import array_ops as _array_ops
82from tensorflow.python.util import compat as _compat
83from tensorflow.python.util import deprecation as _deprecation
84from tensorflow.python.util.all_util import remove_undocumented
85from tensorflow.python.util.tf_export import tf_export as _tf_export
86
87
88@_tf_export(v1=["lite.OpHint"])
89@_deprecation.deprecated(
90    None,
91    "Please follow instructions under "
92    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
93    "fusion in tflite."
94)
95class OpHint:
96  """A class that helps build tflite function invocations.
97
98  It allows you to take a bunch of TensorFlow ops and annotate the construction
99  such that toco knows how to convert it to tflite. This embeds a pseudo
100  function in a TensorFlow graph. This allows embedding high-level API usage
101  information in a lower level TensorFlow implementation so that an alternative
102  implementation can be substituted later.
103
104  Essentially, any "input" into this pseudo op is fed into an identity, and
105  attributes are added to that input before being used by the constituent ops
106  that make up the pseudo op. A similar process is done to any output that
107  is to be exported from the current op.
108
109  """
110  # Attr constants that are used for representation in the GraphDef. These
111  # will be used on every Identity op that is involved in a total OpHint.
112
113  # Name of the OpHint function (cosmetic).
114  FUNCTION_NAME_ATTR = "_tflite_function_name"
115  # UUID of the function (each OpHint gets a new uuid).
116  FUNCTION_UUID_ATTR = "_tflite_function_uuid"
117  # The input index of the input (or nothing if it is an output).
118  FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
119  # The output index of the output (or nothing if it is an input).
120  FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
121  # An index that orders aggregate arguments. Aggregate arguments are ones
122  # that are separate but will be fused horizontally. For example a static LSTM
123  # has a lstm cell for each time step. Each one has a separate opHint, but a
124  # fused SequentialLSTM will treat this as a single tensor.
125  FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
126  # The way in which multiple parts of the aggregate argument will be joined
127  # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
128  # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
129  FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
130  # On fused OpHint stub, the order of inputs that the final LSTM call will
131  # have. What this means is that the TensorFlow order might be
132  # "foo", "bar", "stuff" and you might want the TF lite op order to be
133  # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
134  # attribute to [2, 0, 1, -1].
135  TFLITE_INPUT_INDICES = "_tflite_input_indices"
136  # OpHint level.
137  FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
138  # Ophint internal mapping, this is for high level Ophint only.
139  # This basically contains three kinds of mapping:
140  #   1) How parental ophinted inputs map to the first child ophinted inputs;
141  #   2) How internal children nodes are connected;
142  #   3) How parental ophinted outputs map to the last child ophinted outputs.
143  CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
144
145  # Types of aggregations
146  #  stack: stacks all ophints with matching tags. i.e. for a static rnn.
147  #   specifically, this is good for an input or output to a static rnn cell.
148  AGGREGATE_STACK = "stack"
149  # first: only takes the first output (one with lowest sort index)
150  # of matching tags. This is good for the input state to an RNN.
151  AGGREGATE_FIRST = "first"
152  # aggregation last takes only the last tag (one with highest sort index).
153  # This is good for an output value on the last stack item of a
154  # static rnn.
155  AGGREGATE_LAST = "last"
156
157  class OpHintArgumentTracker:
158    """Conceptually tracks indices of arguments of "OpHint functions".
159
160    The inputs and arguments of these functions both use an instance
161    of the class so they can have independent numbering.
162    """
163
164    def __init__(self,
165                 function_name,
166                 unique_function_id,
167                 node_name_prefix,
168                 attr_name,
169                 level=1,
170                 children_inputs_mappings=None):
171      """Initialize ophint argument.
172
173      Args:
174        function_name: Name of the function that this tracks arguments for.
175        unique_function_id: UUID of function that this tracks arguments for.
176        node_name_prefix: How identities that are created are named.
177        attr_name: Name of attribute to use to store the index for this hint.
178          i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
179        level: Hierarchical level of the Ophint node, a number.
180        children_inputs_mappings: Inputs/Outputs mapping for children hints.
181      """
182
183      # The global index is the argument index of the op. This is in contrast
184      # to the sort index which is the sequence number of a particular instance
185      # of a given global index. For example, you may have called add hint
186      # twice with the tag "foo". Then the global index will be 0 for both
187      # and the sort index will be 0 for the first added and 1 for the second.
188      self._function_name = function_name
189      self._unique_function_id = unique_function_id
190      self._next_global_index = 0  # The absolute global index
191      self._used_global_indices = set()
192      self._tag_to_global_index = {}  # The argument index a given tag maps to
193      self._tag_to_next_sort_index = {}  # The current index for each tag
194      self._node_name_prefix = node_name_prefix
195      self._attr_name = attr_name
196      self._level = level
197      self._children_inputs_mappings = children_inputs_mappings
198
199    def _get_new_global_index(self, index_override):
200      """Return the next unused argument index in order or use an override.
201
202      Args:
203        index_override: An index to use instead of the next available or None
204          to use the next available.
205
206      Returns:
207        A valid global_index to use for the next hint argument.
208
209      Raises:
210        ValueError: If the index_override is already used by another hint.
211      """
212      if index_override is None:
213        global_index = self._next_global_index
214      else:
215        if index_override in self._used_global_indices:
216          raise ValueError("Index %d was already used by another call to add")
217        global_index = index_override
218      # Make next_global_index valid
219      self._used_global_indices.add(global_index)
220      while self._next_global_index in self._used_global_indices:
221        self._next_global_index += 1
222      return global_index
223
224    def add(self, arg, tag=None, name=None, aggregate=None,
225            index_override=None):
226      """Return a wrapped tensor of an input tensor as an argument.
227
228      Args:
229        arg: A TensorFlow tensor that should be considered an argument.
230        tag: String tag to identify arguments that should be packed.
231        name: Name of argument. This is included in the Identity hint op names.
232        aggregate: Strategy to aggregate.
233        Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
234          and OpHint.AGGREGATE_STACK.
235          Note, aggregate is only valid if tag is specified.
236        index_override: Specify what input/output index should this be in the
237          final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the
238          final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
239          the default call order based ordering.
240
241      Returns:
242        A tensor representing the wrapped argument.
243
244      Raises:
245        ValueError: When indices are not consistent.
246      """
247
248      # Find the appropriate index
249      if tag is None:
250        if aggregate is not None:
251          raise ValueError("You must specify `tag` if using aggregate.")
252        global_index = self._get_new_global_index(index_override)
253        sort_index = None
254      else:
255        if aggregate is None:
256          raise ValueError("You must specify `aggregate` if using tag.")
257        if tag not in self._tag_to_global_index:
258          self._tag_to_global_index[tag] = (
259              self._get_new_global_index(index_override))
260          self._tag_to_next_sort_index[tag] = 0
261        elif (index_override and
262              index_override != self._tag_to_global_index[tag]):
263          raise ValueError(
264              "Tag %r was called with two indices %r and %r" %
265              (tag, index_override, self._tag_to_global_index[tag]))
266        global_index = self._tag_to_global_index[tag]
267        sort_index = self._tag_to_next_sort_index[tag]
268        self._tag_to_next_sort_index[tag] += 1
269
270      uuid = self._unique_function_id
271      name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
272                                    uuid, global_index, sort_index, name)
273
274      identity_op = _array_ops.identity(arg, name=name)
275
276      # pylint: disable=protected-access
277      identity_op.op._set_attr(
278          OpHint.FUNCTION_NAME_ATTR,
279          _attr_value_pb2.AttrValue(
280              s=_compat.as_bytes(self._function_name)))
281      identity_op.op._set_attr(
282          OpHint.FUNCTION_UUID_ATTR,
283          _attr_value_pb2.AttrValue(
284              s=_compat.as_bytes(self._unique_function_id)))
285      identity_op.op._set_attr(
286          self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
287      identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
288                               _attr_value_pb2.AttrValue(i=self._level))
289      if self._children_inputs_mappings:
290        identity_op.op._set_attr(
291            OpHint.CHILDREN_INPUTS_MAPPINGS,
292            _attr_value_pb2.AttrValue(
293                s=_compat.as_bytes(_json.dumps(
294                    self._children_inputs_mappings))))
295
296      if sort_index is not None:
297        identity_op.op._set_attr(
298            OpHint.FUNCTION_SORT_INDEX_ATTR,
299            _attr_value_pb2.AttrValue(i=sort_index))
300      if aggregate is not None:
301        identity_op.op._set_attr(
302            OpHint.FUNCTION_AGGREGATE_ATTR,
303            _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
304      # pylint: enable=protected-access
305      return identity_op
306
307  def __init__(self,
308               function_name,
309               level=1,
310               children_inputs_mappings=None,
311               **kwargs):
312    """Create a OpHint.
313
314    Args:
315      function_name: Name of the function (the custom op name in tflite)
316      level: OpHint level.
317      children_inputs_mappings: Children OpHint inputs/outputs mapping.
318        children_inputs_mappings should like below:
319        "parent_first_child_input":
320            [{"parent_input_index": num, "child_input_index": num}, ...]
321        "parent_last_child_output":
322            [{"parent_output_index": num, "child_output_index": num}, ...]
323        "internal_children_input_output":
324            [{"child_input_index": num, "child_output_index": num}, ...]
325      **kwargs: Keyword arguments of any constant attributes for the function.
326    """
327    self._function_name = function_name
328    self._level = level
329    if self._level == 1:
330      assert children_inputs_mappings is None
331    else:
332      assert isinstance(children_inputs_mappings, dict)
333    self._children_inputs_mappings = children_inputs_mappings
334    if self._children_inputs_mappings is not None:
335      self._validate_children_inputs_mappings(self._children_inputs_mappings)
336    self._unique_function_id = _uuid.uuid1().hex
337    self._attrs_to_store_later = kwargs
338    self._stored_attrs = False
339    self._inputs = OpHint.OpHintArgumentTracker(
340        self._function_name, self._unique_function_id, "InputHint",
341        OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
342    self._outputs = OpHint.OpHintArgumentTracker(
343        self._function_name, self._unique_function_id, "OutputHint",
344        OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
345        self._children_inputs_mappings)
346
347  def _validate_children_inputs_mappings(self, children_inputs_mappings):
348    """Validate children inputs mappings is in the right format.
349
350    Args:
351      children_inputs_mappings: the Children ophint inputs/outputs mapping.
352    """
353    assert isinstance(children_inputs_mappings, dict)
354    assert "parent_first_child_input" in children_inputs_mappings
355    assert "parent_last_child_output" in children_inputs_mappings
356    assert "internal_children_input_output" in children_inputs_mappings
357
358    # validate parent_first_child_input.
359
360    def assert_dictlist_has_keys(dictlist, keys):
361      for dikt in dictlist:
362        assert isinstance(dikt, dict)
363        for key in keys:
364          assert key in dikt
365
366    assert_dictlist_has_keys(
367        children_inputs_mappings["parent_first_child_input"],
368        ["parent_ophint_input_index", "first_child_ophint_input_index"])
369    assert_dictlist_has_keys(
370        children_inputs_mappings["parent_last_child_output"],
371        ["parent_output_index", "child_output_index"])
372    assert_dictlist_has_keys(
373        children_inputs_mappings["internal_children_input_output"],
374        ["child_input_index", "child_output_index"])
375
376  def _setattr(self, dest_op, name, value):
377    tensor_value = _ops.convert_to_tensor(value)
378    # pylint: disable=protected-access
379    dest_op.op._set_attr(name, _attr_value_pb2.AttrValue(
380        tensor=tensor_value.op.node_def.attr["value"].tensor))
381    # pylint: enable=protected-access
382
383  def add_input(self, *args, **kwargs):
384    """Add a wrapped input argument to the hint.
385
386    Args:
387      *args: The input tensor.
388      **kwargs:
389        "name" label
390        "tag" a tag to group multiple arguments that will be aggregated. I.e.
391          a string like 'cool_input'. Basically multiple inputs can be added
392          to the same hint for parallel operations that will eventually be
393          combined. An example would be static_rnn which creates multiple copies
394          of state or inputs.
395        "aggregate" aggregation strategy that is valid only for tag non None.
396          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
397          and OpHint.AGGREGATE_STACK.
398        "index_override" The global index to use. This corresponds to the
399          argument order in the final stub that will be generated.
400    Returns:
401      The wrapped input tensor.
402    """
403    return self._inputs.add(*args, **kwargs)
404
405  def add_output(self, *args, **kwargs):
406    """Add a wrapped output argument to the hint.
407
408    Args:
409      *args: The output tensor.
410      **kwargs:
411        "name" label
412        "tag" a tag to group multiple arguments that will be aggregated. I.e.
413          a string like 'cool_input'. Basically multiple inputs can be added
414          to the same hint for parallel operations that will eventually be
415          combined. An example would be static_rnn which creates multiple copies
416          of state or inputs.
417        "aggregate" aggregation strategy that is valid only for tag non None.
418          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
419          and OpHint.AGGREGATE_STACK.
420        "index_override" The global index to use. This corresponds to the
421          argument order in the final stub that will be generated.
422    Returns:
423      The wrapped output tensor.
424    """
425    return self._outputs.add(*args, **kwargs)
426
427  def add_inputs(self, *args, **kwargs):
428    """Add a sequence of inputs to the function invocation.
429
430    Args:
431      *args: List of inputs to be converted (should be Tf.Tensor).
432      **kwargs: This allows 'names' which should be a list of names.
433
434    Returns:
435      Wrapped inputs (identity standins that have additional metadata). These
436      are also are also tf.Tensor's.
437    """
438    if "names" in kwargs:
439      return [
440          self._inputs.add(arg, name=name)
441          for arg, name in zip(args, kwargs["names"])
442      ]
443    else:
444      return [self._inputs.add(arg) for arg in args]
445
446  def add_outputs(self, *args, **kwargs):
447    """Add a sequence of outputs to the function invocation.
448
449    Args:
450      *args: List of outputs to be converted (should be tf.Tensor).
451      **kwargs: See
452
453    Returns:
454      Wrapped outputs (identity standins that have additional metadata). These
455      are also tf.Tensor's.
456    """
457    if "names" in kwargs:
458      return [
459          self._outputs.add(arg, name=name)
460          for arg, name in zip(args, kwargs["names"])
461      ]
462    else:
463      return [self._outputs.add(arg) for arg in args]
464
465
466class _LiteOperand:
467  """Abstract operand for a tflite hint function._dynamic_rnn_loop.
468
469  This is a base class that handles representing arguments to an OpHint.
470  It also is able to serialize operands to the stubbed graph_def.
471  Child classes are responsible for being able to
472  store information about the hint identity operators. They are also responsible
473  for knowing how to serialize to output graphdefs.
474
475  Typically this will be implemented by holding one or more identity nodes
476  that were previously discovered as hints.
477  """
478
479  def aggregate_and_return_name_for_input(self, out_graphdef):
480    """This adds the node(s) to out_graphdef and returns the input node name.
481
482    Args:
483      out_graphdef: A graphdef that is ready to have this input added.
484
485    Returns:
486      The output that the stub should use as an input for this operand.
487
488    Raises:
489      RuntimeError: if the method is not implemented.
490    """
491    del out_graphdef
492    raise RuntimeError("Unimplemented abstract method.")
493
494  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
495                                           out_graphdef):
496    """Add node(s) to graph representing output operands and returns type.
497
498    Args:
499      fused_op_name: name of the fused op stub name.
500      output_index: Output index that we are currently processing from stub.
501      out_graphdef: The destination graphdef we are currently building up.
502
503    Returns:
504      The datatype of this identity.
505
506    Raises:
507      RuntimeError: if the method is not implemented.
508    """
509    del fused_op_name, output_index, out_graphdef
510    raise RuntimeError("Unimplemented abstract method.")
511
512
513class _LiteSingleOperand(_LiteOperand):
514  """A simple operand that is non-aggregated (i.e. most hints)."""
515
516  def __init__(self, node):
517    _LiteOperand.__init__(self)
518    self.node = node
519    self.name = _tensor_name_base(node.name)
520
521  def flatten(self):
522    return [self.name]
523
524  def aggregate_and_return_name_for_input(self, out_graphdef):
525    return self.name
526
527  def aggregate_and_return_name_for_output(self, fused_op_name, index,
528                                           out_graphdef):
529    output_node = _copy.deepcopy(self.node)
530    del output_node.input[:]
531    output_node.input.append(_tensorflow_output_name(fused_op_name, index))
532    out_graphdef.node.extend([output_node])
533    return self.node.attr["type"].i
534
535  def __str__(self):
536    return str(self.name)
537
538
539class _LiteAggregateOperand(_LiteOperand):
540  """An operand for a tflite hint function that is aggregated from many.
541
542  For example, an LSTM is a grid of operators that are all related. Inputs
543  going into them may need to be fused, so they should all be tracked as
544  related arguments.
545  """
546
547  def __init__(self, aggregation):
548    _LiteOperand.__init__(self)
549    self.aggregation = aggregation
550    self.names = {}
551    self.nodes = {}
552    self.flattened = None
553
554  def add(self, sort, node):
555    self.names[sort] = _tensor_name_base(node.name)
556    self.nodes[sort] = node
557
558  def flatten_nodes(self):
559    """Return a list of all the node protos in aggregation sorted order."""
560    if not self.flattened:
561      self.flattened = [None] * len(self.nodes)
562      for idx, node in self.nodes.items():
563        self.flattened[idx] = node
564      for n in self.nodes:
565        if n is None:
566          raise RuntimeError("Aggregate was missing argument.")
567      if self.aggregation == OpHint.AGGREGATE_FIRST:
568        self.flattened = self.flattened[:1]
569      elif self.aggregation == OpHint.AGGREGATE_LAST:
570        self.flattened = self.flattened[-1:]
571      elif self.aggregation == OpHint.AGGREGATE_STACK:
572        pass
573      else:
574        raise ValueError("Invalid aggregation type %r specified" %
575                         self.aggregation)
576    return self.flattened
577
578  def flatten(self):
579    """Return a list of all node names in aggregation sorted sorter."""
580    return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
581
582  def aggregate_and_return_name_for_input(self, out_graphdef):
583    """This adds the nodes to out_graphdef and returns an aggregated output.
584
585    In particular, if you have 4 inputs to a hint stub, this will be the
586    node that you can use as an output. I.e. you have 4 timesteps from a
587    static rnn, then a fused UnidirectionalLSTM will expect 1 input with
588    all 4 time steps. So here we make a pack and return the output name of
589    that pack.
590
591    Args:
592      out_graphdef: A graphdef that is ready to have this input added.
593
594    Returns:
595      The name of a pack that aggregates this node.
596    """
597    flattened = self.flatten_nodes()
598    if (self.aggregation == OpHint.AGGREGATE_FIRST) or (
599        self.aggregation == OpHint.AGGREGATE_LAST):
600      assert len(flattened) == 1
601    if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK:
602      return _tensor_name_base(flattened[0].name)
603    else:
604      new_node = _node_def_pb2.NodeDef()
605      new_node.op = "Pack"
606      new_node.name = "OpHintStack-%s" % flattened[0].name
607      new_node.attr["N"].i = len(flattened)
608      new_node.attr["T"].type = flattened[0].attr["T"].type
609      for discrete in flattened:
610        new_node.input.append(_tensor_name_base(discrete.name))
611      out_graphdef.node.extend([new_node])
612      return new_node.name
613
614  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
615                                           out_graphdef):
616    """This adds to `out_graphdef` all the unaggregated outputs.
617
618    I.e. we are outputting from a fused stub, but we need to make it compatible
619    with the unfused original graph so we insert an unpack. Ideally in a later
620    stage the unpack -> pack sequences will be removed.
621
622    Args:
623      fused_op_name: The name of the stub we are in the process of fusing.
624      output_index: The output output_index this object represents.
625      out_graphdef: The graphdef we are in the process of buildings
626
627    Returns:
628      The type of the aggregated output (so we can finish building the stub
629      op).
630    """
631    flattened = self.flatten_nodes()
632    if (self.aggregation == OpHint.AGGREGATE_FIRST) or (
633        self.aggregation == OpHint.AGGREGATE_LAST):
634      assert len(flattened) == 1
635    if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK:
636      temp_op = _LiteSingleOperand(flattened[0])
637      return temp_op.aggregate_and_return_name_for_output(
638          fused_op_name, output_index, out_graphdef)
639    else:
640      stack_node = _node_def_pb2.NodeDef()
641      stack_node.op = "Unpack"
642      stack_node.name = "OpHintUnstack-%s" % flattened[0].name
643      stack_node.attr["num"].i = len(flattened)
644      output_type = flattened[0].attr["T"].type
645      stack_node.attr["T"].type = output_type
646      stack_node.input.append(
647          _tensorflow_output_name(fused_op_name, output_index))
648      out_graphdef.node.extend([stack_node])
649
650      for idx, discrete in enumerate(flattened):
651        output_node = _copy.deepcopy(discrete)
652        del output_node.input[:]
653        output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
654        out_graphdef.node.extend([output_node])
655
656      return output_type
657
658  def __str__(self):
659    s = "\t\t\tAGGREGATE %s\n" % self.aggregation
660    for sort, val in self.names.iteritems():
661      s += "\t\t\t%d: %s\n" % (sort, val)
662    return s
663
664
665class _LiteFuncCall:
666  """Represent a TensorFlow Lite custom function.
667
668  This is uses to accumulate found hints in the graphdef into a single
669  conceptual unit.
670
671  Attributes:
672    inputs: inputs to the op (hash from index # to argument)
673    outputs: outputs to the op (hash from index # to argument)
674    function_name: the tflite custom op name to use
675    uuid: a unique call id for this particular call  (i.e. multiple function
676      calls would have the same function_name but different uuids.
677    params: A param name to key value for op constant data. I.e. for axis on a
678      reduction, strides on a convolution, etc.
679    level: Level of the OpHint.
680    children_inputs_mappings: If the Ophint has children, children inputs
681      mappings indicate how their inputs & outputs are mapped.
682  """
683
684  def __init__(self):
685    self.inputs = {}
686    self.outputs = {}
687    self.function_name = None
688    self.uuid = None
689    self.params = {}
690    self.level = -1
691    self.children_inputs_mappings = {}
692
693  def flattened_inputs_and_outputs(self):
694    """Return a list of inputs and outputs in a flattened format.
695
696    Returns:
697      Tuple of (inputs, outputs). where input and output i a list of names.
698    """
699
700    def _flatten(input_or_output_dict):
701      flattened_items = []
702      for item in input_or_output_dict.values():
703        flattened_items.extend(item.flatten())
704      return flattened_items
705
706    return _flatten(self.inputs), _flatten(self.outputs)
707
708  def __str__(self):
709
710    def format_args(items):
711      s = ""
712      for idx, item in items.iteritems():
713        s += ("\t\t%d:\n" % idx) + str(item)
714      return s
715
716    inputs_str = "\tInputs\n" + format_args(self.inputs)
717    outputs_str = "\tOutputs\n" + format_args(self.outputs)
718
719    return (
720        "tflite function %s call %s level %d "
721        "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
722        (self.function_name, self.uuid, self.level, inputs_str, outputs_str))
723
724
725def _find_all_hints_in_nodes(nodes):
726  """Look at the all the input nodes and return a list of LiteFuncCall objs.
727
728  Args:
729    nodes: A TensorFlow graph_def to look for LiteFuncCalls.
730
731  Returns:
732    a list of `LifeFuncCall` objects in the form
733
734  """
735  func_calls = _collections.defaultdict(_LiteFuncCall)
736
737  for node in nodes:
738    attr = node.attr
739    # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
740    if (OpHint.FUNCTION_UUID_ATTR not in attr or
741        not attr[OpHint.FUNCTION_UUID_ATTR].s):
742      continue
743    uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
744
745    # Start building function
746    call_def = func_calls[uuid]
747    call_def.uuid = uuid
748    call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
749    call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
750    # Get sorting and aggregation information
751
752    sort = (
753        attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
754        if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
755    if sort == -1:
756      sort = None
757    aggregation = None
758    if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
759      aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
760
761    if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
762      call_def.children_inputs_mappings = _json.loads(
763          _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
764
765    # Add the input or output
766    def put_operand(stuff, index, sort, operand, aggregation):
767      """Add a given index into the function structure."""
768      if sort is None:
769        stuff[index] = _LiteSingleOperand(operand)
770      else:
771        if index not in stuff:
772          stuff[index] = _LiteAggregateOperand(aggregation)
773        stuff[index].add(sort, operand)
774
775    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
776      put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
777                  sort, node, aggregation)
778    if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
779      put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
780                  sort, node, aggregation)
781
782    # Remember attributes
783    for a in attr:
784      if a.startswith("_tflite_attr_"):
785        call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
786
787  return func_calls
788
789
790def _extract_topology_sequence_mapping(nodes):
791  return dict(
792      (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
793
794
795def _find_children_hints_in_while_loop(function_def, nodes_mapping):
796  """Find children hints and all nodes inside the while loop.
797
798  Args:
799    function_def: Function def of the while loop.
800    nodes_mapping: While loop input_arg : real node name.
801
802  Returns:
803    Ordered children hints and all re-mapped nodes inside the while loop.
804  """
805  new_nodes = []
806
807  # Make nodes inside function def inputs point to the real nodes.
808  for node in function_def.node_def:
809    for i, _ in enumerate(node.input):
810      if node.input[i] in nodes_mapping:
811        node.input[i] = nodes_mapping[node.input[i]]
812    new_nodes.append(_copy.deepcopy(node))
813  name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
814  children_hints = _find_all_hints_in_nodes(new_nodes)
815  children_hints_q = []
816  # Ordered by the outputs.
817  for hint in children_hints.values():
818    _, output_names = hint.flattened_inputs_and_outputs()
819    seq = name_to_seq_num[output_names[0]]
820    for output_name in output_names:
821      seq = min(seq, name_to_seq_num[output_name])
822    children_hints_q.append((seq, hint))
823  children_hints_q.sort(key=lambda tup: tup[0])
824  ordered_children_hints = [x[1] for x in children_hints_q]
825  return ordered_children_hints, new_nodes
826
827
828def _find_children_hints(call, graph_def):
829  """Find all children hints.
830
831  For a given OpHint, we find all children hints inside it, we also copy all the
832  nodes inside function defs (if applicable) to the original graph_def, they are
833  returned in a list as well.
834
835  Args:
836    call: Parent OpHint that contains children ophints.
837    graph_def: Original graph def.
838
839  Returns:
840    Ordered children hints inside the parent ophint; new graph def that contains
841    nodes inside function defs (if applicable); nodes inside function defs.
842  """
843  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
844  input_names, output_names = call.flattened_inputs_and_outputs()
845
846  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
847  reachable_by_output = _bfs_for_reachable_nodes(output_names,
848                                                 name_to_input_name)
849  output_nodes_set = set(output_names)
850  children_hints = []
851  out = _graph_pb2.GraphDef()
852  out.library.CopyFrom(graph_def.library)
853  out.versions.CopyFrom(graph_def.versions)
854  function_def_nodes = set()
855  for node in graph_def.node:
856    out.node.extend([_copy.deepcopy(node)])
857    n = _tensor_name_base(node.name)
858    if n in reachable_by_output:
859      if n not in reachable_by_input and n not in output_nodes_set:
860        # special handle for while loop function def.
861        if node.op == "While" or node.op == "StatelessWhile":
862          body_name = node.attr["body"].func.name
863          inputs_outside_loop = node.input
864          for function_def in graph_def.library.function:
865            if function_def.signature.name == body_name:
866              function_inputs = function_def.signature.input_arg
867              assert len(inputs_outside_loop) == len(function_inputs)
868              nodes_mapping = {}
869              for i, function_input in enumerate(function_inputs):
870                nodes_mapping[function_input.name] = inputs_outside_loop[i]
871              (children_hints_in_loop,
872               new_nodes) = _find_children_hints_in_while_loop(
873                   function_def, nodes_mapping)
874              function_def_nodes.update([x.name for x in new_nodes])
875              children_hints.extend(children_hints_in_loop)
876              out.node.extend(new_nodes)
877
878  return children_hints, out, function_def_nodes
879
880
881def _tensor_name_base(full_tensor_name):
882  """Removes the device assignment code from a tensor.
883
884  e.g. _tensor_name_base("foo:3") => "foo"
885
886  Args:
887    full_tensor_name: A tensor name that is annotated with a device placement
888      (this is what tensor flow introspection gives).
889
890  Returns:
891    A name without any device assignment.
892  """
893  if full_tensor_name.startswith("^"):
894    return full_tensor_name[1:]
895  return full_tensor_name.split(":")[0]
896
897
898def _tensorflow_output_name(tensor_name, output_index):
899  return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
900                                                          output_index)
901
902
903def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
904                           name_to_input_name):
905  """Checks to make sure node only connects to predecessor graph through inputs.
906
907  Args:
908    n: Node to check
909    reachable_by_input: Nodes that are reachable by all inputs of subgraph
910    input_nodes_set: The set of nodes that are "inputs".
911    name_to_input_name: Maps from name to the list of inputs.
912
913  Raises:
914    TypeError: If the given node uses items past inputs directly.
915  """
916  next_to_visit = [n]
917  visited = set()
918  while next_to_visit:
919    current_node = next_to_visit.pop()
920    visited.add(current_node)
921    if (current_node in reachable_by_input and
922        current_node not in input_nodes_set):
923      raise TypeError("Node %s uses input %s not in input_nodes." %
924                      (n, current_node))
925    if current_node not in input_nodes_set:
926      next_to_visit += [
927          input_node for input_node in name_to_input_name[current_node]
928          if input_node not in visited
929      ]
930
931
932def _convert_single_op_hint_to_stub(call,
933                                    graph_def,
934                                    function_def_nodes=None,
935                                    is_last_run=True):
936  """Given a graph_def, converts `call` into a stub and returns a new graph_def.
937
938  Args:
939    call: A single function call to be converted.
940    graph_def: A graph_def to use as input (that has call obviously).
941    function_def_nodes: Nodes inside the function def those are not connected to
942      the graph.
943    is_last_run: Whether it is the last run for a given pass (for OpHint has
944      children).
945
946  Returns:
947    A new transformed graph-def that has call as a stub (single op).
948
949  Note: after this process, the graph_def can no longer be loaded into
950      the tensorflow runtime, so all future manipulations are done in graph_def
951      level.
952  """
953  if function_def_nodes is None:
954    function_def_nodes = set()
955  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
956      graph_def)
957  input_names, output_names = call.flattened_inputs_and_outputs()
958
959  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
960  reachable_by_output = _bfs_for_reachable_nodes(output_names,
961                                                 name_to_input_name)
962  output_nodes_set = set(output_names)
963  nodes_after_fuse = []
964  nodes_deleted_by_fuse = set()
965  # Classify each node. We want to keep everything reachable by input, but
966  # we don't know if things that are not reachable by output or input (things
967  # after fusing).
968  for node in graph_def.node:
969    n = _tensor_name_base(node.name)
970    if n in reachable_by_output:
971      if n not in reachable_by_input and n not in output_nodes_set:
972        nodes_deleted_by_fuse.add(n)
973    elif n not in reachable_by_input and n not in function_def_nodes:
974      # n is a node that after all the fusings, so keep it.
975      nodes_after_fuse.append(n)
976    else:
977      # In the last run, n is a node that is randomly in the graph but not
978      # connected to the chain of dependencies, we will delete n, otherwise
979      # we keep them.
980      if not is_last_run:
981        nodes_after_fuse.append(n)
982
983  # Make a new graphdef with all the pre-input and input nodes
984  out = _graph_pb2.GraphDef()
985  reachable_by_input_sorted = sorted(
986      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
987  for node in reachable_by_input_sorted:
988    out.node.extend([_copy.deepcopy(name_to_node[node])])
989
990  # Create any stacks to aggregate arguments into to a single input
991  # i.e. for static_rnn's.
992  sorted_input_indices = list(call.inputs.keys())
993  sorted_input_indices.sort()
994  sorted_output_indices = list(call.outputs.keys())
995  sorted_output_indices.sort()
996  new_node = _node_def_pb2.NodeDef()
997  # Delegate to each operand to produce the proper new input for this stub node.
998  # In particular, an aggregate input will now be a Pack of some previously
999  # non-fused things.
1000
1001  optional_input_node = _node_def_pb2.NodeDef()
1002  optional_input_node.name = "Const" + str(_uuid.uuid1().hex)
1003  optional_input_node.op = "Const"
1004  optional_input_node.attr["dtype"].CopyFrom(
1005      _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum))
1006  optional_input_node.attr["value"].CopyFrom(
1007      _attr_value_pb2.AttrValue(
1008          tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1])))
1009  out.node.extend([optional_input_node])
1010
1011  max_index = max(sorted_input_indices) + 1
1012  for cur_index in range(max_index):
1013    if cur_index in sorted_input_indices:
1014      inputs = call.inputs[cur_index]
1015      input_name = inputs.aggregate_and_return_name_for_input(out)
1016      new_node.input.append(input_name)
1017    else:
1018      new_node.input.append(optional_input_node.name)
1019
1020  new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
1021
1022  # Create the function
1023  new_node.op = call.function_name
1024  new_node.name = call.uuid
1025  out.node.extend([new_node])
1026
1027  # Now call each output argument to give them a chance to make the proper
1028  # output type and add it to our new_node.
1029  output_dtypes = []
1030  max_output_index = max(sorted_output_indices) + 1
1031  for cur_index in range(max_output_index):
1032    if cur_index in sorted_output_indices:
1033      output = call.outputs[cur_index]
1034      output_dtype = (
1035          output.aggregate_and_return_name_for_output(new_node.name, cur_index,
1036                                                      out))
1037    else:
1038      output_dtype = optional_input_node.attr["type"].i
1039    output_dtypes.append(output_dtype)
1040  new_node.attr["_output_types"].list.type[:] = output_dtypes
1041  new_node.attr["_output_quantized"].b = False
1042
1043  # Add post output nodes that do not depend on the outputs
1044  for n in nodes_after_fuse:
1045    should_keep = True
1046    for input_name in name_to_input_name[n]:
1047      if input_name in nodes_deleted_by_fuse:
1048        should_keep = False
1049    if should_keep:
1050      out.node.extend([_copy.deepcopy(name_to_node[n])])
1051
1052  # Misc. graph_def data that needs copying.
1053  out.library.CopyFrom(graph_def.library)
1054  out.versions.CopyFrom(graph_def.versions)
1055
1056  return out
1057
1058
1059def _remove_one_redundant_stack_unstack(in_graph_def):
1060  """Removes a stack->unstack pattern from in_graph_def in a returned graph.
1061
1062  Args:
1063    in_graph_def: Graph def to use as input.
1064
1065  Returns:
1066    Simplified tuple (graph_def, changed_something) where changed_something
1067    is true if anything was done.
1068  """
1069  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
1070      in_graph_def)
1071  del name_to_seq_num
1072
1073  do_generic_pack_unpack = True
1074
1075  out = _graph_pb2.GraphDef()
1076  out.library.CopyFrom(in_graph_def.library)
1077  out.versions.CopyFrom(in_graph_def.versions)
1078  for n in in_graph_def.node:
1079    node_name = _tensor_name_base(n.name)
1080    if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
1081      continue
1082    next_to_visit = [node_name]
1083    visited = set()
1084
1085    unpack_nodes = set()
1086    pack_node = node_name
1087
1088    # Find a pattern of unstack connected to a stack (with identities
1089    # in between.
1090    matches_pattern = True
1091    is_hint_created_stack = False
1092    while next_to_visit:
1093      current_node_name = next_to_visit[0]
1094      visited.add(current_node_name)
1095      del next_to_visit[0]
1096      node = name_to_node[current_node_name]
1097      is_op_hint_stack = node.name.startswith("OpHintStack")
1098      is_op_hint_unstack = node.name.startswith("OpHintUnstack")
1099      if (node.op == "Identity" or is_op_hint_stack or
1100          (do_generic_pack_unpack and node.op == "Pack")):
1101        is_hint_created_stack |= is_op_hint_stack
1102        next_to_visit += [
1103            input_node for input_node in name_to_input_name[current_node_name]
1104            if input_node not in visited
1105        ]
1106      elif (is_op_hint_unstack or
1107            (do_generic_pack_unpack and node.op == "Unpack")):
1108        unpack_nodes.add(node.name)
1109        is_hint_created_stack &= is_op_hint_unstack
1110      else:
1111        matches_pattern = False
1112        break
1113      visited.add(node.name)
1114
1115    if matches_pattern and len(unpack_nodes) == 1:
1116      pack_node = node_name
1117
1118      # Check to see if anyone depends on the intermediate identity or the
1119      # Unstacked form
1120      no_external_dependency = True
1121      for other_n in in_graph_def.node:
1122        if other_n.name in visited:
1123          continue
1124        for input_tensor in name_to_input_name[other_n.name]:
1125          input_op = _tensor_name_base(input_tensor)
1126          if input_op in visited and input_op != pack_node:
1127            no_external_dependency = False
1128      # Proceed with the substitution if the stack/unstack pair was created
1129      # through hints, or that it was not, but nobody is consuming things
1130      # between the stack and unstack.
1131      if is_hint_created_stack or no_external_dependency:
1132        end = unpack_nodes.pop()
1133        end_input = name_to_node[end].input[0]
1134        # All nodes that depend on the final stack need to be redone to use
1135        for other_n in in_graph_def.node:
1136          node_name = _tensor_name_base(other_n.name)
1137          if node_name not in visited:
1138            new_node = _copy.deepcopy(other_n)
1139            new_node.input[:] = [
1140                (end_input if stripped == pack_node else non_stripped)
1141                for stripped, non_stripped in zip(name_to_input_name[node_name],
1142                                                  new_node.input[:])
1143            ]
1144            out.node.extend([new_node])
1145        return out, True
1146  return in_graph_def, False
1147
1148
1149def _remove_redundant_stack_unstack(graph_def):
1150  curr = graph_def
1151  del graph_def
1152  changed_stuff = True
1153  while changed_stuff:
1154    curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
1155  return curr
1156
1157
1158def _get_correct_mapping(original_index, nodes):
1159  # Special handle for the index is -1 case.
1160  # If it is -1, return the last index.
1161  if original_index == -1:
1162    node_indices = nodes.keys()
1163    node_indices = sorted(node_indices)
1164    return node_indices[-1]
1165  return original_index
1166
1167
1168def _convert_op_hints_to_stubs_helper(
1169    graph_def, write_callback=lambda sess, graph_def: None):
1170  """Converts a graph_def to a new graph_def where all op hints are stubbed.
1171
1172  Args:
1173    graph_def: A graph def that we should convert.
1174    write_callback: A function pointer that can be used to write intermediate
1175      steps of graph transformation (optional).
1176
1177  Returns:
1178    A new stubbed graph_def.
1179  """
1180  hints = _find_all_hints_in_nodes(graph_def.node)
1181
1182  hints_q = []
1183  for hint in hints.values():
1184    hints_q.append((hint.level, hint.uuid))
1185
1186  hints_q.sort(key=lambda tup: tup[0])
1187  for i in range(len(hints_q) - 1, -1, -1):
1188    level, hint_uuid = hints_q[i]
1189
1190  curr_graph_def = graph_def
1191  del graph_def  # prevent using graph_def again (common source of error)
1192  for i in range(len(hints_q) - 1, -1, -1):
1193    level, hint_uuid = hints_q[i]
1194    if level >= 2:
1195      children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
1196          hints[hint_uuid], curr_graph_def)
1197      # pylint: disable=superfluous-parens
1198      assert (len(children_hints) > 0)  #  pylint: disable=g-explicit-length-test
1199      # pylint: enable=superfluous-parens
1200
1201      # Re-wire the children hints inputs/outputs, so latter child's inputs
1202      # connect to previous child node's outputs.
1203      children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
1204      for j, child_hint in enumerate(children_hints):
1205        if j == 0:
1206          for mapping in children_inputs_mappings["parent_first_child_input"]:
1207            parent_input_index = _get_correct_mapping(
1208                mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
1209            child_input_index = _get_correct_mapping(
1210                mapping["first_child_ophint_input_index"], child_hint.inputs)
1211            child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
1212                parent_input_index]
1213        else:
1214          for mapping in children_inputs_mappings[
1215              "internal_children_input_output"]:
1216            input_index = _get_correct_mapping(mapping["child_input_index"],
1217                                               child_hint.inputs)
1218            output_index = _get_correct_mapping(mapping["child_output_index"],
1219                                                children_hints[j - 1].outputs)
1220            child_hint.inputs[input_index] = children_hints[
1221                j - 1].outputs[output_index]
1222        if j == len(children_hints) - 1:
1223          for mapping in children_inputs_mappings["parent_last_child_output"]:
1224            parent_output_index = _get_correct_mapping(
1225                mapping["parent_output_index"], hints[hint_uuid].outputs)
1226            child_output_index = _get_correct_mapping(
1227                mapping["child_output_index"], child_hint.outputs)
1228            child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
1229                parent_output_index]
1230
1231      for j, child_hint in enumerate(children_hints):
1232        curr_graph_def = _convert_single_op_hint_to_stub(
1233            child_hint, curr_graph_def, function_def_nodes,
1234            j == len(children_hints) - 1)
1235    else:
1236      curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
1237                                                       curr_graph_def)
1238      write_callback(curr_graph_def, "initial")
1239  # The stubbing process can create stacks/unstacks in the case of LSTMs
1240  # remove them.
1241  curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
1242  return curr_graph_def
1243
1244
1245def find_all_hinted_output_nodes(session=None, graph_def=None):
1246  """Find all Ophints output nodes in the graph.
1247
1248  This is used to get all the output nodes those are ophinted, it is important
1249  for operation like convert_variables_to_constants keep all ophints structure.
1250  Note: only one of session or graph_def should be used, not both.
1251  Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can
1252  generate multiple outputs for unfused subgraph. If not all output nodes are
1253  consumed, graph optimization can potentially drop the unused nodes and cause
1254  ophints in an invalid states (due to missing ophinted output nodes). So it's
1255  important for us to find all those hinted output nodes and make sure they're
1256  not discarded away.
1257
1258  Args:
1259    session: A TensorFlow session that contains the graph to convert.
1260    graph_def: A graph def that we should convert.
1261
1262  Returns:
1263    A list of OpHints output nodes.
1264  Raises:
1265    ValueError: If both session and graph_def are provided.
1266  """
1267  if session is not None and graph_def is not None:
1268    raise ValueError("Provide only one of session and graph_def.")
1269  hinted_outputs_nodes = []
1270  if session is not None:
1271    hints = _find_all_hints_in_nodes(session.graph_def.node)
1272  elif graph_def is not None:
1273    hints = _find_all_hints_in_nodes(graph_def.node)
1274  for hint in hints.values():
1275    _, output_nodes = hint.flattened_inputs_and_outputs()
1276    hinted_outputs_nodes.extend(output_nodes)
1277  return hinted_outputs_nodes
1278
1279
1280def is_ophint_converted(graph_def):
1281  if graph_def is None:
1282    raise ValueError("Must provide the graph_def.")
1283  ophint_converted = False
1284  for node in graph_def.node:
1285    attr = node.attr
1286    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
1287      ophint_converted = True
1288      break
1289  return ophint_converted
1290
1291
1292@_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"])
1293@_deprecation.deprecated(
1294    None,
1295    "Please follow instructions under "
1296    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
1297    "fusion in tflite."
1298)
1299def convert_op_hints_to_stubs(session=None,
1300                              graph_def=None,
1301                              write_callback=lambda graph_def, comments: None):
1302  """Converts a graphdef with LiteOp hints into stub operations.
1303
1304  This is used to prepare for toco conversion of complex intrinsic usages.
1305  Note: only one of session or graph_def should be used, not both.
1306
1307  Args:
1308    session: A TensorFlow session that contains the graph to convert.
1309    graph_def: A graph def that we should convert.
1310    write_callback: A function pointer that can be used to write intermediate
1311      steps of graph transformation (optional).
1312
1313  Returns:
1314    A new graphdef with all ops contained in OpHints being replaced by
1315    a single op call with the right parameters.
1316  Raises:
1317    ValueError: If both session and graph_def are provided.
1318  """
1319
1320  if session is not None and graph_def is not None:
1321    raise ValueError("Provide only one of session and graph_def.")
1322
1323  if session is not None:
1324    return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
1325  elif graph_def is not None:
1326    return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
1327  else:
1328    raise ValueError("Must specify session or graph_def as input.")
1329
1330
1331_allowed_symbols = [
1332    "OpHint",
1333    "convert_op_hints_to_stubs",
1334    "convert_op_hints_to_stubs_new",
1335    "find_all_hinted_output_nodes",
1336    "is_ophint_converted",
1337]
1338remove_undocumented(__name__, _allowed_symbols)
1339