• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Define tflite op hints (intrinsic operations).
16
17This essentially allows defining a TensorFlow API for tflite operations in
18Python with hints on how they are represented in TensorFlow Lite. This basically
19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution
20graph and is useful for LSTMs and other complicated TensorFlow constructions
21that are difficult to pattern match in TOCO, but are represented by a single
22accelerated tflite op.
23
24Example:
25  def tflite_cool_activation(input):
26    # A cool activation function.
27    custom = tf.lite.OpHint("cool_activation")
28    input, = custom.add_inputs(input)
29    output = tf.sigmoid(input) * input
30    output, = custom.add_outputs(output)
31    return output
32
33  image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1))
34  output = tf.identity(tflite_cool_activation(image))
35
36  session = tf.compat.v1.Session()
37
38  graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session)
39  tflite_graph = tf.compat.v1.lite.toco_convert(
40      graphdef_to_convert, [image], [output], allow_custom_ops=True)
41  with open("/tmp/graph.fb", "wb") as fp:
42    fp.write(tflite_graph)
43
44How does it work?:
45
46OpHint is a helper that you use when defining a vanilla python function.
47It allows you to wrap arguments with tf.identities with some custom attributes.
48These attributes allow you to find the original block of ops that was created.
49For example, if you use cool_activation above you essentially get:
50
51a_input = tf.identity()
52result = tf.multiply(tf.sigmoid(a_input), a_input)
53output = tf.identity()
54
55a_input, output are identities that have parameters representing
56what argument they are, what the name of the function they should turn into
57in tf lite as well as a guid that uniquely identifies a particular invocation.
58
59Once you have built your whole tensorflow graph, you can run it and train it
60as usual, but after you have done that, you need to convert the graph into
61a form that replaces these subgraphs wrapped in identities to stub ops. These
62ops don't actually exist in the normal TensorFlow runtime, but will be
63understood by toco later. The generated TensorFlow Lite flatbuffer file will
64contain a custom operator called "cool_activation". Developer needs to implement
65and register this operator in TensorFlow Lite in order to do inference.
66"""
67
68from __future__ import absolute_import
69from __future__ import division
70from __future__ import print_function
71
72import collections as _collections
73import copy as _copy
74import json as _json
75import uuid as _uuid
76import six as _six
77
78from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2
79from tensorflow.core.framework import graph_pb2 as _graph_pb2
80from tensorflow.core.framework import node_def_pb2 as _node_def_pb2
81from tensorflow.python.framework import dtypes as _dtypes
82from tensorflow.python.framework import ops as _ops
83from tensorflow.python.framework import tensor_util as _tensor_util
84from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
85from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
86from tensorflow.python.ops import array_ops as _array_ops
87from tensorflow.python.util import compat as _compat
88from tensorflow.python.util import deprecation as _deprecation
89from tensorflow.python.util.all_util import remove_undocumented
90from tensorflow.python.util.tf_export import tf_export as _tf_export
91
92
93@_tf_export(v1=["lite.OpHint"])
94@_deprecation.deprecated(
95    None,
96    "Please follow instructions under "
97    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
98    "fusion in tflite."
99)
100class OpHint(object):
101  """A class that helps build tflite function invocations.
102
103  It allows you to take a bunch of TensorFlow ops and annotate the construction
104  such that toco knows how to convert it to tflite. This embeds a pseudo
105  function in a TensorFlow graph. This allows embedding high-level API usage
106  information in a lower level TensorFlow implementation so that an alternative
107  implementation can be substituted later.
108
109  Essentially, any "input" into this pseudo op is fed into an identity, and
110  attributes are added to that input before being used by the constituent ops
111  that make up the pseudo op. A similar process is done to any output that
112  is to be exported from the current op.
113
114  """
115  # Attr constants that are used for representation in the GraphDef. These
116  # will be used on every Identity op that is involved in a total OpHint.
117
118  # Name of the OpHint function (cosmetic).
119  FUNCTION_NAME_ATTR = "_tflite_function_name"
120  # UUID of the function (each OpHint gets a new uuid).
121  FUNCTION_UUID_ATTR = "_tflite_function_uuid"
122  # The input index of the input (or nothing if it is an output).
123  FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index"
124  # The output index of the output (or nothing if it is an input).
125  FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index"
126  # An index that orders aggregate arguments. Aggregate arguments are ones
127  # that are separate but will be fused horizontally. For example a static LSTM
128  # has a lstm cell for each time step. Each one has a separate opHint, but a
129  # fused SequentialLSTM will treat this as a single tensor.
130  FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index"
131  # The way in which multiple parts of the aggregate argument will be joined
132  # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST,
133  # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK.
134  FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate"
135  # On fused OpHint stub, the order of inputs that the final LSTM call will
136  # have. What this means is that the TensorFlow order might be
137  # "foo", "bar", "stuff" and you might want the TF lite op order to be
138  # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this
139  # attribute to [2, 0, 1, -1].
140  TFLITE_INPUT_INDICES = "_tflite_input_indices"
141  # OpHint level.
142  FUNCTION_LEVEL_ATTR = "_tflite_ophint_level"
143  # Ophint internal mapping, this is for high level Ophint only.
144  # This basically contains three kinds of mapping:
145  #   1) How parental ophinted inputs map to the first child ophinted inputs;
146  #   2) How internal children nodes are connected;
147  #   3) How parental ophinted outputs map to the last child ophinted outputs.
148  CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping"
149
150  # Types of aggregations
151  #  stack: stacks all ophints with matching tags. i.e. for a static rnn.
152  #   specifically, this is good for an input or output to a static rnn cell.
153  AGGREGATE_STACK = "stack"
154  # first: only takes the first output (one with lowest sort index)
155  # of matching tags. This is good for the input state to an RNN.
156  AGGREGATE_FIRST = "first"
157  # aggregation last takes only the last tag (one with highest sort index).
158  # This is good for an output value on the last stack item of a
159  # static rnn.
160  AGGREGATE_LAST = "last"
161
162  class OpHintArgumentTracker(object):
163    """Conceptually tracks indices of arguments of "OpHint functions".
164
165    The inputs and arguments of these functions both use an instance
166    of the class so they can have independent numbering.
167    """
168
169    def __init__(self,
170                 function_name,
171                 unique_function_id,
172                 node_name_prefix,
173                 attr_name,
174                 level=1,
175                 children_inputs_mappings=None):
176      """Initialize ophint argument.
177
178      Args:
179        function_name: Name of the function that this tracks arguments for.
180        unique_function_id: UUID of function that this tracks arguments for.
181        node_name_prefix: How identities that are created are named.
182        attr_name: Name of attribute to use to store the index for this hint.
183          i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX
184        level: Hierarchical level of the Ophint node, a number.
185        children_inputs_mappings: Inputs/Outputs mapping for children hints.
186      """
187
188      # The global index is the argument index of the op. This is in contrast
189      # to the sort index which is the sequence number of a particular instance
190      # of a given global index. For example, you may have called add hint
191      # twice with the tag "foo". Then the global index will be 0 for both
192      # and the sort index will be 0 for the first added and 1 for the second.
193      self._function_name = function_name
194      self._unique_function_id = unique_function_id
195      self._next_global_index = 0  # The absolute global index
196      self._used_global_indices = set()
197      self._tag_to_global_index = {}  # The argument index a given tag maps to
198      self._tag_to_next_sort_index = {}  # The current index for each tag
199      self._node_name_prefix = node_name_prefix
200      self._attr_name = attr_name
201      self._level = level
202      self._children_inputs_mappings = children_inputs_mappings
203
204    def _get_new_global_index(self, index_override):
205      """Return the next unused argument index in order or use an override.
206
207      Args:
208        index_override: An index to use instead of the next available or None
209          to use the next available.
210
211      Returns:
212        A valid global_index to use for the next hint argument.
213
214      Raises:
215        ValueError: If the index_override is already used by another hint.
216      """
217      if index_override is None:
218        global_index = self._next_global_index
219      else:
220        if index_override in self._used_global_indices:
221          raise ValueError("Index %d was already used by another call to add")
222        global_index = index_override
223      # Make next_global_index valid
224      self._used_global_indices.add(global_index)
225      while self._next_global_index in self._used_global_indices:
226        self._next_global_index += 1
227      return global_index
228
229    def add(self, arg, tag=None, name=None, aggregate=None,
230            index_override=None):
231      """Return a wrapped tensor of an input tensor as an argument.
232
233      Args:
234        arg: A TensorFlow tensor that should be considered an argument.
235        tag: String tag to identify arguments that should be packed.
236        name: Name of argument. This is included in the Identity hint op names.
237        aggregate: Strategy to aggregate.
238        Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
239          and OpHint.AGGREGATE_STACK.
240          Note, aggregate is only valid if tag is specified.
241        index_override: Specify what input/output index should this be in the
242          final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the
243          final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than
244          the default call order based ordering.
245
246      Returns:
247        A tensor representing the wrapped argument.
248
249      Raises:
250        ValueError: When indices are not consistent.
251      """
252
253      # Find the appropriate index
254      if tag is None:
255        if aggregate is not None:
256          raise ValueError("You must specify `tag` if using aggregate.")
257        global_index = self._get_new_global_index(index_override)
258        sort_index = None
259      else:
260        if aggregate is None:
261          raise ValueError("You must specify `aggregate` if using tag.")
262        if tag not in self._tag_to_global_index:
263          self._tag_to_global_index[tag] = (
264              self._get_new_global_index(index_override))
265          self._tag_to_next_sort_index[tag] = 0
266        elif (index_override and
267              index_override != self._tag_to_global_index[tag]):
268          raise ValueError(
269              "Tag %r was called with two indices %r and %r" %
270              (tag, index_override, self._tag_to_global_index[tag]))
271        global_index = self._tag_to_global_index[tag]
272        sort_index = self._tag_to_next_sort_index[tag]
273        self._tag_to_next_sort_index[tag] += 1
274
275      uuid = self._unique_function_id
276      name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name,
277                                    uuid, global_index, sort_index, name)
278
279      identity_op = _array_ops.identity(arg, name=name)
280
281      # pylint: disable=protected-access
282      identity_op.op._set_attr(
283          OpHint.FUNCTION_NAME_ATTR,
284          _attr_value_pb2.AttrValue(
285              s=_compat.as_bytes(self._function_name)))
286      identity_op.op._set_attr(
287          OpHint.FUNCTION_UUID_ATTR,
288          _attr_value_pb2.AttrValue(
289              s=_compat.as_bytes(self._unique_function_id)))
290      identity_op.op._set_attr(
291          self._attr_name, _attr_value_pb2.AttrValue(i=global_index))
292      identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR,
293                               _attr_value_pb2.AttrValue(i=self._level))
294      if self._children_inputs_mappings:
295        identity_op.op._set_attr(
296            OpHint.CHILDREN_INPUTS_MAPPINGS,
297            _attr_value_pb2.AttrValue(
298                s=_compat.as_bytes(_json.dumps(
299                    self._children_inputs_mappings))))
300
301      if sort_index is not None:
302        identity_op.op._set_attr(
303            OpHint.FUNCTION_SORT_INDEX_ATTR,
304            _attr_value_pb2.AttrValue(i=sort_index))
305      if aggregate is not None:
306        identity_op.op._set_attr(
307            OpHint.FUNCTION_AGGREGATE_ATTR,
308            _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate))))
309      # pylint: enable=protected-access
310      return identity_op
311
312  def __init__(self,
313               function_name,
314               level=1,
315               children_inputs_mappings=None,
316               **kwargs):
317    """Create a OpHint.
318
319    Args:
320      function_name: Name of the function (the custom op name in tflite)
321      level: OpHint level.
322      children_inputs_mappings: Children OpHint inputs/outputs mapping.
323        children_inputs_mappings should like below:
324        "parent_first_child_input":
325            [{"parent_input_index": num, "child_input_index": num}, ...]
326        "parent_last_child_output":
327            [{"parent_output_index": num, "child_output_index": num}, ...]
328        "internal_children_input_output":
329            [{"child_input_index": num, "child_output_index": num}, ...]
330      **kwargs: Keyword arguments of any constant attributes for the function.
331    """
332    self._function_name = function_name
333    self._level = level
334    if self._level == 1:
335      assert children_inputs_mappings is None
336    else:
337      assert isinstance(children_inputs_mappings, dict)
338    self._children_inputs_mappings = children_inputs_mappings
339    if self._children_inputs_mappings is not None:
340      self._validate_children_inputs_mappings(self._children_inputs_mappings)
341    self._unique_function_id = _uuid.uuid1().hex
342    self._attrs_to_store_later = kwargs
343    self._stored_attrs = False
344    self._inputs = OpHint.OpHintArgumentTracker(
345        self._function_name, self._unique_function_id, "InputHint",
346        OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings)
347    self._outputs = OpHint.OpHintArgumentTracker(
348        self._function_name, self._unique_function_id, "OutputHint",
349        OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level,
350        self._children_inputs_mappings)
351
352  def _validate_children_inputs_mappings(self, children_inputs_mappings):
353    """Validate children inputs mappings is in the right format.
354
355    Args:
356      children_inputs_mappings: the Children ophint inputs/outputs mapping.
357    """
358    assert isinstance(children_inputs_mappings, dict)
359    assert "parent_first_child_input" in children_inputs_mappings
360    assert "parent_last_child_output" in children_inputs_mappings
361    assert "internal_children_input_output" in children_inputs_mappings
362
363    # validate parent_first_child_input.
364
365    def assert_dictlist_has_keys(dictlist, keys):
366      for dikt in dictlist:
367        assert isinstance(dikt, dict)
368        for key in keys:
369          assert key in dikt
370
371    assert_dictlist_has_keys(
372        children_inputs_mappings["parent_first_child_input"],
373        ["parent_ophint_input_index", "first_child_ophint_input_index"])
374    assert_dictlist_has_keys(
375        children_inputs_mappings["parent_last_child_output"],
376        ["parent_output_index", "child_output_index"])
377    assert_dictlist_has_keys(
378        children_inputs_mappings["internal_children_input_output"],
379        ["child_input_index", "child_output_index"])
380
381  def _setattr(self, dest_op, name, value):
382    tensor_value = _ops.convert_to_tensor(value)
383    # pylint: disable=protected-access
384    dest_op.op._set_attr(name, _attr_value_pb2.AttrValue(
385        tensor=tensor_value.op.node_def.attr["value"].tensor))
386    # pylint: enable=protected-access
387
388  def add_input(self, *args, **kwargs):
389    """Add a wrapped input argument to the hint.
390
391    Args:
392      *args: The input tensor.
393      **kwargs:
394        "name" label
395        "tag" a tag to group multiple arguments that will be aggregated. I.e.
396          a string like 'cool_input'. Basically multiple inputs can be added
397          to the same hint for parallel operations that will eventually be
398          combined. An example would be static_rnn which creates multiple copies
399          of state or inputs.
400        "aggregate" aggregation strategy that is valid only for tag non None.
401          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
402          and OpHint.AGGREGATE_STACK.
403        "index_override" The global index to use. This corresponds to the
404          argument order in the final stub that will be generated.
405    Returns:
406      The wrapped input tensor.
407    """
408    return self._inputs.add(*args, **kwargs)
409
410  def add_output(self, *args, **kwargs):
411    """Add a wrapped output argument to the hint.
412
413    Args:
414      *args: The output tensor.
415      **kwargs:
416        "name" label
417        "tag" a tag to group multiple arguments that will be aggregated. I.e.
418          a string like 'cool_input'. Basically multiple inputs can be added
419          to the same hint for parallel operations that will eventually be
420          combined. An example would be static_rnn which creates multiple copies
421          of state or inputs.
422        "aggregate" aggregation strategy that is valid only for tag non None.
423          Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST,
424          and OpHint.AGGREGATE_STACK.
425        "index_override" The global index to use. This corresponds to the
426          argument order in the final stub that will be generated.
427    Returns:
428      The wrapped output tensor.
429    """
430    return self._outputs.add(*args, **kwargs)
431
432  def add_inputs(self, *args, **kwargs):
433    """Add a sequence of inputs to the function invocation.
434
435    Args:
436      *args: List of inputs to be converted (should be Tf.Tensor).
437      **kwargs: This allows 'names' which should be a list of names.
438
439    Returns:
440      Wrapped inputs (identity standins that have additional metadata). These
441      are also are also tf.Tensor's.
442    """
443    if "names" in kwargs:
444      return [
445          self._inputs.add(arg, name=name)
446          for arg, name in zip(args, kwargs["names"])
447      ]
448    else:
449      return [self._inputs.add(arg) for arg in args]
450
451  def add_outputs(self, *args, **kwargs):
452    """Add a sequence of outputs to the function invocation.
453
454    Args:
455      *args: List of outputs to be converted (should be tf.Tensor).
456      **kwargs: See
457
458    Returns:
459      Wrapped outputs (identity standins that have additional metadata). These
460      are also tf.Tensor's.
461    """
462    if "names" in kwargs:
463      return [
464          self._outputs.add(arg, name=name)
465          for arg, name in zip(args, kwargs["names"])
466      ]
467    else:
468      return [self._outputs.add(arg) for arg in args]
469
470
471class _LiteOperand(object):
472  """Abstract operand for a tflite hint function._dynamic_rnn_loop.
473
474  This is a base class that handles representing arguments to an OpHint.
475  It also is able to serialize operands to the stubbed graph_def.
476  Child classes are responsible for being able to
477  store information about the hint identity operators. They are also responsible
478  for knowing how to serialize to output graphdefs.
479
480  Typically this will be implemented by holding one or more identity nodes
481  that were previously discovered as hints.
482  """
483
484  def aggregate_and_return_name_for_input(self, out_graphdef):
485    """This adds the node(s) to out_graphdef and returns the input node name.
486
487    Args:
488      out_graphdef: A graphdef that is ready to have this input added.
489
490    Returns:
491      The output that the stub should use as an input for this operand.
492
493    Raises:
494      RuntimeError: if the method is not implemented.
495    """
496    del out_graphdef
497    raise RuntimeError("Unimplemented abstract method.")
498
499  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
500                                           out_graphdef):
501    """Add node(s) to graph representing output operands and returns type.
502
503    Args:
504      fused_op_name: name of the fused op stub name.
505      output_index: Output index that we are currently processing from stub.
506      out_graphdef: The destination graphdef we are currently building up.
507
508    Returns:
509      The datatype of this identity.
510
511    Raises:
512      RuntimeError: if the method is not implemented.
513    """
514    del fused_op_name, output_index, out_graphdef
515    raise RuntimeError("Unimplemented abstract method.")
516
517
518class _LiteSingleOperand(_LiteOperand):
519  """A simple operand that is non-aggregated (i.e. most hints)."""
520
521  def __init__(self, node):
522    _LiteOperand.__init__(self)
523    self.node = node
524    self.name = _tensor_name_base(node.name)
525
526  def flatten(self):
527    return [self.name]
528
529  def aggregate_and_return_name_for_input(self, out_graphdef):
530    return self.name
531
532  def aggregate_and_return_name_for_output(self, fused_op_name, index,
533                                           out_graphdef):
534    output_node = _copy.deepcopy(self.node)
535    del output_node.input[:]
536    output_node.input.append(_tensorflow_output_name(fused_op_name, index))
537    out_graphdef.node.extend([output_node])
538    return self.node.attr["type"].i
539
540  def __str__(self):
541    return str(self.name)
542
543
544class _LiteAggregateOperand(_LiteOperand):
545  """An operand for a tflite hint function that is aggregated from many.
546
547  For example, an LSTM is a grid of operators that are all related. Inputs
548  going into them may need to be fused, so they should all be tracked as
549  related arguments.
550  """
551
552  def __init__(self, aggregation):
553    _LiteOperand.__init__(self)
554    self.aggregation = aggregation
555    self.names = {}
556    self.nodes = {}
557    self.flattened = None
558
559  def add(self, sort, node):
560    self.names[sort] = _tensor_name_base(node.name)
561    self.nodes[sort] = node
562
563  def flatten_nodes(self):
564    """Return a list of all the node protos in aggregation sorted order."""
565    if not self.flattened:
566      self.flattened = [None] * len(self.nodes)
567      for idx, node in _six.iteritems(self.nodes):
568        self.flattened[idx] = node
569      for n in self.nodes:
570        if n is None:
571          raise RuntimeError("Aggregate was missing argument.")
572      if self.aggregation == OpHint.AGGREGATE_FIRST:
573        self.flattened = self.flattened[:1]
574      elif self.aggregation == OpHint.AGGREGATE_LAST:
575        self.flattened = self.flattened[-1:]
576      elif self.aggregation == OpHint.AGGREGATE_STACK:
577        pass
578      else:
579        raise ValueError("Invalid aggregation type %r specified" %
580                         self.aggregation)
581    return self.flattened
582
583  def flatten(self):
584    """Return a list of all node names in aggregation sorted sorter."""
585    return [_tensor_name_base(x.name) for x in self.flatten_nodes()]
586
587  def aggregate_and_return_name_for_input(self, out_graphdef):
588    """This adds the nodes to out_graphdef and returns an aggregated output.
589
590    In particular, if you have 4 inputs to a hint stub, this will be the
591    node that you can use as an output. I.e. you have 4 timesteps from a
592    static rnn, then a fused UnidirectionalLSTM will expect 1 input with
593    all 4 time steps. So here we make a pack and return the output name of
594    that pack.
595
596    Args:
597      out_graphdef: A graphdef that is ready to have this input added.
598
599    Returns:
600      The name of a pack that aggregates this node.
601    """
602    flattened = self.flatten_nodes()
603    if (self.aggregation == OpHint.AGGREGATE_FIRST) or (
604        self.aggregation == OpHint.AGGREGATE_LAST):
605      assert len(flattened) == 1
606    if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK:
607      return _tensor_name_base(flattened[0].name)
608    else:
609      new_node = _node_def_pb2.NodeDef()
610      new_node.op = "Pack"
611      new_node.name = "OpHintStack-%s" % flattened[0].name
612      new_node.attr["N"].i = len(flattened)
613      new_node.attr["T"].type = flattened[0].attr["T"].type
614      for discrete in flattened:
615        new_node.input.append(_tensor_name_base(discrete.name))
616      out_graphdef.node.extend([new_node])
617      return new_node.name
618
619  def aggregate_and_return_name_for_output(self, fused_op_name, output_index,
620                                           out_graphdef):
621    """This adds to `out_graphdef` all the unaggregated outputs.
622
623    I.e. we are outputting from a fused stub, but we need to make it compatible
624    with the unfused original graph so we insert an unpack. Ideally in a later
625    stage the unpack -> pack sequences will be removed.
626
627    Args:
628      fused_op_name: The name of the stub we are in the process of fusing.
629      output_index: The output output_index this object represents.
630      out_graphdef: The graphdef we are in the process of buildings
631
632    Returns:
633      The type of the aggregated output (so we can finish building the stub
634      op).
635    """
636    flattened = self.flatten_nodes()
637    if (self.aggregation == OpHint.AGGREGATE_FIRST) or (
638        self.aggregation == OpHint.AGGREGATE_LAST):
639      assert len(flattened) == 1
640    if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK:
641      temp_op = _LiteSingleOperand(flattened[0])
642      return temp_op.aggregate_and_return_name_for_output(
643          fused_op_name, output_index, out_graphdef)
644    else:
645      stack_node = _node_def_pb2.NodeDef()
646      stack_node.op = "Unpack"
647      stack_node.name = "OpHintUnstack-%s" % flattened[0].name
648      stack_node.attr["num"].i = len(flattened)
649      output_type = flattened[0].attr["T"].type
650      stack_node.attr["T"].type = output_type
651      stack_node.input.append(
652          _tensorflow_output_name(fused_op_name, output_index))
653      out_graphdef.node.extend([stack_node])
654
655      for idx, discrete in enumerate(flattened):
656        output_node = _copy.deepcopy(discrete)
657        del output_node.input[:]
658        output_node.input.append(_tensorflow_output_name(stack_node.name, idx))
659        out_graphdef.node.extend([output_node])
660
661      return output_type
662
663  def __str__(self):
664    s = "\t\t\tAGGREGATE %s\n" % self.aggregation
665    for sort, val in self.names.iteritems():
666      s += "\t\t\t%d: %s\n" % (sort, val)
667    return s
668
669
670class _LiteFuncCall(object):
671  """Represent a TensorFlow Lite custom function.
672
673  This is uses to accumulate found hints in the graphdef into a single
674  conceptual unit.
675
676  Attributes:
677    inputs: inputs to the op (hash from index # to argument)
678    outputs: outputs to the op (hash from index # to argument)
679    function_name: the tflite custom op name to use
680    uuid: a unique call id for this particular call  (i.e. multiple function
681      calls would have the same function_name but different uuids.
682    params: A param name to key value for op constant data. I.e. for axis on a
683      reduction, strides on a convolution, etc.
684    level: Level of the OpHint.
685    children_inputs_mappings: If the Ophint has children, children inputs
686      mappings indicate how their inputs & outputs are mapped.
687  """
688
689  def __init__(self):
690    self.inputs = {}
691    self.outputs = {}
692    self.function_name = None
693    self.uuid = None
694    self.params = {}
695    self.level = -1
696    self.children_inputs_mappings = {}
697
698  def flattened_inputs_and_outputs(self):
699    """Return a list of inputs and outputs in a flattened format.
700
701    Returns:
702      Tuple of (inputs, outputs). where input and output i a list of names.
703    """
704
705    def _flatten(input_or_output_dict):
706      flattened_items = []
707      for item in input_or_output_dict.values():
708        flattened_items.extend(item.flatten())
709      return flattened_items
710
711    return _flatten(self.inputs), _flatten(self.outputs)
712
713  def __str__(self):
714
715    def format_args(items):
716      s = ""
717      for idx, item in items.iteritems():
718        s += ("\t\t%d:\n" % idx) + str(item)
719      return s
720
721    inputs_str = "\tInputs\n" + format_args(self.inputs)
722    outputs_str = "\tOutputs\n" + format_args(self.outputs)
723
724    return (
725        "tflite function %s call %s level %d "
726        "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" %
727        (self.function_name, self.uuid, self.level, inputs_str, outputs_str))
728
729
730def _find_all_hints_in_nodes(nodes):
731  """Look at the all the input nodes and return a list of LiteFuncCall objs.
732
733  Args:
734    nodes: A TensorFlow graph_def to look for LiteFuncCalls.
735
736  Returns:
737    a list of `LifeFuncCall` objects in the form
738
739  """
740  func_calls = _collections.defaultdict(_LiteFuncCall)
741
742  for node in nodes:
743    attr = node.attr
744    # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip
745    if (OpHint.FUNCTION_UUID_ATTR not in attr or
746        not attr[OpHint.FUNCTION_UUID_ATTR].s):
747      continue
748    uuid = attr[OpHint.FUNCTION_UUID_ATTR].s
749
750    # Start building function
751    call_def = func_calls[uuid]
752    call_def.uuid = uuid
753    call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s
754    call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i
755    # Get sorting and aggregation information
756
757    sort = (
758        attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i
759        if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None)
760    if sort == -1:
761      sort = None
762    aggregation = None
763    if OpHint.FUNCTION_AGGREGATE_ATTR in attr:
764      aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s)
765
766    if OpHint.CHILDREN_INPUTS_MAPPINGS in attr:
767      call_def.children_inputs_mappings = _json.loads(
768          _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s))
769
770    # Add the input or output
771    def put_operand(stuff, index, sort, operand, aggregation):
772      """Add a given index into the function structure."""
773      if sort is None:
774        stuff[index] = _LiteSingleOperand(operand)
775      else:
776        if index not in stuff:
777          stuff[index] = _LiteAggregateOperand(aggregation)
778        stuff[index].add(sort, operand)
779
780    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
781      put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i,
782                  sort, node, aggregation)
783    if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr:
784      put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i,
785                  sort, node, aggregation)
786
787    # Remember attributes
788    for a in attr:
789      if a.startswith("_tflite_attr_"):
790        call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor
791
792  return func_calls
793
794
795def _extract_topology_sequence_mapping(nodes):
796  return dict(
797      (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes))
798
799
800def _find_children_hints_in_while_loop(function_def, nodes_mapping):
801  """Find children hints and all nodes inside the while loop.
802
803  Args:
804    function_def: Function def of the while loop.
805    nodes_mapping: While loop input_arg : real node name.
806
807  Returns:
808    Ordered children hints and all re-mapped nodes inside the while loop.
809  """
810  new_nodes = []
811
812  # Make nodes inside function def inputs point to the real nodes.
813  for node in function_def.node_def:
814    for i, _ in enumerate(node.input):
815      if node.input[i] in nodes_mapping:
816        node.input[i] = nodes_mapping[node.input[i]]
817    new_nodes.append(_copy.deepcopy(node))
818  name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def)
819  children_hints = _find_all_hints_in_nodes(new_nodes)
820  children_hints_q = []
821  # Ordered by the outputs.
822  for hint in _six.itervalues(children_hints):
823    _, output_names = hint.flattened_inputs_and_outputs()
824    seq = name_to_seq_num[output_names[0]]
825    for output_name in output_names:
826      seq = min(seq, name_to_seq_num[output_name])
827    children_hints_q.append((seq, hint))
828  children_hints_q.sort(key=lambda tup: tup[0])
829  ordered_children_hints = [x[1] for x in children_hints_q]
830  return ordered_children_hints, new_nodes
831
832
833def _find_children_hints(call, graph_def):
834  """Find all children hints.
835
836  For a given OpHint, we find all children hints inside it, we also copy all the
837  nodes inside function defs (if applicable) to the original graph_def, they are
838  returned in a list as well.
839
840  Args:
841    call: Parent OpHint that contains children ophints.
842    graph_def: Original graph def.
843
844  Returns:
845    Ordered children hints inside the parent ophint; new graph def that contains
846    nodes inside function defs (if applicable); nodes inside function defs.
847  """
848  name_to_input_name, _, _ = _extract_graph_summary(graph_def)
849  input_names, output_names = call.flattened_inputs_and_outputs()
850
851  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
852  reachable_by_output = _bfs_for_reachable_nodes(output_names,
853                                                 name_to_input_name)
854  output_nodes_set = set(output_names)
855  children_hints = []
856  out = _graph_pb2.GraphDef()
857  out.library.CopyFrom(graph_def.library)
858  out.versions.CopyFrom(graph_def.versions)
859  function_def_nodes = set()
860  for node in graph_def.node:
861    out.node.extend([_copy.deepcopy(node)])
862    n = _tensor_name_base(node.name)
863    if n in reachable_by_output:
864      if n not in reachable_by_input and n not in output_nodes_set:
865        # special handle for while loop function def.
866        if node.op == "While" or node.op == "StatelessWhile":
867          body_name = node.attr["body"].func.name
868          inputs_outside_loop = node.input
869          for function_def in graph_def.library.function:
870            if function_def.signature.name == body_name:
871              function_inputs = function_def.signature.input_arg
872              assert len(inputs_outside_loop) == len(function_inputs)
873              nodes_mapping = {}
874              for i, function_input in enumerate(function_inputs):
875                nodes_mapping[function_input.name] = inputs_outside_loop[i]
876              (children_hints_in_loop,
877               new_nodes) = _find_children_hints_in_while_loop(
878                   function_def, nodes_mapping)
879              function_def_nodes.update([x.name for x in new_nodes])
880              children_hints.extend(children_hints_in_loop)
881              out.node.extend(new_nodes)
882
883  return children_hints, out, function_def_nodes
884
885
886def _tensor_name_base(full_tensor_name):
887  """Removes the device assignment code from a tensor.
888
889  e.g. _tensor_name_base("foo:3") => "foo"
890
891  Args:
892    full_tensor_name: A tensor name that is annotated with a device placement
893      (this is what tensor flow introspection gives).
894
895  Returns:
896    A name without any device assignment.
897  """
898  if full_tensor_name.startswith("^"):
899    return full_tensor_name[1:]
900  return full_tensor_name.split(":")[0]
901
902
903def _tensorflow_output_name(tensor_name, output_index):
904  return tensor_name if output_index == 0 else "%s:%d" % (tensor_name,
905                                                          output_index)
906
907
908def _check_subgraph_closed(n, reachable_by_input, input_nodes_set,
909                           name_to_input_name):
910  """Checks to make sure node only connects to predecessor graph through inputs.
911
912  Args:
913    n: Node to check
914    reachable_by_input: Nodes that are reachable by all inputs of subgraph
915    input_nodes_set: The set of nodes that are "inputs".
916    name_to_input_name: Maps from name to the list of inputs.
917
918  Raises:
919    TypeError: If the given node uses items past inputs directly.
920  """
921  next_to_visit = [n]
922  visited = set()
923  while next_to_visit:
924    current_node = next_to_visit.pop()
925    visited.add(current_node)
926    if (current_node in reachable_by_input and
927        current_node not in input_nodes_set):
928      raise TypeError("Node %s uses input %s not in input_nodes." %
929                      (n, current_node))
930    if current_node not in input_nodes_set:
931      next_to_visit += [
932          input_node for input_node in name_to_input_name[current_node]
933          if input_node not in visited
934      ]
935
936
937def _convert_single_op_hint_to_stub(call,
938                                    graph_def,
939                                    function_def_nodes=None,
940                                    is_last_run=True):
941  """Given a graph_def, converts `call` into a stub and returns a new graph_def.
942
943  Args:
944    call: A single function call to be converted.
945    graph_def: A graph_def to use as input (that has call obviously).
946    function_def_nodes: Nodes inside the function def those are not connected to
947      the graph.
948    is_last_run: Whether it is the last run for a given pass (for OpHint has
949      children).
950
951  Returns:
952    A new transformed graph-def that has call as a stub (single op).
953
954  Note: after this process, the graph_def can no longer be loaded into
955      the tensorflow runtime, so all future manipulations are done in graph_def
956      level.
957  """
958  if function_def_nodes is None:
959    function_def_nodes = set()
960  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
961      graph_def)
962  input_names, output_names = call.flattened_inputs_and_outputs()
963
964  reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name)
965  reachable_by_output = _bfs_for_reachable_nodes(output_names,
966                                                 name_to_input_name)
967  output_nodes_set = set(output_names)
968  nodes_after_fuse = []
969  nodes_deleted_by_fuse = set()
970  # Classify each node. We want to keep everything reachable by input, but
971  # we don't know if things that are not reachable by output or input (things
972  # after fusing).
973  for node in graph_def.node:
974    n = _tensor_name_base(node.name)
975    if n in reachable_by_output:
976      if n not in reachable_by_input and n not in output_nodes_set:
977        nodes_deleted_by_fuse.add(n)
978    elif n not in reachable_by_input and n not in function_def_nodes:
979      # n is a node that after all the fusings, so keep it.
980      nodes_after_fuse.append(n)
981    else:
982      # In the last run, n is a node that is randomly in the graph but not
983      # connected to the chain of dependencies, we will delete n, otherwise
984      # we keep them.
985      if not is_last_run:
986        nodes_after_fuse.append(n)
987
988  # Make a new graphdef with all the pre-input and input nodes
989  out = _graph_pb2.GraphDef()
990  reachable_by_input_sorted = sorted(
991      list(reachable_by_input), key=lambda n: name_to_seq_num[n])
992  for node in reachable_by_input_sorted:
993    out.node.extend([_copy.deepcopy(name_to_node[node])])
994
995  # Create any stacks to aggregate arguments into to a single input
996  # i.e. for static_rnn's.
997  sorted_input_indices = list(call.inputs.keys())
998  sorted_input_indices.sort()
999  sorted_output_indices = list(call.outputs.keys())
1000  sorted_output_indices.sort()
1001  new_node = _node_def_pb2.NodeDef()
1002  # Delegate to each operand to produce the proper new input for this stub node.
1003  # In particular, an aggregate input will now be a Pack of some previously
1004  # non-fused things.
1005
1006  optional_input_node = _node_def_pb2.NodeDef()
1007  optional_input_node.name = "Const" + str(_uuid.uuid1().hex)
1008  optional_input_node.op = "Const"
1009  optional_input_node.attr["dtype"].CopyFrom(
1010      _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum))
1011  optional_input_node.attr["value"].CopyFrom(
1012      _attr_value_pb2.AttrValue(
1013          tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1])))
1014  out.node.extend([optional_input_node])
1015
1016  max_index = max(sorted_input_indices) + 1
1017  for cur_index in range(max_index):
1018    if cur_index in sorted_input_indices:
1019      inputs = call.inputs[cur_index]
1020      input_name = inputs.aggregate_and_return_name_for_input(out)
1021      new_node.input.append(input_name)
1022    else:
1023      new_node.input.append(optional_input_node.name)
1024
1025  new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices)
1026
1027  # Create the function
1028  new_node.op = call.function_name
1029  new_node.name = call.uuid
1030  out.node.extend([new_node])
1031
1032  # Now call each output argument to give them a chance to make the proper
1033  # output type and add it to our new_node.
1034  output_dtypes = []
1035  max_output_index = max(sorted_output_indices) + 1
1036  for cur_index in range(max_output_index):
1037    if cur_index in sorted_output_indices:
1038      output = call.outputs[cur_index]
1039      output_dtype = (
1040          output.aggregate_and_return_name_for_output(new_node.name, cur_index,
1041                                                      out))
1042    else:
1043      output_dtype = optional_input_node.attr["type"].i
1044    output_dtypes.append(output_dtype)
1045  new_node.attr["_output_types"].list.type[:] = output_dtypes
1046  new_node.attr["_output_quantized"].b = False
1047
1048  # Add post output nodes that do not depend on the outputs
1049  for n in nodes_after_fuse:
1050    should_keep = True
1051    for input_name in name_to_input_name[n]:
1052      if input_name in nodes_deleted_by_fuse:
1053        should_keep = False
1054    if should_keep:
1055      out.node.extend([_copy.deepcopy(name_to_node[n])])
1056
1057  # Misc. graph_def data that needs copying.
1058  out.library.CopyFrom(graph_def.library)
1059  out.versions.CopyFrom(graph_def.versions)
1060
1061  return out
1062
1063
1064def _remove_one_redundant_stack_unstack(in_graph_def):
1065  """Removes a stack->unstack pattern from in_graph_def in a returned graph.
1066
1067  Args:
1068    in_graph_def: Graph def to use as input.
1069
1070  Returns:
1071    Simplified tuple (graph_def, changed_something) where changed_something
1072    is true if anything was done.
1073  """
1074  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
1075      in_graph_def)
1076  del name_to_seq_num
1077
1078  do_generic_pack_unpack = True
1079
1080  out = _graph_pb2.GraphDef()
1081  out.library.CopyFrom(in_graph_def.library)
1082  out.versions.CopyFrom(in_graph_def.versions)
1083  for n in in_graph_def.node:
1084    node_name = _tensor_name_base(n.name)
1085    if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"):
1086      continue
1087    next_to_visit = [node_name]
1088    visited = set()
1089
1090    unpack_nodes = set()
1091    pack_node = node_name
1092
1093    # Find a pattern of unstack connected to a stack (with identities
1094    # in between.
1095    matches_pattern = True
1096    is_hint_created_stack = False
1097    while next_to_visit:
1098      current_node_name = next_to_visit[0]
1099      visited.add(current_node_name)
1100      del next_to_visit[0]
1101      node = name_to_node[current_node_name]
1102      is_op_hint_stack = node.name.startswith("OpHintStack")
1103      is_op_hint_unstack = node.name.startswith("OpHintUnstack")
1104      if (node.op == "Identity" or is_op_hint_stack or
1105          (do_generic_pack_unpack and node.op == "Pack")):
1106        is_hint_created_stack |= is_op_hint_stack
1107        next_to_visit += [
1108            input_node for input_node in name_to_input_name[current_node_name]
1109            if input_node not in visited
1110        ]
1111      elif (is_op_hint_unstack or
1112            (do_generic_pack_unpack and node.op == "Unpack")):
1113        unpack_nodes.add(node.name)
1114        is_hint_created_stack &= is_op_hint_unstack
1115      else:
1116        matches_pattern = False
1117        break
1118      visited.add(node.name)
1119
1120    if matches_pattern and len(unpack_nodes) == 1:
1121      pack_node = node_name
1122
1123      # Check to see if anyone depends on the intermediate identity or the
1124      # Unstacked form
1125      no_external_dependency = True
1126      for other_n in in_graph_def.node:
1127        if other_n.name in visited:
1128          continue
1129        for input_tensor in name_to_input_name[other_n.name]:
1130          input_op = _tensor_name_base(input_tensor)
1131          if input_op in visited and input_op != pack_node:
1132            no_external_dependency = False
1133      # Proceed with the substitution if the stack/unstack pair was created
1134      # through hints, or that it was not, but nobody is consuming things
1135      # between the stack and unstack.
1136      if is_hint_created_stack or no_external_dependency:
1137        end = unpack_nodes.pop()
1138        end_input = name_to_node[end].input[0]
1139        # All nodes that depend on the final stack need to be redone to use
1140        for other_n in in_graph_def.node:
1141          node_name = _tensor_name_base(other_n.name)
1142          if node_name not in visited:
1143            new_node = _copy.deepcopy(other_n)
1144            new_node.input[:] = [
1145                (end_input if stripped == pack_node else non_stripped)
1146                for stripped, non_stripped in zip(name_to_input_name[node_name],
1147                                                  new_node.input[:])
1148            ]
1149            out.node.extend([new_node])
1150        return out, True
1151  return in_graph_def, False
1152
1153
1154def _remove_redundant_stack_unstack(graph_def):
1155  curr = graph_def
1156  del graph_def
1157  changed_stuff = True
1158  while changed_stuff:
1159    curr, changed_stuff = _remove_one_redundant_stack_unstack(curr)
1160  return curr
1161
1162
1163def _get_correct_mapping(original_index, nodes):
1164  # Special handle for the index is -1 case.
1165  # If it is -1, return the last index.
1166  if original_index == -1:
1167    node_indices = nodes.keys()
1168    node_indices = sorted(node_indices)
1169    return node_indices[-1]
1170  return original_index
1171
1172
1173def _convert_op_hints_to_stubs_helper(
1174    graph_def, write_callback=lambda sess, graph_def: None):
1175  """Converts a graph_def to a new graph_def where all op hints are stubbed.
1176
1177  Args:
1178    graph_def: A graph def that we should convert.
1179    write_callback: A function pointer that can be used to write intermediate
1180      steps of graph transformation (optional).
1181
1182  Returns:
1183    A new stubbed graph_def.
1184  """
1185  hints = _find_all_hints_in_nodes(graph_def.node)
1186
1187  hints_q = []
1188  for hint in _six.itervalues(hints):
1189    hints_q.append((hint.level, hint.uuid))
1190
1191  hints_q.sort(key=lambda tup: tup[0])
1192  for i in range(len(hints_q) - 1, -1, -1):
1193    level, hint_uuid = hints_q[i]
1194
1195  curr_graph_def = graph_def
1196  del graph_def  # prevent using graph_def again (common source of error)
1197  for i in range(len(hints_q) - 1, -1, -1):
1198    level, hint_uuid = hints_q[i]
1199    if level >= 2:
1200      children_hints, curr_graph_def, function_def_nodes = _find_children_hints(
1201          hints[hint_uuid], curr_graph_def)
1202      # pylint: disable=superfluous-parens
1203      assert (len(children_hints) > 0)  #  pylint: disable=g-explicit-length-test
1204      # pylint: enable=superfluous-parens
1205
1206      # Re-wire the children hints inputs/outputs, so latter child's inputs
1207      # connect to previous child node's outputs.
1208      children_inputs_mappings = hints[hint_uuid].children_inputs_mappings
1209      for j, child_hint in enumerate(children_hints):
1210        if j == 0:
1211          for mapping in children_inputs_mappings["parent_first_child_input"]:
1212            parent_input_index = _get_correct_mapping(
1213                mapping["parent_ophint_input_index"], hints[hint_uuid].inputs)
1214            child_input_index = _get_correct_mapping(
1215                mapping["first_child_ophint_input_index"], child_hint.inputs)
1216            child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[
1217                parent_input_index]
1218        else:
1219          for mapping in children_inputs_mappings[
1220              "internal_children_input_output"]:
1221            input_index = _get_correct_mapping(mapping["child_input_index"],
1222                                               child_hint.inputs)
1223            output_index = _get_correct_mapping(mapping["child_output_index"],
1224                                                children_hints[j - 1].outputs)
1225            child_hint.inputs[input_index] = children_hints[
1226                j - 1].outputs[output_index]
1227        if j == len(children_hints) - 1:
1228          for mapping in children_inputs_mappings["parent_last_child_output"]:
1229            parent_output_index = _get_correct_mapping(
1230                mapping["parent_output_index"], hints[hint_uuid].outputs)
1231            child_output_index = _get_correct_mapping(
1232                mapping["child_output_index"], child_hint.outputs)
1233            child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[
1234                parent_output_index]
1235
1236      for j, child_hint in enumerate(children_hints):
1237        curr_graph_def = _convert_single_op_hint_to_stub(
1238            child_hint, curr_graph_def, function_def_nodes,
1239            j == len(children_hints) - 1)
1240    else:
1241      curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid],
1242                                                       curr_graph_def)
1243      write_callback(curr_graph_def, "initial")
1244  # The stubbing process can create stacks/unstacks in the case of LSTMs
1245  # remove them.
1246  curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def)
1247  return curr_graph_def
1248
1249
1250def find_all_hinted_output_nodes(session=None, graph_def=None):
1251  """Find all Ophints output nodes in the graph.
1252
1253  This is used to get all the output nodes those are ophinted, it is important
1254  for operation like convert_variables_to_constants keep all ophints structure.
1255  Note: only one of session or graph_def should be used, not both.
1256  Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can
1257  generate multiple outputs for unfused subgraph. If not all output nodes are
1258  consumed, graph optimization can potentially drop the unused nodes and cause
1259  ophints in an invalid states (due to missing ophinted output nodes). So it's
1260  important for us to find all those hinted output nodes and make sure they're
1261  not discarded away.
1262
1263  Args:
1264    session: A TensorFlow session that contains the graph to convert.
1265    graph_def: A graph def that we should convert.
1266
1267  Returns:
1268    A list of OpHints output nodes.
1269  Raises:
1270    ValueError: If both session and graph_def are provided.
1271  """
1272  if session is not None and graph_def is not None:
1273    raise ValueError("Provide only one of session and graph_def.")
1274  hinted_outputs_nodes = []
1275  if session is not None:
1276    hints = _find_all_hints_in_nodes(session.graph_def.node)
1277  elif graph_def is not None:
1278    hints = _find_all_hints_in_nodes(graph_def.node)
1279  for hint in _six.itervalues(hints):
1280    _, output_nodes = hint.flattened_inputs_and_outputs()
1281    hinted_outputs_nodes.extend(output_nodes)
1282  return hinted_outputs_nodes
1283
1284
1285def is_ophint_converted(graph_def):
1286  if graph_def is None:
1287    raise ValueError("Must provide the graph_def.")
1288  ophint_converted = False
1289  for node in graph_def.node:
1290    attr = node.attr
1291    if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr:
1292      ophint_converted = True
1293      break
1294  return ophint_converted
1295
1296
1297@_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"])
1298@_deprecation.deprecated(
1299    None,
1300    "Please follow instructions under "
1301    "https://www.tensorflow.org/lite/convert/operation_fusion for operation"
1302    "fusion in tflite."
1303)
1304def convert_op_hints_to_stubs(session=None,
1305                              graph_def=None,
1306                              write_callback=lambda graph_def, comments: None):
1307  """Converts a graphdef with LiteOp hints into stub operations.
1308
1309  This is used to prepare for toco conversion of complex intrinsic usages.
1310  Note: only one of session or graph_def should be used, not both.
1311
1312  Args:
1313    session: A TensorFlow session that contains the graph to convert.
1314    graph_def: A graph def that we should convert.
1315    write_callback: A function pointer that can be used to write intermediate
1316      steps of graph transformation (optional).
1317
1318  Returns:
1319    A new graphdef with all ops contained in OpHints being replaced by
1320    a single op call with the right parameters.
1321  Raises:
1322    ValueError: If both session and graph_def are provided.
1323  """
1324
1325  if session is not None and graph_def is not None:
1326    raise ValueError("Provide only one of session and graph_def.")
1327
1328  if session is not None:
1329    return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback)
1330  elif graph_def is not None:
1331    return _convert_op_hints_to_stubs_helper(graph_def, write_callback)
1332  else:
1333    raise ValueError("Must specify session or graph_def as input.")
1334
1335
1336_allowed_symbols = [
1337    "OpHint",
1338    "convert_op_hints_to_stubs",
1339    "convert_op_hints_to_stubs_new",
1340    "find_all_hinted_output_nodes",
1341    "is_ophint_converted",
1342]
1343remove_undocumented(__name__, _allowed_symbols)
1344