1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Define tflite op hints (intrinsic operations). 16 17This essentially allows defining a TensorFlow API for tflite operations in 18Python with hints on how they are represented in TensorFlow Lite. This basically 19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution 20graph and is useful for LSTMs and other complicated TensorFlow constructions 21that are difficult to pattern match in TOCO, but are represented by a single 22accelerated tflite op. 23 24Example: 25 def tflite_cool_activation(input): 26 # A cool activation function. 27 custom = tf.lite.OpHint("cool_activation") 28 input, = custom.add_inputs(input) 29 output = tf.sigmoid(input) * input 30 output, = custom.add_outputs(output) 31 return output 32 33 image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1)) 34 output = tf.identity(tflite_cool_activation(image)) 35 36 session = tf.compat.v1.Session() 37 38 graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session) 39 tflite_graph = tf.compat.v1.lite.toco_convert( 40 graphdef_to_convert, [image], [output], allow_custom_ops=True) 41 with open("/tmp/graph.fb", "wb") as fp: 42 fp.write(tflite_graph) 43 44How does it work?: 45 46OpHint is a helper that you use when defining a vanilla python function. 47It allows you to wrap arguments with tf.identities with some custom attributes. 48These attributes allow you to find the original block of ops that was created. 49For example, if you use cool_activation above you essentially get: 50 51a_input = tf.identity() 52result = tf.multiply(tf.sigmoid(a_input), a_input) 53output = tf.identity() 54 55a_input, output are identities that have parameters representing 56what argument they are, what the name of the function they should turn into 57in tf lite as well as a guid that uniquely identifies a particular invocation. 58 59Once you have built your whole tensorflow graph, you can run it and train it 60as usual, but after you have done that, you need to convert the graph into 61a form that replaces these subgraphs wrapped in identities to stub ops. These 62ops don't actually exist in the normal TensorFlow runtime, but will be 63understood by toco later. The generated TensorFlow Lite flatbuffer file will 64contain a custom operator called "cool_activation". Developer needs to implement 65and register this operator in TensorFlow Lite in order to do inference. 66""" 67 68from __future__ import absolute_import 69from __future__ import division 70from __future__ import print_function 71 72import collections as _collections 73import copy as _copy 74import json as _json 75import uuid as _uuid 76import six as _six 77 78from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 79from tensorflow.core.framework import graph_pb2 as _graph_pb2 80from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 81from tensorflow.python.framework import dtypes as _dtypes 82from tensorflow.python.framework import ops as _ops 83from tensorflow.python.framework import tensor_util as _tensor_util 84from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 85from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 86from tensorflow.python.ops import array_ops as _array_ops 87from tensorflow.python.util import compat as _compat 88from tensorflow.python.util import deprecation as _deprecation 89from tensorflow.python.util.all_util import remove_undocumented 90from tensorflow.python.util.tf_export import tf_export as _tf_export 91 92 93@_tf_export(v1=["lite.OpHint"]) 94@_deprecation.deprecated( 95 None, 96 "Please follow instructions under " 97 "https://www.tensorflow.org/lite/convert/operation_fusion for operation" 98 "fusion in tflite." 99) 100class OpHint(object): 101 """A class that helps build tflite function invocations. 102 103 It allows you to take a bunch of TensorFlow ops and annotate the construction 104 such that toco knows how to convert it to tflite. This embeds a pseudo 105 function in a TensorFlow graph. This allows embedding high-level API usage 106 information in a lower level TensorFlow implementation so that an alternative 107 implementation can be substituted later. 108 109 Essentially, any "input" into this pseudo op is fed into an identity, and 110 attributes are added to that input before being used by the constituent ops 111 that make up the pseudo op. A similar process is done to any output that 112 is to be exported from the current op. 113 114 """ 115 # Attr constants that are used for representation in the GraphDef. These 116 # will be used on every Identity op that is involved in a total OpHint. 117 118 # Name of the OpHint function (cosmetic). 119 FUNCTION_NAME_ATTR = "_tflite_function_name" 120 # UUID of the function (each OpHint gets a new uuid). 121 FUNCTION_UUID_ATTR = "_tflite_function_uuid" 122 # The input index of the input (or nothing if it is an output). 123 FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" 124 # The output index of the output (or nothing if it is an input). 125 FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" 126 # An index that orders aggregate arguments. Aggregate arguments are ones 127 # that are separate but will be fused horizontally. For example a static LSTM 128 # has a lstm cell for each time step. Each one has a separate opHint, but a 129 # fused SequentialLSTM will treat this as a single tensor. 130 FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" 131 # The way in which multiple parts of the aggregate argument will be joined 132 # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, 133 # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. 134 FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" 135 # On fused OpHint stub, the order of inputs that the final LSTM call will 136 # have. What this means is that the TensorFlow order might be 137 # "foo", "bar", "stuff" and you might want the TF lite op order to be 138 # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this 139 # attribute to [2, 0, 1, -1]. 140 TFLITE_INPUT_INDICES = "_tflite_input_indices" 141 # OpHint level. 142 FUNCTION_LEVEL_ATTR = "_tflite_ophint_level" 143 # Ophint internal mapping, this is for high level Ophint only. 144 # This basically contains three kinds of mapping: 145 # 1) How parental ophinted inputs map to the first child ophinted inputs; 146 # 2) How internal children nodes are connected; 147 # 3) How parental ophinted outputs map to the last child ophinted outputs. 148 CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping" 149 150 # Types of aggregations 151 # stack: stacks all ophints with matching tags. i.e. for a static rnn. 152 # specifically, this is good for an input or output to a static rnn cell. 153 AGGREGATE_STACK = "stack" 154 # first: only takes the first output (one with lowest sort index) 155 # of matching tags. This is good for the input state to an RNN. 156 AGGREGATE_FIRST = "first" 157 # aggregation last takes only the last tag (one with highest sort index). 158 # This is good for an output value on the last stack item of a 159 # static rnn. 160 AGGREGATE_LAST = "last" 161 162 class OpHintArgumentTracker(object): 163 """Conceptually tracks indices of arguments of "OpHint functions". 164 165 The inputs and arguments of these functions both use an instance 166 of the class so they can have independent numbering. 167 """ 168 169 def __init__(self, 170 function_name, 171 unique_function_id, 172 node_name_prefix, 173 attr_name, 174 level=1, 175 children_inputs_mappings=None): 176 """Initialize ophint argument. 177 178 Args: 179 function_name: Name of the function that this tracks arguments for. 180 unique_function_id: UUID of function that this tracks arguments for. 181 node_name_prefix: How identities that are created are named. 182 attr_name: Name of attribute to use to store the index for this hint. 183 i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX 184 level: Hierarchical level of the Ophint node, a number. 185 children_inputs_mappings: Inputs/Outputs mapping for children hints. 186 """ 187 188 # The global index is the argument index of the op. This is in contrast 189 # to the sort index which is the sequence number of a particular instance 190 # of a given global index. For example, you may have called add hint 191 # twice with the tag "foo". Then the global index will be 0 for both 192 # and the sort index will be 0 for the first added and 1 for the second. 193 self._function_name = function_name 194 self._unique_function_id = unique_function_id 195 self._next_global_index = 0 # The absolute global index 196 self._used_global_indices = set() 197 self._tag_to_global_index = {} # The argument index a given tag maps to 198 self._tag_to_next_sort_index = {} # The current index for each tag 199 self._node_name_prefix = node_name_prefix 200 self._attr_name = attr_name 201 self._level = level 202 self._children_inputs_mappings = children_inputs_mappings 203 204 def _get_new_global_index(self, index_override): 205 """Return the next unused argument index in order or use an override. 206 207 Args: 208 index_override: An index to use instead of the next available or None 209 to use the next available. 210 211 Returns: 212 A valid global_index to use for the next hint argument. 213 214 Raises: 215 ValueError: If the index_override is already used by another hint. 216 """ 217 if index_override is None: 218 global_index = self._next_global_index 219 else: 220 if index_override in self._used_global_indices: 221 raise ValueError("Index %d was already used by another call to add") 222 global_index = index_override 223 # Make next_global_index valid 224 self._used_global_indices.add(global_index) 225 while self._next_global_index in self._used_global_indices: 226 self._next_global_index += 1 227 return global_index 228 229 def add(self, arg, tag=None, name=None, aggregate=None, 230 index_override=None): 231 """Return a wrapped tensor of an input tensor as an argument. 232 233 Args: 234 arg: A TensorFlow tensor that should be considered an argument. 235 tag: String tag to identify arguments that should be packed. 236 name: Name of argument. This is included in the Identity hint op names. 237 aggregate: Strategy to aggregate. 238 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 239 and OpHint.AGGREGATE_STACK. 240 Note, aggregate is only valid if tag is specified. 241 index_override: Specify what input/output index should this be in the 242 final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the 243 final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than 244 the default call order based ordering. 245 246 Returns: 247 A tensor representing the wrapped argument. 248 249 Raises: 250 ValueError: When indices are not consistent. 251 """ 252 253 # Find the appropriate index 254 if tag is None: 255 if aggregate is not None: 256 raise ValueError("You must specify `tag` if using aggregate.") 257 global_index = self._get_new_global_index(index_override) 258 sort_index = None 259 else: 260 if aggregate is None: 261 raise ValueError("You must specify `aggregate` if using tag.") 262 if tag not in self._tag_to_global_index: 263 self._tag_to_global_index[tag] = ( 264 self._get_new_global_index(index_override)) 265 self._tag_to_next_sort_index[tag] = 0 266 elif (index_override and 267 index_override != self._tag_to_global_index[tag]): 268 raise ValueError( 269 "Tag %r was called with two indices %r and %r" % 270 (tag, index_override, self._tag_to_global_index[tag])) 271 global_index = self._tag_to_global_index[tag] 272 sort_index = self._tag_to_next_sort_index[tag] 273 self._tag_to_next_sort_index[tag] += 1 274 275 uuid = self._unique_function_id 276 name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, 277 uuid, global_index, sort_index, name) 278 279 identity_op = _array_ops.identity(arg, name=name) 280 281 # pylint: disable=protected-access 282 identity_op.op._set_attr( 283 OpHint.FUNCTION_NAME_ATTR, 284 _attr_value_pb2.AttrValue( 285 s=_compat.as_bytes(self._function_name))) 286 identity_op.op._set_attr( 287 OpHint.FUNCTION_UUID_ATTR, 288 _attr_value_pb2.AttrValue( 289 s=_compat.as_bytes(self._unique_function_id))) 290 identity_op.op._set_attr( 291 self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) 292 identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR, 293 _attr_value_pb2.AttrValue(i=self._level)) 294 if self._children_inputs_mappings: 295 identity_op.op._set_attr( 296 OpHint.CHILDREN_INPUTS_MAPPINGS, 297 _attr_value_pb2.AttrValue( 298 s=_compat.as_bytes(_json.dumps( 299 self._children_inputs_mappings)))) 300 301 if sort_index is not None: 302 identity_op.op._set_attr( 303 OpHint.FUNCTION_SORT_INDEX_ATTR, 304 _attr_value_pb2.AttrValue(i=sort_index)) 305 if aggregate is not None: 306 identity_op.op._set_attr( 307 OpHint.FUNCTION_AGGREGATE_ATTR, 308 _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate)))) 309 # pylint: enable=protected-access 310 return identity_op 311 312 def __init__(self, 313 function_name, 314 level=1, 315 children_inputs_mappings=None, 316 **kwargs): 317 """Create a OpHint. 318 319 Args: 320 function_name: Name of the function (the custom op name in tflite) 321 level: OpHint level. 322 children_inputs_mappings: Children OpHint inputs/outputs mapping. 323 children_inputs_mappings should like below: 324 "parent_first_child_input": 325 [{"parent_input_index": num, "child_input_index": num}, ...] 326 "parent_last_child_output": 327 [{"parent_output_index": num, "child_output_index": num}, ...] 328 "internal_children_input_output": 329 [{"child_input_index": num, "child_output_index": num}, ...] 330 **kwargs: Keyword arguments of any constant attributes for the function. 331 """ 332 self._function_name = function_name 333 self._level = level 334 if self._level == 1: 335 assert children_inputs_mappings is None 336 else: 337 assert isinstance(children_inputs_mappings, dict) 338 self._children_inputs_mappings = children_inputs_mappings 339 if self._children_inputs_mappings is not None: 340 self._validate_children_inputs_mappings(self._children_inputs_mappings) 341 self._unique_function_id = _uuid.uuid1().hex 342 self._attrs_to_store_later = kwargs 343 self._stored_attrs = False 344 self._inputs = OpHint.OpHintArgumentTracker( 345 self._function_name, self._unique_function_id, "InputHint", 346 OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings) 347 self._outputs = OpHint.OpHintArgumentTracker( 348 self._function_name, self._unique_function_id, "OutputHint", 349 OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level, 350 self._children_inputs_mappings) 351 352 def _validate_children_inputs_mappings(self, children_inputs_mappings): 353 """Validate children inputs mappings is in the right format. 354 355 Args: 356 children_inputs_mappings: the Children ophint inputs/outputs mapping. 357 """ 358 assert isinstance(children_inputs_mappings, dict) 359 assert "parent_first_child_input" in children_inputs_mappings 360 assert "parent_last_child_output" in children_inputs_mappings 361 assert "internal_children_input_output" in children_inputs_mappings 362 363 # validate parent_first_child_input. 364 365 def assert_dictlist_has_keys(dictlist, keys): 366 for dikt in dictlist: 367 assert isinstance(dikt, dict) 368 for key in keys: 369 assert key in dikt 370 371 assert_dictlist_has_keys( 372 children_inputs_mappings["parent_first_child_input"], 373 ["parent_ophint_input_index", "first_child_ophint_input_index"]) 374 assert_dictlist_has_keys( 375 children_inputs_mappings["parent_last_child_output"], 376 ["parent_output_index", "child_output_index"]) 377 assert_dictlist_has_keys( 378 children_inputs_mappings["internal_children_input_output"], 379 ["child_input_index", "child_output_index"]) 380 381 def _setattr(self, dest_op, name, value): 382 tensor_value = _ops.convert_to_tensor(value) 383 # pylint: disable=protected-access 384 dest_op.op._set_attr(name, _attr_value_pb2.AttrValue( 385 tensor=tensor_value.op.node_def.attr["value"].tensor)) 386 # pylint: enable=protected-access 387 388 def add_input(self, *args, **kwargs): 389 """Add a wrapped input argument to the hint. 390 391 Args: 392 *args: The input tensor. 393 **kwargs: 394 "name" label 395 "tag" a tag to group multiple arguments that will be aggregated. I.e. 396 a string like 'cool_input'. Basically multiple inputs can be added 397 to the same hint for parallel operations that will eventually be 398 combined. An example would be static_rnn which creates multiple copies 399 of state or inputs. 400 "aggregate" aggregation strategy that is valid only for tag non None. 401 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 402 and OpHint.AGGREGATE_STACK. 403 "index_override" The global index to use. This corresponds to the 404 argument order in the final stub that will be generated. 405 Returns: 406 The wrapped input tensor. 407 """ 408 return self._inputs.add(*args, **kwargs) 409 410 def add_output(self, *args, **kwargs): 411 """Add a wrapped output argument to the hint. 412 413 Args: 414 *args: The output tensor. 415 **kwargs: 416 "name" label 417 "tag" a tag to group multiple arguments that will be aggregated. I.e. 418 a string like 'cool_input'. Basically multiple inputs can be added 419 to the same hint for parallel operations that will eventually be 420 combined. An example would be static_rnn which creates multiple copies 421 of state or inputs. 422 "aggregate" aggregation strategy that is valid only for tag non None. 423 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 424 and OpHint.AGGREGATE_STACK. 425 "index_override" The global index to use. This corresponds to the 426 argument order in the final stub that will be generated. 427 Returns: 428 The wrapped output tensor. 429 """ 430 return self._outputs.add(*args, **kwargs) 431 432 def add_inputs(self, *args, **kwargs): 433 """Add a sequence of inputs to the function invocation. 434 435 Args: 436 *args: List of inputs to be converted (should be Tf.Tensor). 437 **kwargs: This allows 'names' which should be a list of names. 438 439 Returns: 440 Wrapped inputs (identity standins that have additional metadata). These 441 are also are also tf.Tensor's. 442 """ 443 if "names" in kwargs: 444 return [ 445 self._inputs.add(arg, name=name) 446 for arg, name in zip(args, kwargs["names"]) 447 ] 448 else: 449 return [self._inputs.add(arg) for arg in args] 450 451 def add_outputs(self, *args, **kwargs): 452 """Add a sequence of outputs to the function invocation. 453 454 Args: 455 *args: List of outputs to be converted (should be tf.Tensor). 456 **kwargs: See 457 458 Returns: 459 Wrapped outputs (identity standins that have additional metadata). These 460 are also tf.Tensor's. 461 """ 462 if "names" in kwargs: 463 return [ 464 self._outputs.add(arg, name=name) 465 for arg, name in zip(args, kwargs["names"]) 466 ] 467 else: 468 return [self._outputs.add(arg) for arg in args] 469 470 471class _LiteOperand(object): 472 """Abstract operand for a tflite hint function._dynamic_rnn_loop. 473 474 This is a base class that handles representing arguments to an OpHint. 475 It also is able to serialize operands to the stubbed graph_def. 476 Child classes are responsible for being able to 477 store information about the hint identity operators. They are also responsible 478 for knowing how to serialize to output graphdefs. 479 480 Typically this will be implemented by holding one or more identity nodes 481 that were previously discovered as hints. 482 """ 483 484 def aggregate_and_return_name_for_input(self, out_graphdef): 485 """This adds the node(s) to out_graphdef and returns the input node name. 486 487 Args: 488 out_graphdef: A graphdef that is ready to have this input added. 489 490 Returns: 491 The output that the stub should use as an input for this operand. 492 493 Raises: 494 RuntimeError: if the method is not implemented. 495 """ 496 del out_graphdef 497 raise RuntimeError("Unimplemented abstract method.") 498 499 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 500 out_graphdef): 501 """Add node(s) to graph representing output operands and returns type. 502 503 Args: 504 fused_op_name: name of the fused op stub name. 505 output_index: Output index that we are currently processing from stub. 506 out_graphdef: The destination graphdef we are currently building up. 507 508 Returns: 509 The datatype of this identity. 510 511 Raises: 512 RuntimeError: if the method is not implemented. 513 """ 514 del fused_op_name, output_index, out_graphdef 515 raise RuntimeError("Unimplemented abstract method.") 516 517 518class _LiteSingleOperand(_LiteOperand): 519 """A simple operand that is non-aggregated (i.e. most hints).""" 520 521 def __init__(self, node): 522 _LiteOperand.__init__(self) 523 self.node = node 524 self.name = _tensor_name_base(node.name) 525 526 def flatten(self): 527 return [self.name] 528 529 def aggregate_and_return_name_for_input(self, out_graphdef): 530 return self.name 531 532 def aggregate_and_return_name_for_output(self, fused_op_name, index, 533 out_graphdef): 534 output_node = _copy.deepcopy(self.node) 535 del output_node.input[:] 536 output_node.input.append(_tensorflow_output_name(fused_op_name, index)) 537 out_graphdef.node.extend([output_node]) 538 return self.node.attr["type"].i 539 540 def __str__(self): 541 return str(self.name) 542 543 544class _LiteAggregateOperand(_LiteOperand): 545 """An operand for a tflite hint function that is aggregated from many. 546 547 For example, an LSTM is a grid of operators that are all related. Inputs 548 going into them may need to be fused, so they should all be tracked as 549 related arguments. 550 """ 551 552 def __init__(self, aggregation): 553 _LiteOperand.__init__(self) 554 self.aggregation = aggregation 555 self.names = {} 556 self.nodes = {} 557 self.flattened = None 558 559 def add(self, sort, node): 560 self.names[sort] = _tensor_name_base(node.name) 561 self.nodes[sort] = node 562 563 def flatten_nodes(self): 564 """Return a list of all the node protos in aggregation sorted order.""" 565 if not self.flattened: 566 self.flattened = [None] * len(self.nodes) 567 for idx, node in _six.iteritems(self.nodes): 568 self.flattened[idx] = node 569 for n in self.nodes: 570 if n is None: 571 raise RuntimeError("Aggregate was missing argument.") 572 if self.aggregation == OpHint.AGGREGATE_FIRST: 573 self.flattened = self.flattened[:1] 574 elif self.aggregation == OpHint.AGGREGATE_LAST: 575 self.flattened = self.flattened[-1:] 576 elif self.aggregation == OpHint.AGGREGATE_STACK: 577 pass 578 else: 579 raise ValueError("Invalid aggregation type %r specified" % 580 self.aggregation) 581 return self.flattened 582 583 def flatten(self): 584 """Return a list of all node names in aggregation sorted sorter.""" 585 return [_tensor_name_base(x.name) for x in self.flatten_nodes()] 586 587 def aggregate_and_return_name_for_input(self, out_graphdef): 588 """This adds the nodes to out_graphdef and returns an aggregated output. 589 590 In particular, if you have 4 inputs to a hint stub, this will be the 591 node that you can use as an output. I.e. you have 4 timesteps from a 592 static rnn, then a fused UnidirectionalLSTM will expect 1 input with 593 all 4 time steps. So here we make a pack and return the output name of 594 that pack. 595 596 Args: 597 out_graphdef: A graphdef that is ready to have this input added. 598 599 Returns: 600 The name of a pack that aggregates this node. 601 """ 602 flattened = self.flatten_nodes() 603 if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( 604 self.aggregation == OpHint.AGGREGATE_LAST): 605 assert len(flattened) == 1 606 if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: 607 return _tensor_name_base(flattened[0].name) 608 else: 609 new_node = _node_def_pb2.NodeDef() 610 new_node.op = "Pack" 611 new_node.name = "OpHintStack-%s" % flattened[0].name 612 new_node.attr["N"].i = len(flattened) 613 new_node.attr["T"].type = flattened[0].attr["T"].type 614 for discrete in flattened: 615 new_node.input.append(_tensor_name_base(discrete.name)) 616 out_graphdef.node.extend([new_node]) 617 return new_node.name 618 619 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 620 out_graphdef): 621 """This adds to `out_graphdef` all the unaggregated outputs. 622 623 I.e. we are outputting from a fused stub, but we need to make it compatible 624 with the unfused original graph so we insert an unpack. Ideally in a later 625 stage the unpack -> pack sequences will be removed. 626 627 Args: 628 fused_op_name: The name of the stub we are in the process of fusing. 629 output_index: The output output_index this object represents. 630 out_graphdef: The graphdef we are in the process of buildings 631 632 Returns: 633 The type of the aggregated output (so we can finish building the stub 634 op). 635 """ 636 flattened = self.flatten_nodes() 637 if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( 638 self.aggregation == OpHint.AGGREGATE_LAST): 639 assert len(flattened) == 1 640 if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: 641 temp_op = _LiteSingleOperand(flattened[0]) 642 return temp_op.aggregate_and_return_name_for_output( 643 fused_op_name, output_index, out_graphdef) 644 else: 645 stack_node = _node_def_pb2.NodeDef() 646 stack_node.op = "Unpack" 647 stack_node.name = "OpHintUnstack-%s" % flattened[0].name 648 stack_node.attr["num"].i = len(flattened) 649 output_type = flattened[0].attr["T"].type 650 stack_node.attr["T"].type = output_type 651 stack_node.input.append( 652 _tensorflow_output_name(fused_op_name, output_index)) 653 out_graphdef.node.extend([stack_node]) 654 655 for idx, discrete in enumerate(flattened): 656 output_node = _copy.deepcopy(discrete) 657 del output_node.input[:] 658 output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) 659 out_graphdef.node.extend([output_node]) 660 661 return output_type 662 663 def __str__(self): 664 s = "\t\t\tAGGREGATE %s\n" % self.aggregation 665 for sort, val in self.names.iteritems(): 666 s += "\t\t\t%d: %s\n" % (sort, val) 667 return s 668 669 670class _LiteFuncCall(object): 671 """Represent a TensorFlow Lite custom function. 672 673 This is uses to accumulate found hints in the graphdef into a single 674 conceptual unit. 675 676 Attributes: 677 inputs: inputs to the op (hash from index # to argument) 678 outputs: outputs to the op (hash from index # to argument) 679 function_name: the tflite custom op name to use 680 uuid: a unique call id for this particular call (i.e. multiple function 681 calls would have the same function_name but different uuids. 682 params: A param name to key value for op constant data. I.e. for axis on a 683 reduction, strides on a convolution, etc. 684 level: Level of the OpHint. 685 children_inputs_mappings: If the Ophint has children, children inputs 686 mappings indicate how their inputs & outputs are mapped. 687 """ 688 689 def __init__(self): 690 self.inputs = {} 691 self.outputs = {} 692 self.function_name = None 693 self.uuid = None 694 self.params = {} 695 self.level = -1 696 self.children_inputs_mappings = {} 697 698 def flattened_inputs_and_outputs(self): 699 """Return a list of inputs and outputs in a flattened format. 700 701 Returns: 702 Tuple of (inputs, outputs). where input and output i a list of names. 703 """ 704 705 def _flatten(input_or_output_dict): 706 flattened_items = [] 707 for item in input_or_output_dict.values(): 708 flattened_items.extend(item.flatten()) 709 return flattened_items 710 711 return _flatten(self.inputs), _flatten(self.outputs) 712 713 def __str__(self): 714 715 def format_args(items): 716 s = "" 717 for idx, item in items.iteritems(): 718 s += ("\t\t%d:\n" % idx) + str(item) 719 return s 720 721 inputs_str = "\tInputs\n" + format_args(self.inputs) 722 outputs_str = "\tOutputs\n" + format_args(self.outputs) 723 724 return ( 725 "tflite function %s call %s level %d " 726 "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" % 727 (self.function_name, self.uuid, self.level, inputs_str, outputs_str)) 728 729 730def _find_all_hints_in_nodes(nodes): 731 """Look at the all the input nodes and return a list of LiteFuncCall objs. 732 733 Args: 734 nodes: A TensorFlow graph_def to look for LiteFuncCalls. 735 736 Returns: 737 a list of `LifeFuncCall` objects in the form 738 739 """ 740 func_calls = _collections.defaultdict(_LiteFuncCall) 741 742 for node in nodes: 743 attr = node.attr 744 # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip 745 if (OpHint.FUNCTION_UUID_ATTR not in attr or 746 not attr[OpHint.FUNCTION_UUID_ATTR].s): 747 continue 748 uuid = attr[OpHint.FUNCTION_UUID_ATTR].s 749 750 # Start building function 751 call_def = func_calls[uuid] 752 call_def.uuid = uuid 753 call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s 754 call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i 755 # Get sorting and aggregation information 756 757 sort = ( 758 attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i 759 if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) 760 if sort == -1: 761 sort = None 762 aggregation = None 763 if OpHint.FUNCTION_AGGREGATE_ATTR in attr: 764 aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) 765 766 if OpHint.CHILDREN_INPUTS_MAPPINGS in attr: 767 call_def.children_inputs_mappings = _json.loads( 768 _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s)) 769 770 # Add the input or output 771 def put_operand(stuff, index, sort, operand, aggregation): 772 """Add a given index into the function structure.""" 773 if sort is None: 774 stuff[index] = _LiteSingleOperand(operand) 775 else: 776 if index not in stuff: 777 stuff[index] = _LiteAggregateOperand(aggregation) 778 stuff[index].add(sort, operand) 779 780 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 781 put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, 782 sort, node, aggregation) 783 if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: 784 put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, 785 sort, node, aggregation) 786 787 # Remember attributes 788 for a in attr: 789 if a.startswith("_tflite_attr_"): 790 call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor 791 792 return func_calls 793 794 795def _extract_topology_sequence_mapping(nodes): 796 return dict( 797 (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes)) 798 799 800def _find_children_hints_in_while_loop(function_def, nodes_mapping): 801 """Find children hints and all nodes inside the while loop. 802 803 Args: 804 function_def: Function def of the while loop. 805 nodes_mapping: While loop input_arg : real node name. 806 807 Returns: 808 Ordered children hints and all re-mapped nodes inside the while loop. 809 """ 810 new_nodes = [] 811 812 # Make nodes inside function def inputs point to the real nodes. 813 for node in function_def.node_def: 814 for i, _ in enumerate(node.input): 815 if node.input[i] in nodes_mapping: 816 node.input[i] = nodes_mapping[node.input[i]] 817 new_nodes.append(_copy.deepcopy(node)) 818 name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def) 819 children_hints = _find_all_hints_in_nodes(new_nodes) 820 children_hints_q = [] 821 # Ordered by the outputs. 822 for hint in _six.itervalues(children_hints): 823 _, output_names = hint.flattened_inputs_and_outputs() 824 seq = name_to_seq_num[output_names[0]] 825 for output_name in output_names: 826 seq = min(seq, name_to_seq_num[output_name]) 827 children_hints_q.append((seq, hint)) 828 children_hints_q.sort(key=lambda tup: tup[0]) 829 ordered_children_hints = [x[1] for x in children_hints_q] 830 return ordered_children_hints, new_nodes 831 832 833def _find_children_hints(call, graph_def): 834 """Find all children hints. 835 836 For a given OpHint, we find all children hints inside it, we also copy all the 837 nodes inside function defs (if applicable) to the original graph_def, they are 838 returned in a list as well. 839 840 Args: 841 call: Parent OpHint that contains children ophints. 842 graph_def: Original graph def. 843 844 Returns: 845 Ordered children hints inside the parent ophint; new graph def that contains 846 nodes inside function defs (if applicable); nodes inside function defs. 847 """ 848 name_to_input_name, _, _ = _extract_graph_summary(graph_def) 849 input_names, output_names = call.flattened_inputs_and_outputs() 850 851 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 852 reachable_by_output = _bfs_for_reachable_nodes(output_names, 853 name_to_input_name) 854 output_nodes_set = set(output_names) 855 children_hints = [] 856 out = _graph_pb2.GraphDef() 857 out.library.CopyFrom(graph_def.library) 858 out.versions.CopyFrom(graph_def.versions) 859 function_def_nodes = set() 860 for node in graph_def.node: 861 out.node.extend([_copy.deepcopy(node)]) 862 n = _tensor_name_base(node.name) 863 if n in reachable_by_output: 864 if n not in reachable_by_input and n not in output_nodes_set: 865 # special handle for while loop function def. 866 if node.op == "While" or node.op == "StatelessWhile": 867 body_name = node.attr["body"].func.name 868 inputs_outside_loop = node.input 869 for function_def in graph_def.library.function: 870 if function_def.signature.name == body_name: 871 function_inputs = function_def.signature.input_arg 872 assert len(inputs_outside_loop) == len(function_inputs) 873 nodes_mapping = {} 874 for i, function_input in enumerate(function_inputs): 875 nodes_mapping[function_input.name] = inputs_outside_loop[i] 876 (children_hints_in_loop, 877 new_nodes) = _find_children_hints_in_while_loop( 878 function_def, nodes_mapping) 879 function_def_nodes.update([x.name for x in new_nodes]) 880 children_hints.extend(children_hints_in_loop) 881 out.node.extend(new_nodes) 882 883 return children_hints, out, function_def_nodes 884 885 886def _tensor_name_base(full_tensor_name): 887 """Removes the device assignment code from a tensor. 888 889 e.g. _tensor_name_base("foo:3") => "foo" 890 891 Args: 892 full_tensor_name: A tensor name that is annotated with a device placement 893 (this is what tensor flow introspection gives). 894 895 Returns: 896 A name without any device assignment. 897 """ 898 if full_tensor_name.startswith("^"): 899 return full_tensor_name[1:] 900 return full_tensor_name.split(":")[0] 901 902 903def _tensorflow_output_name(tensor_name, output_index): 904 return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, 905 output_index) 906 907 908def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, 909 name_to_input_name): 910 """Checks to make sure node only connects to predecessor graph through inputs. 911 912 Args: 913 n: Node to check 914 reachable_by_input: Nodes that are reachable by all inputs of subgraph 915 input_nodes_set: The set of nodes that are "inputs". 916 name_to_input_name: Maps from name to the list of inputs. 917 918 Raises: 919 TypeError: If the given node uses items past inputs directly. 920 """ 921 next_to_visit = [n] 922 visited = set() 923 while next_to_visit: 924 current_node = next_to_visit.pop() 925 visited.add(current_node) 926 if (current_node in reachable_by_input and 927 current_node not in input_nodes_set): 928 raise TypeError("Node %s uses input %s not in input_nodes." % 929 (n, current_node)) 930 if current_node not in input_nodes_set: 931 next_to_visit += [ 932 input_node for input_node in name_to_input_name[current_node] 933 if input_node not in visited 934 ] 935 936 937def _convert_single_op_hint_to_stub(call, 938 graph_def, 939 function_def_nodes=None, 940 is_last_run=True): 941 """Given a graph_def, converts `call` into a stub and returns a new graph_def. 942 943 Args: 944 call: A single function call to be converted. 945 graph_def: A graph_def to use as input (that has call obviously). 946 function_def_nodes: Nodes inside the function def those are not connected to 947 the graph. 948 is_last_run: Whether it is the last run for a given pass (for OpHint has 949 children). 950 951 Returns: 952 A new transformed graph-def that has call as a stub (single op). 953 954 Note: after this process, the graph_def can no longer be loaded into 955 the tensorflow runtime, so all future manipulations are done in graph_def 956 level. 957 """ 958 if function_def_nodes is None: 959 function_def_nodes = set() 960 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 961 graph_def) 962 input_names, output_names = call.flattened_inputs_and_outputs() 963 964 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 965 reachable_by_output = _bfs_for_reachable_nodes(output_names, 966 name_to_input_name) 967 output_nodes_set = set(output_names) 968 nodes_after_fuse = [] 969 nodes_deleted_by_fuse = set() 970 # Classify each node. We want to keep everything reachable by input, but 971 # we don't know if things that are not reachable by output or input (things 972 # after fusing). 973 for node in graph_def.node: 974 n = _tensor_name_base(node.name) 975 if n in reachable_by_output: 976 if n not in reachable_by_input and n not in output_nodes_set: 977 nodes_deleted_by_fuse.add(n) 978 elif n not in reachable_by_input and n not in function_def_nodes: 979 # n is a node that after all the fusings, so keep it. 980 nodes_after_fuse.append(n) 981 else: 982 # In the last run, n is a node that is randomly in the graph but not 983 # connected to the chain of dependencies, we will delete n, otherwise 984 # we keep them. 985 if not is_last_run: 986 nodes_after_fuse.append(n) 987 988 # Make a new graphdef with all the pre-input and input nodes 989 out = _graph_pb2.GraphDef() 990 reachable_by_input_sorted = sorted( 991 list(reachable_by_input), key=lambda n: name_to_seq_num[n]) 992 for node in reachable_by_input_sorted: 993 out.node.extend([_copy.deepcopy(name_to_node[node])]) 994 995 # Create any stacks to aggregate arguments into to a single input 996 # i.e. for static_rnn's. 997 sorted_input_indices = list(call.inputs.keys()) 998 sorted_input_indices.sort() 999 sorted_output_indices = list(call.outputs.keys()) 1000 sorted_output_indices.sort() 1001 new_node = _node_def_pb2.NodeDef() 1002 # Delegate to each operand to produce the proper new input for this stub node. 1003 # In particular, an aggregate input will now be a Pack of some previously 1004 # non-fused things. 1005 1006 optional_input_node = _node_def_pb2.NodeDef() 1007 optional_input_node.name = "Const" + str(_uuid.uuid1().hex) 1008 optional_input_node.op = "Const" 1009 optional_input_node.attr["dtype"].CopyFrom( 1010 _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum)) 1011 optional_input_node.attr["value"].CopyFrom( 1012 _attr_value_pb2.AttrValue( 1013 tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1]))) 1014 out.node.extend([optional_input_node]) 1015 1016 max_index = max(sorted_input_indices) + 1 1017 for cur_index in range(max_index): 1018 if cur_index in sorted_input_indices: 1019 inputs = call.inputs[cur_index] 1020 input_name = inputs.aggregate_and_return_name_for_input(out) 1021 new_node.input.append(input_name) 1022 else: 1023 new_node.input.append(optional_input_node.name) 1024 1025 new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) 1026 1027 # Create the function 1028 new_node.op = call.function_name 1029 new_node.name = call.uuid 1030 out.node.extend([new_node]) 1031 1032 # Now call each output argument to give them a chance to make the proper 1033 # output type and add it to our new_node. 1034 output_dtypes = [] 1035 max_output_index = max(sorted_output_indices) + 1 1036 for cur_index in range(max_output_index): 1037 if cur_index in sorted_output_indices: 1038 output = call.outputs[cur_index] 1039 output_dtype = ( 1040 output.aggregate_and_return_name_for_output(new_node.name, cur_index, 1041 out)) 1042 else: 1043 output_dtype = optional_input_node.attr["type"].i 1044 output_dtypes.append(output_dtype) 1045 new_node.attr["_output_types"].list.type[:] = output_dtypes 1046 new_node.attr["_output_quantized"].b = False 1047 1048 # Add post output nodes that do not depend on the outputs 1049 for n in nodes_after_fuse: 1050 should_keep = True 1051 for input_name in name_to_input_name[n]: 1052 if input_name in nodes_deleted_by_fuse: 1053 should_keep = False 1054 if should_keep: 1055 out.node.extend([_copy.deepcopy(name_to_node[n])]) 1056 1057 # Misc. graph_def data that needs copying. 1058 out.library.CopyFrom(graph_def.library) 1059 out.versions.CopyFrom(graph_def.versions) 1060 1061 return out 1062 1063 1064def _remove_one_redundant_stack_unstack(in_graph_def): 1065 """Removes a stack->unstack pattern from in_graph_def in a returned graph. 1066 1067 Args: 1068 in_graph_def: Graph def to use as input. 1069 1070 Returns: 1071 Simplified tuple (graph_def, changed_something) where changed_something 1072 is true if anything was done. 1073 """ 1074 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 1075 in_graph_def) 1076 del name_to_seq_num 1077 1078 do_generic_pack_unpack = True 1079 1080 out = _graph_pb2.GraphDef() 1081 out.library.CopyFrom(in_graph_def.library) 1082 out.versions.CopyFrom(in_graph_def.versions) 1083 for n in in_graph_def.node: 1084 node_name = _tensor_name_base(n.name) 1085 if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): 1086 continue 1087 next_to_visit = [node_name] 1088 visited = set() 1089 1090 unpack_nodes = set() 1091 pack_node = node_name 1092 1093 # Find a pattern of unstack connected to a stack (with identities 1094 # in between. 1095 matches_pattern = True 1096 is_hint_created_stack = False 1097 while next_to_visit: 1098 current_node_name = next_to_visit[0] 1099 visited.add(current_node_name) 1100 del next_to_visit[0] 1101 node = name_to_node[current_node_name] 1102 is_op_hint_stack = node.name.startswith("OpHintStack") 1103 is_op_hint_unstack = node.name.startswith("OpHintUnstack") 1104 if (node.op == "Identity" or is_op_hint_stack or 1105 (do_generic_pack_unpack and node.op == "Pack")): 1106 is_hint_created_stack |= is_op_hint_stack 1107 next_to_visit += [ 1108 input_node for input_node in name_to_input_name[current_node_name] 1109 if input_node not in visited 1110 ] 1111 elif (is_op_hint_unstack or 1112 (do_generic_pack_unpack and node.op == "Unpack")): 1113 unpack_nodes.add(node.name) 1114 is_hint_created_stack &= is_op_hint_unstack 1115 else: 1116 matches_pattern = False 1117 break 1118 visited.add(node.name) 1119 1120 if matches_pattern and len(unpack_nodes) == 1: 1121 pack_node = node_name 1122 1123 # Check to see if anyone depends on the intermediate identity or the 1124 # Unstacked form 1125 no_external_dependency = True 1126 for other_n in in_graph_def.node: 1127 if other_n.name in visited: 1128 continue 1129 for input_tensor in name_to_input_name[other_n.name]: 1130 input_op = _tensor_name_base(input_tensor) 1131 if input_op in visited and input_op != pack_node: 1132 no_external_dependency = False 1133 # Proceed with the substitution if the stack/unstack pair was created 1134 # through hints, or that it was not, but nobody is consuming things 1135 # between the stack and unstack. 1136 if is_hint_created_stack or no_external_dependency: 1137 end = unpack_nodes.pop() 1138 end_input = name_to_node[end].input[0] 1139 # All nodes that depend on the final stack need to be redone to use 1140 for other_n in in_graph_def.node: 1141 node_name = _tensor_name_base(other_n.name) 1142 if node_name not in visited: 1143 new_node = _copy.deepcopy(other_n) 1144 new_node.input[:] = [ 1145 (end_input if stripped == pack_node else non_stripped) 1146 for stripped, non_stripped in zip(name_to_input_name[node_name], 1147 new_node.input[:]) 1148 ] 1149 out.node.extend([new_node]) 1150 return out, True 1151 return in_graph_def, False 1152 1153 1154def _remove_redundant_stack_unstack(graph_def): 1155 curr = graph_def 1156 del graph_def 1157 changed_stuff = True 1158 while changed_stuff: 1159 curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) 1160 return curr 1161 1162 1163def _get_correct_mapping(original_index, nodes): 1164 # Special handle for the index is -1 case. 1165 # If it is -1, return the last index. 1166 if original_index == -1: 1167 node_indices = nodes.keys() 1168 node_indices = sorted(node_indices) 1169 return node_indices[-1] 1170 return original_index 1171 1172 1173def _convert_op_hints_to_stubs_helper( 1174 graph_def, write_callback=lambda sess, graph_def: None): 1175 """Converts a graph_def to a new graph_def where all op hints are stubbed. 1176 1177 Args: 1178 graph_def: A graph def that we should convert. 1179 write_callback: A function pointer that can be used to write intermediate 1180 steps of graph transformation (optional). 1181 1182 Returns: 1183 A new stubbed graph_def. 1184 """ 1185 hints = _find_all_hints_in_nodes(graph_def.node) 1186 1187 hints_q = [] 1188 for hint in _six.itervalues(hints): 1189 hints_q.append((hint.level, hint.uuid)) 1190 1191 hints_q.sort(key=lambda tup: tup[0]) 1192 for i in range(len(hints_q) - 1, -1, -1): 1193 level, hint_uuid = hints_q[i] 1194 1195 curr_graph_def = graph_def 1196 del graph_def # prevent using graph_def again (common source of error) 1197 for i in range(len(hints_q) - 1, -1, -1): 1198 level, hint_uuid = hints_q[i] 1199 if level >= 2: 1200 children_hints, curr_graph_def, function_def_nodes = _find_children_hints( 1201 hints[hint_uuid], curr_graph_def) 1202 # pylint: disable=superfluous-parens 1203 assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test 1204 # pylint: enable=superfluous-parens 1205 1206 # Re-wire the children hints inputs/outputs, so latter child's inputs 1207 # connect to previous child node's outputs. 1208 children_inputs_mappings = hints[hint_uuid].children_inputs_mappings 1209 for j, child_hint in enumerate(children_hints): 1210 if j == 0: 1211 for mapping in children_inputs_mappings["parent_first_child_input"]: 1212 parent_input_index = _get_correct_mapping( 1213 mapping["parent_ophint_input_index"], hints[hint_uuid].inputs) 1214 child_input_index = _get_correct_mapping( 1215 mapping["first_child_ophint_input_index"], child_hint.inputs) 1216 child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[ 1217 parent_input_index] 1218 else: 1219 for mapping in children_inputs_mappings[ 1220 "internal_children_input_output"]: 1221 input_index = _get_correct_mapping(mapping["child_input_index"], 1222 child_hint.inputs) 1223 output_index = _get_correct_mapping(mapping["child_output_index"], 1224 children_hints[j - 1].outputs) 1225 child_hint.inputs[input_index] = children_hints[ 1226 j - 1].outputs[output_index] 1227 if j == len(children_hints) - 1: 1228 for mapping in children_inputs_mappings["parent_last_child_output"]: 1229 parent_output_index = _get_correct_mapping( 1230 mapping["parent_output_index"], hints[hint_uuid].outputs) 1231 child_output_index = _get_correct_mapping( 1232 mapping["child_output_index"], child_hint.outputs) 1233 child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[ 1234 parent_output_index] 1235 1236 for j, child_hint in enumerate(children_hints): 1237 curr_graph_def = _convert_single_op_hint_to_stub( 1238 child_hint, curr_graph_def, function_def_nodes, 1239 j == len(children_hints) - 1) 1240 else: 1241 curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid], 1242 curr_graph_def) 1243 write_callback(curr_graph_def, "initial") 1244 # The stubbing process can create stacks/unstacks in the case of LSTMs 1245 # remove them. 1246 curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) 1247 return curr_graph_def 1248 1249 1250def find_all_hinted_output_nodes(session=None, graph_def=None): 1251 """Find all Ophints output nodes in the graph. 1252 1253 This is used to get all the output nodes those are ophinted, it is important 1254 for operation like convert_variables_to_constants keep all ophints structure. 1255 Note: only one of session or graph_def should be used, not both. 1256 Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can 1257 generate multiple outputs for unfused subgraph. If not all output nodes are 1258 consumed, graph optimization can potentially drop the unused nodes and cause 1259 ophints in an invalid states (due to missing ophinted output nodes). So it's 1260 important for us to find all those hinted output nodes and make sure they're 1261 not discarded away. 1262 1263 Args: 1264 session: A TensorFlow session that contains the graph to convert. 1265 graph_def: A graph def that we should convert. 1266 1267 Returns: 1268 A list of OpHints output nodes. 1269 Raises: 1270 ValueError: If both session and graph_def are provided. 1271 """ 1272 if session is not None and graph_def is not None: 1273 raise ValueError("Provide only one of session and graph_def.") 1274 hinted_outputs_nodes = [] 1275 if session is not None: 1276 hints = _find_all_hints_in_nodes(session.graph_def.node) 1277 elif graph_def is not None: 1278 hints = _find_all_hints_in_nodes(graph_def.node) 1279 for hint in _six.itervalues(hints): 1280 _, output_nodes = hint.flattened_inputs_and_outputs() 1281 hinted_outputs_nodes.extend(output_nodes) 1282 return hinted_outputs_nodes 1283 1284 1285def is_ophint_converted(graph_def): 1286 if graph_def is None: 1287 raise ValueError("Must provide the graph_def.") 1288 ophint_converted = False 1289 for node in graph_def.node: 1290 attr = node.attr 1291 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 1292 ophint_converted = True 1293 break 1294 return ophint_converted 1295 1296 1297@_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"]) 1298@_deprecation.deprecated( 1299 None, 1300 "Please follow instructions under " 1301 "https://www.tensorflow.org/lite/convert/operation_fusion for operation" 1302 "fusion in tflite." 1303) 1304def convert_op_hints_to_stubs(session=None, 1305 graph_def=None, 1306 write_callback=lambda graph_def, comments: None): 1307 """Converts a graphdef with LiteOp hints into stub operations. 1308 1309 This is used to prepare for toco conversion of complex intrinsic usages. 1310 Note: only one of session or graph_def should be used, not both. 1311 1312 Args: 1313 session: A TensorFlow session that contains the graph to convert. 1314 graph_def: A graph def that we should convert. 1315 write_callback: A function pointer that can be used to write intermediate 1316 steps of graph transformation (optional). 1317 1318 Returns: 1319 A new graphdef with all ops contained in OpHints being replaced by 1320 a single op call with the right parameters. 1321 Raises: 1322 ValueError: If both session and graph_def are provided. 1323 """ 1324 1325 if session is not None and graph_def is not None: 1326 raise ValueError("Provide only one of session and graph_def.") 1327 1328 if session is not None: 1329 return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) 1330 elif graph_def is not None: 1331 return _convert_op_hints_to_stubs_helper(graph_def, write_callback) 1332 else: 1333 raise ValueError("Must specify session or graph_def as input.") 1334 1335 1336_allowed_symbols = [ 1337 "OpHint", 1338 "convert_op_hints_to_stubs", 1339 "convert_op_hints_to_stubs_new", 1340 "find_all_hinted_output_nodes", 1341 "is_ophint_converted", 1342] 1343remove_undocumented(__name__, _allowed_symbols) 1344