1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Define tflite op hints (intrinsic operations). 16 17This essentially allows defining a TensorFlow API for tflite operations in 18Python with hints on how they are represented in TensorFlow Lite. This basically 19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution 20graph and is useful for LSTMs and other complicated TensorFlow constructions 21that are difficult to pattern match in TOCO, but are represented by a single 22accelerated tflite op. 23 24Example: 25 def tflite_cool_activation(input): 26 # A cool activation function. 27 custom = tf.lite.OpHint("cool_activation") 28 input, = custom.add_inputs(input) 29 output = tf.sigmoid(input) * input 30 output, = custom.add_outputs(output) 31 return output 32 33 image = tf.compat.v1.placeholder(tf.float32, (1, 16, 16, 1)) 34 output = tf.identity(tflite_cool_activation(image)) 35 36 session = tf.compat.v1.Session() 37 38 graphdef_to_convert = tf.lite.experimental.convert_op_hints_to_stubs(session) 39 tflite_graph = tf.compat.v1.lite.toco_convert( 40 graphdef_to_convert, [image], [output], allow_custom_ops=True) 41 with open("/tmp/graph.fb", "wb") as fp: 42 fp.write(tflite_graph) 43 44How does it work?: 45 46OpHint is a helper that you use when defining a vanilla python function. 47It allows you to wrap arguments with tf.identities with some custom attributes. 48These attributes allow you to find the original block of ops that was created. 49For example, if you use cool_activation above you essentially get: 50 51a_input = tf.identity() 52result = tf.multiply(tf.sigmoid(a_input), a_input) 53output = tf.identity() 54 55a_input, output are identities that have parameters representing 56what argument they are, what the name of the function they should turn into 57in tf lite as well as a guid that uniquely identifies a particular invocation. 58 59Once you have built your whole tensorflow graph, you can run it and train it 60as usual, but after you have done that, you need to convert the graph into 61a form that replaces these subgraphs wrapped in identities to stub ops. These 62ops don't actually exist in the normal TensorFlow runtime, but will be 63understood by toco later. The generated TensorFlow Lite flatbuffer file will 64contain a custom operator called "cool_activation". Developer needs to implement 65and register this operator in TensorFlow Lite in order to do inference. 66""" 67 68import collections as _collections 69import copy as _copy 70import json as _json 71import uuid as _uuid 72 73from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 74from tensorflow.core.framework import graph_pb2 as _graph_pb2 75from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 76from tensorflow.python.framework import dtypes as _dtypes 77from tensorflow.python.framework import ops as _ops 78from tensorflow.python.framework import tensor_util as _tensor_util 79from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 80from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 81from tensorflow.python.ops import array_ops as _array_ops 82from tensorflow.python.util import compat as _compat 83from tensorflow.python.util import deprecation as _deprecation 84from tensorflow.python.util.all_util import remove_undocumented 85from tensorflow.python.util.tf_export import tf_export as _tf_export 86 87 88@_tf_export(v1=["lite.OpHint"]) 89@_deprecation.deprecated( 90 None, 91 "Please follow instructions under " 92 "https://www.tensorflow.org/lite/convert/operation_fusion for operation" 93 "fusion in tflite." 94) 95class OpHint: 96 """A class that helps build tflite function invocations. 97 98 It allows you to take a bunch of TensorFlow ops and annotate the construction 99 such that toco knows how to convert it to tflite. This embeds a pseudo 100 function in a TensorFlow graph. This allows embedding high-level API usage 101 information in a lower level TensorFlow implementation so that an alternative 102 implementation can be substituted later. 103 104 Essentially, any "input" into this pseudo op is fed into an identity, and 105 attributes are added to that input before being used by the constituent ops 106 that make up the pseudo op. A similar process is done to any output that 107 is to be exported from the current op. 108 109 """ 110 # Attr constants that are used for representation in the GraphDef. These 111 # will be used on every Identity op that is involved in a total OpHint. 112 113 # Name of the OpHint function (cosmetic). 114 FUNCTION_NAME_ATTR = "_tflite_function_name" 115 # UUID of the function (each OpHint gets a new uuid). 116 FUNCTION_UUID_ATTR = "_tflite_function_uuid" 117 # The input index of the input (or nothing if it is an output). 118 FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" 119 # The output index of the output (or nothing if it is an input). 120 FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" 121 # An index that orders aggregate arguments. Aggregate arguments are ones 122 # that are separate but will be fused horizontally. For example a static LSTM 123 # has a lstm cell for each time step. Each one has a separate opHint, but a 124 # fused SequentialLSTM will treat this as a single tensor. 125 FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" 126 # The way in which multiple parts of the aggregate argument will be joined 127 # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, 128 # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. 129 FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" 130 # On fused OpHint stub, the order of inputs that the final LSTM call will 131 # have. What this means is that the TensorFlow order might be 132 # "foo", "bar", "stuff" and you might want the TF lite op order to be 133 # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this 134 # attribute to [2, 0, 1, -1]. 135 TFLITE_INPUT_INDICES = "_tflite_input_indices" 136 # OpHint level. 137 FUNCTION_LEVEL_ATTR = "_tflite_ophint_level" 138 # Ophint internal mapping, this is for high level Ophint only. 139 # This basically contains three kinds of mapping: 140 # 1) How parental ophinted inputs map to the first child ophinted inputs; 141 # 2) How internal children nodes are connected; 142 # 3) How parental ophinted outputs map to the last child ophinted outputs. 143 CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping" 144 145 # Types of aggregations 146 # stack: stacks all ophints with matching tags. i.e. for a static rnn. 147 # specifically, this is good for an input or output to a static rnn cell. 148 AGGREGATE_STACK = "stack" 149 # first: only takes the first output (one with lowest sort index) 150 # of matching tags. This is good for the input state to an RNN. 151 AGGREGATE_FIRST = "first" 152 # aggregation last takes only the last tag (one with highest sort index). 153 # This is good for an output value on the last stack item of a 154 # static rnn. 155 AGGREGATE_LAST = "last" 156 157 class OpHintArgumentTracker: 158 """Conceptually tracks indices of arguments of "OpHint functions". 159 160 The inputs and arguments of these functions both use an instance 161 of the class so they can have independent numbering. 162 """ 163 164 def __init__(self, 165 function_name, 166 unique_function_id, 167 node_name_prefix, 168 attr_name, 169 level=1, 170 children_inputs_mappings=None): 171 """Initialize ophint argument. 172 173 Args: 174 function_name: Name of the function that this tracks arguments for. 175 unique_function_id: UUID of function that this tracks arguments for. 176 node_name_prefix: How identities that are created are named. 177 attr_name: Name of attribute to use to store the index for this hint. 178 i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX 179 level: Hierarchical level of the Ophint node, a number. 180 children_inputs_mappings: Inputs/Outputs mapping for children hints. 181 """ 182 183 # The global index is the argument index of the op. This is in contrast 184 # to the sort index which is the sequence number of a particular instance 185 # of a given global index. For example, you may have called add hint 186 # twice with the tag "foo". Then the global index will be 0 for both 187 # and the sort index will be 0 for the first added and 1 for the second. 188 self._function_name = function_name 189 self._unique_function_id = unique_function_id 190 self._next_global_index = 0 # The absolute global index 191 self._used_global_indices = set() 192 self._tag_to_global_index = {} # The argument index a given tag maps to 193 self._tag_to_next_sort_index = {} # The current index for each tag 194 self._node_name_prefix = node_name_prefix 195 self._attr_name = attr_name 196 self._level = level 197 self._children_inputs_mappings = children_inputs_mappings 198 199 def _get_new_global_index(self, index_override): 200 """Return the next unused argument index in order or use an override. 201 202 Args: 203 index_override: An index to use instead of the next available or None 204 to use the next available. 205 206 Returns: 207 A valid global_index to use for the next hint argument. 208 209 Raises: 210 ValueError: If the index_override is already used by another hint. 211 """ 212 if index_override is None: 213 global_index = self._next_global_index 214 else: 215 if index_override in self._used_global_indices: 216 raise ValueError("Index %d was already used by another call to add") 217 global_index = index_override 218 # Make next_global_index valid 219 self._used_global_indices.add(global_index) 220 while self._next_global_index in self._used_global_indices: 221 self._next_global_index += 1 222 return global_index 223 224 def add(self, arg, tag=None, name=None, aggregate=None, 225 index_override=None): 226 """Return a wrapped tensor of an input tensor as an argument. 227 228 Args: 229 arg: A TensorFlow tensor that should be considered an argument. 230 tag: String tag to identify arguments that should be packed. 231 name: Name of argument. This is included in the Identity hint op names. 232 aggregate: Strategy to aggregate. 233 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 234 and OpHint.AGGREGATE_STACK. 235 Note, aggregate is only valid if tag is specified. 236 index_override: Specify what input/output index should this be in the 237 final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the 238 final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than 239 the default call order based ordering. 240 241 Returns: 242 A tensor representing the wrapped argument. 243 244 Raises: 245 ValueError: When indices are not consistent. 246 """ 247 248 # Find the appropriate index 249 if tag is None: 250 if aggregate is not None: 251 raise ValueError("You must specify `tag` if using aggregate.") 252 global_index = self._get_new_global_index(index_override) 253 sort_index = None 254 else: 255 if aggregate is None: 256 raise ValueError("You must specify `aggregate` if using tag.") 257 if tag not in self._tag_to_global_index: 258 self._tag_to_global_index[tag] = ( 259 self._get_new_global_index(index_override)) 260 self._tag_to_next_sort_index[tag] = 0 261 elif (index_override and 262 index_override != self._tag_to_global_index[tag]): 263 raise ValueError( 264 "Tag %r was called with two indices %r and %r" % 265 (tag, index_override, self._tag_to_global_index[tag])) 266 global_index = self._tag_to_global_index[tag] 267 sort_index = self._tag_to_next_sort_index[tag] 268 self._tag_to_next_sort_index[tag] += 1 269 270 uuid = self._unique_function_id 271 name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, 272 uuid, global_index, sort_index, name) 273 274 identity_op = _array_ops.identity(arg, name=name) 275 276 # pylint: disable=protected-access 277 identity_op.op._set_attr( 278 OpHint.FUNCTION_NAME_ATTR, 279 _attr_value_pb2.AttrValue( 280 s=_compat.as_bytes(self._function_name))) 281 identity_op.op._set_attr( 282 OpHint.FUNCTION_UUID_ATTR, 283 _attr_value_pb2.AttrValue( 284 s=_compat.as_bytes(self._unique_function_id))) 285 identity_op.op._set_attr( 286 self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) 287 identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR, 288 _attr_value_pb2.AttrValue(i=self._level)) 289 if self._children_inputs_mappings: 290 identity_op.op._set_attr( 291 OpHint.CHILDREN_INPUTS_MAPPINGS, 292 _attr_value_pb2.AttrValue( 293 s=_compat.as_bytes(_json.dumps( 294 self._children_inputs_mappings)))) 295 296 if sort_index is not None: 297 identity_op.op._set_attr( 298 OpHint.FUNCTION_SORT_INDEX_ATTR, 299 _attr_value_pb2.AttrValue(i=sort_index)) 300 if aggregate is not None: 301 identity_op.op._set_attr( 302 OpHint.FUNCTION_AGGREGATE_ATTR, 303 _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate)))) 304 # pylint: enable=protected-access 305 return identity_op 306 307 def __init__(self, 308 function_name, 309 level=1, 310 children_inputs_mappings=None, 311 **kwargs): 312 """Create a OpHint. 313 314 Args: 315 function_name: Name of the function (the custom op name in tflite) 316 level: OpHint level. 317 children_inputs_mappings: Children OpHint inputs/outputs mapping. 318 children_inputs_mappings should like below: 319 "parent_first_child_input": 320 [{"parent_input_index": num, "child_input_index": num}, ...] 321 "parent_last_child_output": 322 [{"parent_output_index": num, "child_output_index": num}, ...] 323 "internal_children_input_output": 324 [{"child_input_index": num, "child_output_index": num}, ...] 325 **kwargs: Keyword arguments of any constant attributes for the function. 326 """ 327 self._function_name = function_name 328 self._level = level 329 if self._level == 1: 330 assert children_inputs_mappings is None 331 else: 332 assert isinstance(children_inputs_mappings, dict) 333 self._children_inputs_mappings = children_inputs_mappings 334 if self._children_inputs_mappings is not None: 335 self._validate_children_inputs_mappings(self._children_inputs_mappings) 336 self._unique_function_id = _uuid.uuid1().hex 337 self._attrs_to_store_later = kwargs 338 self._stored_attrs = False 339 self._inputs = OpHint.OpHintArgumentTracker( 340 self._function_name, self._unique_function_id, "InputHint", 341 OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings) 342 self._outputs = OpHint.OpHintArgumentTracker( 343 self._function_name, self._unique_function_id, "OutputHint", 344 OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level, 345 self._children_inputs_mappings) 346 347 def _validate_children_inputs_mappings(self, children_inputs_mappings): 348 """Validate children inputs mappings is in the right format. 349 350 Args: 351 children_inputs_mappings: the Children ophint inputs/outputs mapping. 352 """ 353 assert isinstance(children_inputs_mappings, dict) 354 assert "parent_first_child_input" in children_inputs_mappings 355 assert "parent_last_child_output" in children_inputs_mappings 356 assert "internal_children_input_output" in children_inputs_mappings 357 358 # validate parent_first_child_input. 359 360 def assert_dictlist_has_keys(dictlist, keys): 361 for dikt in dictlist: 362 assert isinstance(dikt, dict) 363 for key in keys: 364 assert key in dikt 365 366 assert_dictlist_has_keys( 367 children_inputs_mappings["parent_first_child_input"], 368 ["parent_ophint_input_index", "first_child_ophint_input_index"]) 369 assert_dictlist_has_keys( 370 children_inputs_mappings["parent_last_child_output"], 371 ["parent_output_index", "child_output_index"]) 372 assert_dictlist_has_keys( 373 children_inputs_mappings["internal_children_input_output"], 374 ["child_input_index", "child_output_index"]) 375 376 def _setattr(self, dest_op, name, value): 377 tensor_value = _ops.convert_to_tensor(value) 378 # pylint: disable=protected-access 379 dest_op.op._set_attr(name, _attr_value_pb2.AttrValue( 380 tensor=tensor_value.op.node_def.attr["value"].tensor)) 381 # pylint: enable=protected-access 382 383 def add_input(self, *args, **kwargs): 384 """Add a wrapped input argument to the hint. 385 386 Args: 387 *args: The input tensor. 388 **kwargs: 389 "name" label 390 "tag" a tag to group multiple arguments that will be aggregated. I.e. 391 a string like 'cool_input'. Basically multiple inputs can be added 392 to the same hint for parallel operations that will eventually be 393 combined. An example would be static_rnn which creates multiple copies 394 of state or inputs. 395 "aggregate" aggregation strategy that is valid only for tag non None. 396 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 397 and OpHint.AGGREGATE_STACK. 398 "index_override" The global index to use. This corresponds to the 399 argument order in the final stub that will be generated. 400 Returns: 401 The wrapped input tensor. 402 """ 403 return self._inputs.add(*args, **kwargs) 404 405 def add_output(self, *args, **kwargs): 406 """Add a wrapped output argument to the hint. 407 408 Args: 409 *args: The output tensor. 410 **kwargs: 411 "name" label 412 "tag" a tag to group multiple arguments that will be aggregated. I.e. 413 a string like 'cool_input'. Basically multiple inputs can be added 414 to the same hint for parallel operations that will eventually be 415 combined. An example would be static_rnn which creates multiple copies 416 of state or inputs. 417 "aggregate" aggregation strategy that is valid only for tag non None. 418 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 419 and OpHint.AGGREGATE_STACK. 420 "index_override" The global index to use. This corresponds to the 421 argument order in the final stub that will be generated. 422 Returns: 423 The wrapped output tensor. 424 """ 425 return self._outputs.add(*args, **kwargs) 426 427 def add_inputs(self, *args, **kwargs): 428 """Add a sequence of inputs to the function invocation. 429 430 Args: 431 *args: List of inputs to be converted (should be Tf.Tensor). 432 **kwargs: This allows 'names' which should be a list of names. 433 434 Returns: 435 Wrapped inputs (identity standins that have additional metadata). These 436 are also are also tf.Tensor's. 437 """ 438 if "names" in kwargs: 439 return [ 440 self._inputs.add(arg, name=name) 441 for arg, name in zip(args, kwargs["names"]) 442 ] 443 else: 444 return [self._inputs.add(arg) for arg in args] 445 446 def add_outputs(self, *args, **kwargs): 447 """Add a sequence of outputs to the function invocation. 448 449 Args: 450 *args: List of outputs to be converted (should be tf.Tensor). 451 **kwargs: See 452 453 Returns: 454 Wrapped outputs (identity standins that have additional metadata). These 455 are also tf.Tensor's. 456 """ 457 if "names" in kwargs: 458 return [ 459 self._outputs.add(arg, name=name) 460 for arg, name in zip(args, kwargs["names"]) 461 ] 462 else: 463 return [self._outputs.add(arg) for arg in args] 464 465 466class _LiteOperand: 467 """Abstract operand for a tflite hint function._dynamic_rnn_loop. 468 469 This is a base class that handles representing arguments to an OpHint. 470 It also is able to serialize operands to the stubbed graph_def. 471 Child classes are responsible for being able to 472 store information about the hint identity operators. They are also responsible 473 for knowing how to serialize to output graphdefs. 474 475 Typically this will be implemented by holding one or more identity nodes 476 that were previously discovered as hints. 477 """ 478 479 def aggregate_and_return_name_for_input(self, out_graphdef): 480 """This adds the node(s) to out_graphdef and returns the input node name. 481 482 Args: 483 out_graphdef: A graphdef that is ready to have this input added. 484 485 Returns: 486 The output that the stub should use as an input for this operand. 487 488 Raises: 489 RuntimeError: if the method is not implemented. 490 """ 491 del out_graphdef 492 raise RuntimeError("Unimplemented abstract method.") 493 494 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 495 out_graphdef): 496 """Add node(s) to graph representing output operands and returns type. 497 498 Args: 499 fused_op_name: name of the fused op stub name. 500 output_index: Output index that we are currently processing from stub. 501 out_graphdef: The destination graphdef we are currently building up. 502 503 Returns: 504 The datatype of this identity. 505 506 Raises: 507 RuntimeError: if the method is not implemented. 508 """ 509 del fused_op_name, output_index, out_graphdef 510 raise RuntimeError("Unimplemented abstract method.") 511 512 513class _LiteSingleOperand(_LiteOperand): 514 """A simple operand that is non-aggregated (i.e. most hints).""" 515 516 def __init__(self, node): 517 _LiteOperand.__init__(self) 518 self.node = node 519 self.name = _tensor_name_base(node.name) 520 521 def flatten(self): 522 return [self.name] 523 524 def aggregate_and_return_name_for_input(self, out_graphdef): 525 return self.name 526 527 def aggregate_and_return_name_for_output(self, fused_op_name, index, 528 out_graphdef): 529 output_node = _copy.deepcopy(self.node) 530 del output_node.input[:] 531 output_node.input.append(_tensorflow_output_name(fused_op_name, index)) 532 out_graphdef.node.extend([output_node]) 533 return self.node.attr["type"].i 534 535 def __str__(self): 536 return str(self.name) 537 538 539class _LiteAggregateOperand(_LiteOperand): 540 """An operand for a tflite hint function that is aggregated from many. 541 542 For example, an LSTM is a grid of operators that are all related. Inputs 543 going into them may need to be fused, so they should all be tracked as 544 related arguments. 545 """ 546 547 def __init__(self, aggregation): 548 _LiteOperand.__init__(self) 549 self.aggregation = aggregation 550 self.names = {} 551 self.nodes = {} 552 self.flattened = None 553 554 def add(self, sort, node): 555 self.names[sort] = _tensor_name_base(node.name) 556 self.nodes[sort] = node 557 558 def flatten_nodes(self): 559 """Return a list of all the node protos in aggregation sorted order.""" 560 if not self.flattened: 561 self.flattened = [None] * len(self.nodes) 562 for idx, node in self.nodes.items(): 563 self.flattened[idx] = node 564 for n in self.nodes: 565 if n is None: 566 raise RuntimeError("Aggregate was missing argument.") 567 if self.aggregation == OpHint.AGGREGATE_FIRST: 568 self.flattened = self.flattened[:1] 569 elif self.aggregation == OpHint.AGGREGATE_LAST: 570 self.flattened = self.flattened[-1:] 571 elif self.aggregation == OpHint.AGGREGATE_STACK: 572 pass 573 else: 574 raise ValueError("Invalid aggregation type %r specified" % 575 self.aggregation) 576 return self.flattened 577 578 def flatten(self): 579 """Return a list of all node names in aggregation sorted sorter.""" 580 return [_tensor_name_base(x.name) for x in self.flatten_nodes()] 581 582 def aggregate_and_return_name_for_input(self, out_graphdef): 583 """This adds the nodes to out_graphdef and returns an aggregated output. 584 585 In particular, if you have 4 inputs to a hint stub, this will be the 586 node that you can use as an output. I.e. you have 4 timesteps from a 587 static rnn, then a fused UnidirectionalLSTM will expect 1 input with 588 all 4 time steps. So here we make a pack and return the output name of 589 that pack. 590 591 Args: 592 out_graphdef: A graphdef that is ready to have this input added. 593 594 Returns: 595 The name of a pack that aggregates this node. 596 """ 597 flattened = self.flatten_nodes() 598 if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( 599 self.aggregation == OpHint.AGGREGATE_LAST): 600 assert len(flattened) == 1 601 if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: 602 return _tensor_name_base(flattened[0].name) 603 else: 604 new_node = _node_def_pb2.NodeDef() 605 new_node.op = "Pack" 606 new_node.name = "OpHintStack-%s" % flattened[0].name 607 new_node.attr["N"].i = len(flattened) 608 new_node.attr["T"].type = flattened[0].attr["T"].type 609 for discrete in flattened: 610 new_node.input.append(_tensor_name_base(discrete.name)) 611 out_graphdef.node.extend([new_node]) 612 return new_node.name 613 614 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 615 out_graphdef): 616 """This adds to `out_graphdef` all the unaggregated outputs. 617 618 I.e. we are outputting from a fused stub, but we need to make it compatible 619 with the unfused original graph so we insert an unpack. Ideally in a later 620 stage the unpack -> pack sequences will be removed. 621 622 Args: 623 fused_op_name: The name of the stub we are in the process of fusing. 624 output_index: The output output_index this object represents. 625 out_graphdef: The graphdef we are in the process of buildings 626 627 Returns: 628 The type of the aggregated output (so we can finish building the stub 629 op). 630 """ 631 flattened = self.flatten_nodes() 632 if (self.aggregation == OpHint.AGGREGATE_FIRST) or ( 633 self.aggregation == OpHint.AGGREGATE_LAST): 634 assert len(flattened) == 1 635 if len(flattened) == 1 and self.aggregation != OpHint.AGGREGATE_STACK: 636 temp_op = _LiteSingleOperand(flattened[0]) 637 return temp_op.aggregate_and_return_name_for_output( 638 fused_op_name, output_index, out_graphdef) 639 else: 640 stack_node = _node_def_pb2.NodeDef() 641 stack_node.op = "Unpack" 642 stack_node.name = "OpHintUnstack-%s" % flattened[0].name 643 stack_node.attr["num"].i = len(flattened) 644 output_type = flattened[0].attr["T"].type 645 stack_node.attr["T"].type = output_type 646 stack_node.input.append( 647 _tensorflow_output_name(fused_op_name, output_index)) 648 out_graphdef.node.extend([stack_node]) 649 650 for idx, discrete in enumerate(flattened): 651 output_node = _copy.deepcopy(discrete) 652 del output_node.input[:] 653 output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) 654 out_graphdef.node.extend([output_node]) 655 656 return output_type 657 658 def __str__(self): 659 s = "\t\t\tAGGREGATE %s\n" % self.aggregation 660 for sort, val in self.names.iteritems(): 661 s += "\t\t\t%d: %s\n" % (sort, val) 662 return s 663 664 665class _LiteFuncCall: 666 """Represent a TensorFlow Lite custom function. 667 668 This is uses to accumulate found hints in the graphdef into a single 669 conceptual unit. 670 671 Attributes: 672 inputs: inputs to the op (hash from index # to argument) 673 outputs: outputs to the op (hash from index # to argument) 674 function_name: the tflite custom op name to use 675 uuid: a unique call id for this particular call (i.e. multiple function 676 calls would have the same function_name but different uuids. 677 params: A param name to key value for op constant data. I.e. for axis on a 678 reduction, strides on a convolution, etc. 679 level: Level of the OpHint. 680 children_inputs_mappings: If the Ophint has children, children inputs 681 mappings indicate how their inputs & outputs are mapped. 682 """ 683 684 def __init__(self): 685 self.inputs = {} 686 self.outputs = {} 687 self.function_name = None 688 self.uuid = None 689 self.params = {} 690 self.level = -1 691 self.children_inputs_mappings = {} 692 693 def flattened_inputs_and_outputs(self): 694 """Return a list of inputs and outputs in a flattened format. 695 696 Returns: 697 Tuple of (inputs, outputs). where input and output i a list of names. 698 """ 699 700 def _flatten(input_or_output_dict): 701 flattened_items = [] 702 for item in input_or_output_dict.values(): 703 flattened_items.extend(item.flatten()) 704 return flattened_items 705 706 return _flatten(self.inputs), _flatten(self.outputs) 707 708 def __str__(self): 709 710 def format_args(items): 711 s = "" 712 for idx, item in items.iteritems(): 713 s += ("\t\t%d:\n" % idx) + str(item) 714 return s 715 716 inputs_str = "\tInputs\n" + format_args(self.inputs) 717 outputs_str = "\tOutputs\n" + format_args(self.outputs) 718 719 return ( 720 "tflite function %s call %s level %d " 721 "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" % 722 (self.function_name, self.uuid, self.level, inputs_str, outputs_str)) 723 724 725def _find_all_hints_in_nodes(nodes): 726 """Look at the all the input nodes and return a list of LiteFuncCall objs. 727 728 Args: 729 nodes: A TensorFlow graph_def to look for LiteFuncCalls. 730 731 Returns: 732 a list of `LifeFuncCall` objects in the form 733 734 """ 735 func_calls = _collections.defaultdict(_LiteFuncCall) 736 737 for node in nodes: 738 attr = node.attr 739 # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip 740 if (OpHint.FUNCTION_UUID_ATTR not in attr or 741 not attr[OpHint.FUNCTION_UUID_ATTR].s): 742 continue 743 uuid = attr[OpHint.FUNCTION_UUID_ATTR].s 744 745 # Start building function 746 call_def = func_calls[uuid] 747 call_def.uuid = uuid 748 call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s 749 call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i 750 # Get sorting and aggregation information 751 752 sort = ( 753 attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i 754 if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) 755 if sort == -1: 756 sort = None 757 aggregation = None 758 if OpHint.FUNCTION_AGGREGATE_ATTR in attr: 759 aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) 760 761 if OpHint.CHILDREN_INPUTS_MAPPINGS in attr: 762 call_def.children_inputs_mappings = _json.loads( 763 _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s)) 764 765 # Add the input or output 766 def put_operand(stuff, index, sort, operand, aggregation): 767 """Add a given index into the function structure.""" 768 if sort is None: 769 stuff[index] = _LiteSingleOperand(operand) 770 else: 771 if index not in stuff: 772 stuff[index] = _LiteAggregateOperand(aggregation) 773 stuff[index].add(sort, operand) 774 775 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 776 put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, 777 sort, node, aggregation) 778 if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: 779 put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, 780 sort, node, aggregation) 781 782 # Remember attributes 783 for a in attr: 784 if a.startswith("_tflite_attr_"): 785 call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor 786 787 return func_calls 788 789 790def _extract_topology_sequence_mapping(nodes): 791 return dict( 792 (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes)) 793 794 795def _find_children_hints_in_while_loop(function_def, nodes_mapping): 796 """Find children hints and all nodes inside the while loop. 797 798 Args: 799 function_def: Function def of the while loop. 800 nodes_mapping: While loop input_arg : real node name. 801 802 Returns: 803 Ordered children hints and all re-mapped nodes inside the while loop. 804 """ 805 new_nodes = [] 806 807 # Make nodes inside function def inputs point to the real nodes. 808 for node in function_def.node_def: 809 for i, _ in enumerate(node.input): 810 if node.input[i] in nodes_mapping: 811 node.input[i] = nodes_mapping[node.input[i]] 812 new_nodes.append(_copy.deepcopy(node)) 813 name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def) 814 children_hints = _find_all_hints_in_nodes(new_nodes) 815 children_hints_q = [] 816 # Ordered by the outputs. 817 for hint in children_hints.values(): 818 _, output_names = hint.flattened_inputs_and_outputs() 819 seq = name_to_seq_num[output_names[0]] 820 for output_name in output_names: 821 seq = min(seq, name_to_seq_num[output_name]) 822 children_hints_q.append((seq, hint)) 823 children_hints_q.sort(key=lambda tup: tup[0]) 824 ordered_children_hints = [x[1] for x in children_hints_q] 825 return ordered_children_hints, new_nodes 826 827 828def _find_children_hints(call, graph_def): 829 """Find all children hints. 830 831 For a given OpHint, we find all children hints inside it, we also copy all the 832 nodes inside function defs (if applicable) to the original graph_def, they are 833 returned in a list as well. 834 835 Args: 836 call: Parent OpHint that contains children ophints. 837 graph_def: Original graph def. 838 839 Returns: 840 Ordered children hints inside the parent ophint; new graph def that contains 841 nodes inside function defs (if applicable); nodes inside function defs. 842 """ 843 name_to_input_name, _, _ = _extract_graph_summary(graph_def) 844 input_names, output_names = call.flattened_inputs_and_outputs() 845 846 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 847 reachable_by_output = _bfs_for_reachable_nodes(output_names, 848 name_to_input_name) 849 output_nodes_set = set(output_names) 850 children_hints = [] 851 out = _graph_pb2.GraphDef() 852 out.library.CopyFrom(graph_def.library) 853 out.versions.CopyFrom(graph_def.versions) 854 function_def_nodes = set() 855 for node in graph_def.node: 856 out.node.extend([_copy.deepcopy(node)]) 857 n = _tensor_name_base(node.name) 858 if n in reachable_by_output: 859 if n not in reachable_by_input and n not in output_nodes_set: 860 # special handle for while loop function def. 861 if node.op == "While" or node.op == "StatelessWhile": 862 body_name = node.attr["body"].func.name 863 inputs_outside_loop = node.input 864 for function_def in graph_def.library.function: 865 if function_def.signature.name == body_name: 866 function_inputs = function_def.signature.input_arg 867 assert len(inputs_outside_loop) == len(function_inputs) 868 nodes_mapping = {} 869 for i, function_input in enumerate(function_inputs): 870 nodes_mapping[function_input.name] = inputs_outside_loop[i] 871 (children_hints_in_loop, 872 new_nodes) = _find_children_hints_in_while_loop( 873 function_def, nodes_mapping) 874 function_def_nodes.update([x.name for x in new_nodes]) 875 children_hints.extend(children_hints_in_loop) 876 out.node.extend(new_nodes) 877 878 return children_hints, out, function_def_nodes 879 880 881def _tensor_name_base(full_tensor_name): 882 """Removes the device assignment code from a tensor. 883 884 e.g. _tensor_name_base("foo:3") => "foo" 885 886 Args: 887 full_tensor_name: A tensor name that is annotated with a device placement 888 (this is what tensor flow introspection gives). 889 890 Returns: 891 A name without any device assignment. 892 """ 893 if full_tensor_name.startswith("^"): 894 return full_tensor_name[1:] 895 return full_tensor_name.split(":")[0] 896 897 898def _tensorflow_output_name(tensor_name, output_index): 899 return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, 900 output_index) 901 902 903def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, 904 name_to_input_name): 905 """Checks to make sure node only connects to predecessor graph through inputs. 906 907 Args: 908 n: Node to check 909 reachable_by_input: Nodes that are reachable by all inputs of subgraph 910 input_nodes_set: The set of nodes that are "inputs". 911 name_to_input_name: Maps from name to the list of inputs. 912 913 Raises: 914 TypeError: If the given node uses items past inputs directly. 915 """ 916 next_to_visit = [n] 917 visited = set() 918 while next_to_visit: 919 current_node = next_to_visit.pop() 920 visited.add(current_node) 921 if (current_node in reachable_by_input and 922 current_node not in input_nodes_set): 923 raise TypeError("Node %s uses input %s not in input_nodes." % 924 (n, current_node)) 925 if current_node not in input_nodes_set: 926 next_to_visit += [ 927 input_node for input_node in name_to_input_name[current_node] 928 if input_node not in visited 929 ] 930 931 932def _convert_single_op_hint_to_stub(call, 933 graph_def, 934 function_def_nodes=None, 935 is_last_run=True): 936 """Given a graph_def, converts `call` into a stub and returns a new graph_def. 937 938 Args: 939 call: A single function call to be converted. 940 graph_def: A graph_def to use as input (that has call obviously). 941 function_def_nodes: Nodes inside the function def those are not connected to 942 the graph. 943 is_last_run: Whether it is the last run for a given pass (for OpHint has 944 children). 945 946 Returns: 947 A new transformed graph-def that has call as a stub (single op). 948 949 Note: after this process, the graph_def can no longer be loaded into 950 the tensorflow runtime, so all future manipulations are done in graph_def 951 level. 952 """ 953 if function_def_nodes is None: 954 function_def_nodes = set() 955 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 956 graph_def) 957 input_names, output_names = call.flattened_inputs_and_outputs() 958 959 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 960 reachable_by_output = _bfs_for_reachable_nodes(output_names, 961 name_to_input_name) 962 output_nodes_set = set(output_names) 963 nodes_after_fuse = [] 964 nodes_deleted_by_fuse = set() 965 # Classify each node. We want to keep everything reachable by input, but 966 # we don't know if things that are not reachable by output or input (things 967 # after fusing). 968 for node in graph_def.node: 969 n = _tensor_name_base(node.name) 970 if n in reachable_by_output: 971 if n not in reachable_by_input and n not in output_nodes_set: 972 nodes_deleted_by_fuse.add(n) 973 elif n not in reachable_by_input and n not in function_def_nodes: 974 # n is a node that after all the fusings, so keep it. 975 nodes_after_fuse.append(n) 976 else: 977 # In the last run, n is a node that is randomly in the graph but not 978 # connected to the chain of dependencies, we will delete n, otherwise 979 # we keep them. 980 if not is_last_run: 981 nodes_after_fuse.append(n) 982 983 # Make a new graphdef with all the pre-input and input nodes 984 out = _graph_pb2.GraphDef() 985 reachable_by_input_sorted = sorted( 986 list(reachable_by_input), key=lambda n: name_to_seq_num[n]) 987 for node in reachable_by_input_sorted: 988 out.node.extend([_copy.deepcopy(name_to_node[node])]) 989 990 # Create any stacks to aggregate arguments into to a single input 991 # i.e. for static_rnn's. 992 sorted_input_indices = list(call.inputs.keys()) 993 sorted_input_indices.sort() 994 sorted_output_indices = list(call.outputs.keys()) 995 sorted_output_indices.sort() 996 new_node = _node_def_pb2.NodeDef() 997 # Delegate to each operand to produce the proper new input for this stub node. 998 # In particular, an aggregate input will now be a Pack of some previously 999 # non-fused things. 1000 1001 optional_input_node = _node_def_pb2.NodeDef() 1002 optional_input_node.name = "Const" + str(_uuid.uuid1().hex) 1003 optional_input_node.op = "Const" 1004 optional_input_node.attr["dtype"].CopyFrom( 1005 _attr_value_pb2.AttrValue(type=_dtypes.float32.as_datatype_enum)) 1006 optional_input_node.attr["value"].CopyFrom( 1007 _attr_value_pb2.AttrValue( 1008 tensor=_tensor_util.make_tensor_proto([-1], _dtypes.float32, [1]))) 1009 out.node.extend([optional_input_node]) 1010 1011 max_index = max(sorted_input_indices) + 1 1012 for cur_index in range(max_index): 1013 if cur_index in sorted_input_indices: 1014 inputs = call.inputs[cur_index] 1015 input_name = inputs.aggregate_and_return_name_for_input(out) 1016 new_node.input.append(input_name) 1017 else: 1018 new_node.input.append(optional_input_node.name) 1019 1020 new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) 1021 1022 # Create the function 1023 new_node.op = call.function_name 1024 new_node.name = call.uuid 1025 out.node.extend([new_node]) 1026 1027 # Now call each output argument to give them a chance to make the proper 1028 # output type and add it to our new_node. 1029 output_dtypes = [] 1030 max_output_index = max(sorted_output_indices) + 1 1031 for cur_index in range(max_output_index): 1032 if cur_index in sorted_output_indices: 1033 output = call.outputs[cur_index] 1034 output_dtype = ( 1035 output.aggregate_and_return_name_for_output(new_node.name, cur_index, 1036 out)) 1037 else: 1038 output_dtype = optional_input_node.attr["type"].i 1039 output_dtypes.append(output_dtype) 1040 new_node.attr["_output_types"].list.type[:] = output_dtypes 1041 new_node.attr["_output_quantized"].b = False 1042 1043 # Add post output nodes that do not depend on the outputs 1044 for n in nodes_after_fuse: 1045 should_keep = True 1046 for input_name in name_to_input_name[n]: 1047 if input_name in nodes_deleted_by_fuse: 1048 should_keep = False 1049 if should_keep: 1050 out.node.extend([_copy.deepcopy(name_to_node[n])]) 1051 1052 # Misc. graph_def data that needs copying. 1053 out.library.CopyFrom(graph_def.library) 1054 out.versions.CopyFrom(graph_def.versions) 1055 1056 return out 1057 1058 1059def _remove_one_redundant_stack_unstack(in_graph_def): 1060 """Removes a stack->unstack pattern from in_graph_def in a returned graph. 1061 1062 Args: 1063 in_graph_def: Graph def to use as input. 1064 1065 Returns: 1066 Simplified tuple (graph_def, changed_something) where changed_something 1067 is true if anything was done. 1068 """ 1069 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 1070 in_graph_def) 1071 del name_to_seq_num 1072 1073 do_generic_pack_unpack = True 1074 1075 out = _graph_pb2.GraphDef() 1076 out.library.CopyFrom(in_graph_def.library) 1077 out.versions.CopyFrom(in_graph_def.versions) 1078 for n in in_graph_def.node: 1079 node_name = _tensor_name_base(n.name) 1080 if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): 1081 continue 1082 next_to_visit = [node_name] 1083 visited = set() 1084 1085 unpack_nodes = set() 1086 pack_node = node_name 1087 1088 # Find a pattern of unstack connected to a stack (with identities 1089 # in between. 1090 matches_pattern = True 1091 is_hint_created_stack = False 1092 while next_to_visit: 1093 current_node_name = next_to_visit[0] 1094 visited.add(current_node_name) 1095 del next_to_visit[0] 1096 node = name_to_node[current_node_name] 1097 is_op_hint_stack = node.name.startswith("OpHintStack") 1098 is_op_hint_unstack = node.name.startswith("OpHintUnstack") 1099 if (node.op == "Identity" or is_op_hint_stack or 1100 (do_generic_pack_unpack and node.op == "Pack")): 1101 is_hint_created_stack |= is_op_hint_stack 1102 next_to_visit += [ 1103 input_node for input_node in name_to_input_name[current_node_name] 1104 if input_node not in visited 1105 ] 1106 elif (is_op_hint_unstack or 1107 (do_generic_pack_unpack and node.op == "Unpack")): 1108 unpack_nodes.add(node.name) 1109 is_hint_created_stack &= is_op_hint_unstack 1110 else: 1111 matches_pattern = False 1112 break 1113 visited.add(node.name) 1114 1115 if matches_pattern and len(unpack_nodes) == 1: 1116 pack_node = node_name 1117 1118 # Check to see if anyone depends on the intermediate identity or the 1119 # Unstacked form 1120 no_external_dependency = True 1121 for other_n in in_graph_def.node: 1122 if other_n.name in visited: 1123 continue 1124 for input_tensor in name_to_input_name[other_n.name]: 1125 input_op = _tensor_name_base(input_tensor) 1126 if input_op in visited and input_op != pack_node: 1127 no_external_dependency = False 1128 # Proceed with the substitution if the stack/unstack pair was created 1129 # through hints, or that it was not, but nobody is consuming things 1130 # between the stack and unstack. 1131 if is_hint_created_stack or no_external_dependency: 1132 end = unpack_nodes.pop() 1133 end_input = name_to_node[end].input[0] 1134 # All nodes that depend on the final stack need to be redone to use 1135 for other_n in in_graph_def.node: 1136 node_name = _tensor_name_base(other_n.name) 1137 if node_name not in visited: 1138 new_node = _copy.deepcopy(other_n) 1139 new_node.input[:] = [ 1140 (end_input if stripped == pack_node else non_stripped) 1141 for stripped, non_stripped in zip(name_to_input_name[node_name], 1142 new_node.input[:]) 1143 ] 1144 out.node.extend([new_node]) 1145 return out, True 1146 return in_graph_def, False 1147 1148 1149def _remove_redundant_stack_unstack(graph_def): 1150 curr = graph_def 1151 del graph_def 1152 changed_stuff = True 1153 while changed_stuff: 1154 curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) 1155 return curr 1156 1157 1158def _get_correct_mapping(original_index, nodes): 1159 # Special handle for the index is -1 case. 1160 # If it is -1, return the last index. 1161 if original_index == -1: 1162 node_indices = nodes.keys() 1163 node_indices = sorted(node_indices) 1164 return node_indices[-1] 1165 return original_index 1166 1167 1168def _convert_op_hints_to_stubs_helper( 1169 graph_def, write_callback=lambda sess, graph_def: None): 1170 """Converts a graph_def to a new graph_def where all op hints are stubbed. 1171 1172 Args: 1173 graph_def: A graph def that we should convert. 1174 write_callback: A function pointer that can be used to write intermediate 1175 steps of graph transformation (optional). 1176 1177 Returns: 1178 A new stubbed graph_def. 1179 """ 1180 hints = _find_all_hints_in_nodes(graph_def.node) 1181 1182 hints_q = [] 1183 for hint in hints.values(): 1184 hints_q.append((hint.level, hint.uuid)) 1185 1186 hints_q.sort(key=lambda tup: tup[0]) 1187 for i in range(len(hints_q) - 1, -1, -1): 1188 level, hint_uuid = hints_q[i] 1189 1190 curr_graph_def = graph_def 1191 del graph_def # prevent using graph_def again (common source of error) 1192 for i in range(len(hints_q) - 1, -1, -1): 1193 level, hint_uuid = hints_q[i] 1194 if level >= 2: 1195 children_hints, curr_graph_def, function_def_nodes = _find_children_hints( 1196 hints[hint_uuid], curr_graph_def) 1197 # pylint: disable=superfluous-parens 1198 assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test 1199 # pylint: enable=superfluous-parens 1200 1201 # Re-wire the children hints inputs/outputs, so latter child's inputs 1202 # connect to previous child node's outputs. 1203 children_inputs_mappings = hints[hint_uuid].children_inputs_mappings 1204 for j, child_hint in enumerate(children_hints): 1205 if j == 0: 1206 for mapping in children_inputs_mappings["parent_first_child_input"]: 1207 parent_input_index = _get_correct_mapping( 1208 mapping["parent_ophint_input_index"], hints[hint_uuid].inputs) 1209 child_input_index = _get_correct_mapping( 1210 mapping["first_child_ophint_input_index"], child_hint.inputs) 1211 child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[ 1212 parent_input_index] 1213 else: 1214 for mapping in children_inputs_mappings[ 1215 "internal_children_input_output"]: 1216 input_index = _get_correct_mapping(mapping["child_input_index"], 1217 child_hint.inputs) 1218 output_index = _get_correct_mapping(mapping["child_output_index"], 1219 children_hints[j - 1].outputs) 1220 child_hint.inputs[input_index] = children_hints[ 1221 j - 1].outputs[output_index] 1222 if j == len(children_hints) - 1: 1223 for mapping in children_inputs_mappings["parent_last_child_output"]: 1224 parent_output_index = _get_correct_mapping( 1225 mapping["parent_output_index"], hints[hint_uuid].outputs) 1226 child_output_index = _get_correct_mapping( 1227 mapping["child_output_index"], child_hint.outputs) 1228 child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[ 1229 parent_output_index] 1230 1231 for j, child_hint in enumerate(children_hints): 1232 curr_graph_def = _convert_single_op_hint_to_stub( 1233 child_hint, curr_graph_def, function_def_nodes, 1234 j == len(children_hints) - 1) 1235 else: 1236 curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid], 1237 curr_graph_def) 1238 write_callback(curr_graph_def, "initial") 1239 # The stubbing process can create stacks/unstacks in the case of LSTMs 1240 # remove them. 1241 curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) 1242 return curr_graph_def 1243 1244 1245def find_all_hinted_output_nodes(session=None, graph_def=None): 1246 """Find all Ophints output nodes in the graph. 1247 1248 This is used to get all the output nodes those are ophinted, it is important 1249 for operation like convert_variables_to_constants keep all ophints structure. 1250 Note: only one of session or graph_def should be used, not both. 1251 Why this can be useful? Some TensorFlow ops (e.g. bidirectional rnn), can 1252 generate multiple outputs for unfused subgraph. If not all output nodes are 1253 consumed, graph optimization can potentially drop the unused nodes and cause 1254 ophints in an invalid states (due to missing ophinted output nodes). So it's 1255 important for us to find all those hinted output nodes and make sure they're 1256 not discarded away. 1257 1258 Args: 1259 session: A TensorFlow session that contains the graph to convert. 1260 graph_def: A graph def that we should convert. 1261 1262 Returns: 1263 A list of OpHints output nodes. 1264 Raises: 1265 ValueError: If both session and graph_def are provided. 1266 """ 1267 if session is not None and graph_def is not None: 1268 raise ValueError("Provide only one of session and graph_def.") 1269 hinted_outputs_nodes = [] 1270 if session is not None: 1271 hints = _find_all_hints_in_nodes(session.graph_def.node) 1272 elif graph_def is not None: 1273 hints = _find_all_hints_in_nodes(graph_def.node) 1274 for hint in hints.values(): 1275 _, output_nodes = hint.flattened_inputs_and_outputs() 1276 hinted_outputs_nodes.extend(output_nodes) 1277 return hinted_outputs_nodes 1278 1279 1280def is_ophint_converted(graph_def): 1281 if graph_def is None: 1282 raise ValueError("Must provide the graph_def.") 1283 ophint_converted = False 1284 for node in graph_def.node: 1285 attr = node.attr 1286 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 1287 ophint_converted = True 1288 break 1289 return ophint_converted 1290 1291 1292@_tf_export(v1=["lite.experimental.convert_op_hints_to_stubs"]) 1293@_deprecation.deprecated( 1294 None, 1295 "Please follow instructions under " 1296 "https://www.tensorflow.org/lite/convert/operation_fusion for operation" 1297 "fusion in tflite." 1298) 1299def convert_op_hints_to_stubs(session=None, 1300 graph_def=None, 1301 write_callback=lambda graph_def, comments: None): 1302 """Converts a graphdef with LiteOp hints into stub operations. 1303 1304 This is used to prepare for toco conversion of complex intrinsic usages. 1305 Note: only one of session or graph_def should be used, not both. 1306 1307 Args: 1308 session: A TensorFlow session that contains the graph to convert. 1309 graph_def: A graph def that we should convert. 1310 write_callback: A function pointer that can be used to write intermediate 1311 steps of graph transformation (optional). 1312 1313 Returns: 1314 A new graphdef with all ops contained in OpHints being replaced by 1315 a single op call with the right parameters. 1316 Raises: 1317 ValueError: If both session and graph_def are provided. 1318 """ 1319 1320 if session is not None and graph_def is not None: 1321 raise ValueError("Provide only one of session and graph_def.") 1322 1323 if session is not None: 1324 return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) 1325 elif graph_def is not None: 1326 return _convert_op_hints_to_stubs_helper(graph_def, write_callback) 1327 else: 1328 raise ValueError("Must specify session or graph_def as input.") 1329 1330 1331_allowed_symbols = [ 1332 "OpHint", 1333 "convert_op_hints_to_stubs", 1334 "convert_op_hints_to_stubs_new", 1335 "find_all_hinted_output_nodes", 1336 "is_ophint_converted", 1337] 1338remove_undocumented(__name__, _allowed_symbols) 1339