1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Define tflite op hints (intrinsic operations). 16 17This essentially allows defining a TensorFlow API for tflite operations in 18Python with hints on how they are represented in TensorFlow Lite. This basically 19is a form of tflite intrinsic. It wraps a subpart of a TensorFlow execution 20graph and is useful for LSTMs and other complicated TensorFlow constructions 21that are difficult to pattern match in TOCO, but are represented by a single 22accelerated tflite op. 23 24Example: 25 def tflite_cool_activation(input): 26 # A cool activation function. 27 custom = tf.lite.OpHint("cool_activation") 28 input, = custom.add_inputs(input) 29 output = tf.sigmoid(input) * input 30 output, = custom.add_outputs(output) 31 return output 32 33 image = tf.placeholder(tf.float32, (1, 16, 16, 1)) 34 output = tf.identity(tflite_cool_activation(image)) 35 36 session = tf.Session() 37 38 graphdef_to_convert = tf.lite.convert_op_hints_to_stubs(session) 39 tflite_graph = tf.lite.toco_convert(graphdef_to_convert, [image], [output]) 40 with open("/tmp/graph.fb", "wb") as fp: 41 fp.write(tflite_graph) 42 43How does it work?: 44 45OpHint is a helper that you use when defining a vanilla python function. 46It allows you to wrap arguments with tf.identities with some custom attributes. 47These attributes allow you to find the original block of ops that was created. 48For example, if you use cool_activation above you essentially get: 49 50a_input = tf.identity() 51result = tf.multiply(tf.sigmoid(a_input), a_input) 52output = tf.identity() 53 54a_input, output are identities that have parameters representing 55what argument they are, what the name of the function they should turn into 56in tf lite as well as a guid that uniquely identifies a particular invocation. 57 58Once you have built your whole tensorflow graph, you can run it and train it 59as usual, but after you have done that, you need to convert the graph into 60a form that replaces these subgraphs wrapped in identities to stub ops. These 61ops don't actually exist in the normal TensorFlow runtime, but will be 62understood by toco later. 63""" 64 65# TODO(aselle): Make this use generic graph transformations. 66# TODO(aselle): _tensor_name_base should be called _tensor_name_to_op_name. 67 68from __future__ import absolute_import 69from __future__ import division 70from __future__ import print_function 71 72import collections as _collections 73import copy as _copy 74import json as _json 75import uuid as _uuid 76import six as _six 77 78from tensorflow.core.framework import attr_value_pb2 as _attr_value_pb2 79from tensorflow.core.framework import graph_pb2 as _graph_pb2 80from tensorflow.core.framework import node_def_pb2 as _node_def_pb2 81from tensorflow.python.framework import ops as _ops 82# TODO(aselle): publicize these apis if we continue to use these. 83from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes 84from tensorflow.python.framework.graph_util_impl import _extract_graph_summary 85from tensorflow.python.ops import array_ops as _array_ops 86from tensorflow.python.util import compat as _compat 87from tensorflow.python.util.all_util import remove_undocumented 88from tensorflow.python.util.tf_export import tf_export as _tf_export 89 90 91@_tf_export("lite.OpHint") 92class OpHint(object): 93 """A class that helps build tflite function invocations. 94 95 It allows you to take a bunch of TensorFlow ops and annotate the construction 96 such that toco knows how to convert it to tflite. This embeds a pseudo 97 function in a TensorFlow graph. This allows embedding high-level API usage 98 information in a lower level TensorFlow implementation so that an alternative 99 implementation can be substituted later. 100 101 Essentially, any "input" into this pseudo op is fed into an identity, and 102 attributes are added to that input before being used by the constituent ops 103 that make up the pseudo op. A similar process is done to any output that 104 is to be exported from the current op. 105 106 """ 107 # TODO(aselle): When TensorFlow functions functionality works for arbitrary 108 # constructs, this mechanism can be retired and changed to use python defun's. 109 110 # Attr constants that are used for representation in the GraphDef. These 111 # will be used on every Identity op that is involved in a total OpHint. 112 113 # Name of the OpHint function (cosmetic). 114 FUNCTION_NAME_ATTR = "_tflite_function_name" 115 # UUID of the function (each OpHint gets a new uuid). 116 FUNCTION_UUID_ATTR = "_tflite_function_uuid" 117 # The index index of the input (or nothing if it is an output). 118 FUNCTION_INPUT_INDEX_ATTR = "_tflite_function_input_index" 119 # The output index of the output (or nothing if it is an input). 120 FUNCTION_OUTPUT_INDEX_ATTR = "_tflite_function_output_index" 121 # An index that orders aggregate arguments. Aggregate arguments are ones 122 # that are separate but will be fused horizontally. For example a static LSTM 123 # has a lstm cell for each time step. Each one has a separate opHint, but a 124 # fused SequentialLSTM will treat this as a single tensor. 125 FUNCTION_SORT_INDEX_ATTR = "_tflite_function_sort_index" 126 # The way in which multiple parts of the aggregate argument will be joined 127 # into a fused operand. Valid options are OpHint.AGGREGATE_FIRST, 128 # OpHint.AGGREGATE_LAST, OpHint.AGGREGATE_STACK. 129 FUNCTION_AGGREGATE_ATTR = "_tflite_function_aggregate" 130 # On fused OpHint stub, the order of inputs that the final LSTM call will 131 # have. What this means is that the TensorFlow order might be 132 # "foo", "bar", "stuff" and you might want the TF lite op order to be 133 # "stuff", "foo", "bar", -1 (where -1 is unused). So you would set this 134 # attribute to [2, 0, 1, -1]. 135 TFLITE_INPUT_INDICES = "_tflite_input_indices" 136 # OpHint level. 137 FUNCTION_LEVEL_ATTR = "_tflite_ophint_level" 138 # Ophint internal mapping, this is for high level Ophint only. 139 # This basically contains three kinds of mapping: 140 # 1) How parental ophinted inputs map to the first child ophinted inputs; 141 # 2) How internal children nodes are connected; 142 # 3) How parental ophinted outputs map to the last child ophinted outputs. 143 CHILDREN_INPUTS_MAPPINGS = "_tflite_children_ophint_inputs_mapping" 144 145 # Types of aggregations 146 # stack: stacks all ophints with matching tags. i.e. for a static rnn. 147 # specifically, this is good for an input or output to a static rnn cell. 148 AGGREGATE_STACK = "stack" 149 # first: only takes the first output (one with lowest sort index) 150 # of matching tags. This is good for the input state to an RNN. 151 AGGREGATE_FIRST = "first" 152 # aggregation last takes only the last tag (one with highest sort index). 153 # This is good for an output value on the last stack item of a 154 # static rnn. 155 AGGREGATE_LAST = "last" 156 157 class OpHintArgumentTracker(object): 158 """Conceptually tracks indices of arguments of "OpHint functions". 159 160 The inputs and arguments of these functions both use an instance 161 of the class so they can have independent numbering. 162 """ 163 164 def __init__(self, 165 function_name, 166 unique_function_id, 167 node_name_prefix, 168 attr_name, 169 level=1, 170 children_inputs_mappings=None): 171 """Initialize ophint argument. 172 173 Args: 174 function_name: Name of the function that this tracks arguments for. 175 unique_function_id: UUID of function that this tracks arguments for. 176 node_name_prefix: How identities that are created are named. 177 attr_name: Name of attribute to use to store the index for this hint. 178 i.e. FUNCTION_INPUT_INDEX or FUNCTION_OUTPUT_INDEX 179 level: Hierarchical level of the Ophint node, a number. 180 children_inputs_mappings: Inputs/Outputs mapping for children hints. 181 """ 182 183 # The global index is the argument index of the op. This is in contrast 184 # to the sort index which is the sequence number of a particular instance 185 # of a given global index. For example, you may have called add hint 186 # twice with the tag "foo". Then the global index will be 0 for both 187 # and the sort index will be 0 for the first added and 1 for the second. 188 self._function_name = function_name 189 self._unique_function_id = unique_function_id 190 self._next_global_index = 0 # The absolute global index 191 self._used_global_indices = set() 192 self._tag_to_global_index = {} # The argument index a given tag maps to 193 self._tag_to_next_sort_index = {} # The current index for each tag 194 self._node_name_prefix = node_name_prefix 195 self._attr_name = attr_name 196 self._level = level 197 self._children_inputs_mappings = children_inputs_mappings 198 199 def _get_new_global_index(self, index_override): 200 """Return the next unused argument index in order or use an override. 201 202 Args: 203 index_override: An index to use instead of the next available or None 204 to use the next available. 205 206 Returns: 207 A valid global_index to use for the next hint argument. 208 209 Raises: 210 ValueError: If the index_override is already used by another hint. 211 """ 212 if index_override is None: 213 global_index = self._next_global_index 214 else: 215 if index_override in self._used_global_indices: 216 raise ValueError("Index %d was already used by another call to add") 217 global_index = index_override 218 # Make next_global_index valid 219 self._used_global_indices.add(global_index) 220 while self._next_global_index in self._used_global_indices: 221 self._next_global_index += 1 222 return global_index 223 224 def add(self, arg, tag=None, name=None, aggregate=None, 225 index_override=None): 226 """Return a wrapped tensor of an input tensor as an argument. 227 228 Args: 229 arg: A TensorFlow tensor that should be considered an argument. 230 tag: String tag to identify arguments that should be packed. 231 name: Name of argument. This is included in the Identity hint op names. 232 aggregate: Strategy to aggregate. 233 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 234 and OpHint.AGGREGATE_STACK. 235 Note, aggregate is only valid if tag is specified. 236 index_override: Specify what input/output index should this be in the 237 final stub. i.e. add(arg0, index=1); add(arg1, index=0) will make the 238 final stub be as stub_func(inputs[arg1, arg0], outputs=[]) rather than 239 the default call order based ordering. 240 241 Returns: 242 A tensor representing the wrapped argument. 243 244 Raises: 245 ValueError: When indices are not consistent. 246 """ 247 248 # Find the appropriate index 249 if tag is None: 250 if aggregate is not None: 251 raise ValueError("You must specify `tag` if using aggregate.") 252 global_index = self._get_new_global_index(index_override) 253 sort_index = None 254 else: 255 if aggregate is None: 256 raise ValueError("You must specify `aggregate` if using tag.") 257 if tag not in self._tag_to_global_index: 258 self._tag_to_global_index[tag] = ( 259 self._get_new_global_index(index_override)) 260 self._tag_to_next_sort_index[tag] = 0 261 elif (index_override and 262 index_override != self._tag_to_global_index[tag]): 263 raise ValueError( 264 "Tag %r was called with two indices %r and %r" % 265 (tag, index_override, self._tag_to_global_index[tag])) 266 global_index = self._tag_to_global_index[tag] 267 sort_index = self._tag_to_next_sort_index[tag] 268 self._tag_to_next_sort_index[tag] += 1 269 270 uuid = self._unique_function_id 271 name = "%s-%s-%s-%r-%r-%s" % (self._node_name_prefix, self._function_name, 272 uuid, global_index, sort_index, name) 273 274 identity_op = _array_ops.identity(arg, name=name) 275 276 # pylint: disable=protected-access 277 identity_op.op._set_attr( 278 OpHint.FUNCTION_NAME_ATTR, 279 _attr_value_pb2.AttrValue( 280 s=_compat.as_bytes(self._function_name))) 281 identity_op.op._set_attr( 282 OpHint.FUNCTION_UUID_ATTR, 283 _attr_value_pb2.AttrValue( 284 s=_compat.as_bytes(self._unique_function_id))) 285 identity_op.op._set_attr( 286 self._attr_name, _attr_value_pb2.AttrValue(i=global_index)) 287 identity_op.op._set_attr(OpHint.FUNCTION_LEVEL_ATTR, 288 _attr_value_pb2.AttrValue(i=self._level)) 289 if self._children_inputs_mappings: 290 identity_op.op._set_attr( 291 OpHint.CHILDREN_INPUTS_MAPPINGS, 292 _attr_value_pb2.AttrValue( 293 s=_compat.as_bytes(_json.dumps( 294 self._children_inputs_mappings)))) 295 296 if sort_index is not None: 297 identity_op.op._set_attr( 298 OpHint.FUNCTION_SORT_INDEX_ATTR, 299 _attr_value_pb2.AttrValue(i=sort_index)) 300 if aggregate is not None: 301 identity_op.op._set_attr( 302 OpHint.FUNCTION_AGGREGATE_ATTR, 303 _attr_value_pb2.AttrValue(s=_compat.as_bytes((aggregate)))) 304 # pylint: enable=protected-access 305 return identity_op 306 307 def __init__(self, 308 function_name, 309 level=1, 310 children_inputs_mappings=None, 311 **kwargs): 312 """Create a OpHint. 313 314 Args: 315 function_name: Name of the function (the custom op name in tflite) 316 level: OpHint level. 317 children_inputs_mappings: Children OpHint inputs/outputs mapping. 318 children_inputs_mappings should like below: 319 "parent_first_child_input": 320 [{"parent_input_index": num, "child_input_index": num}, ...] 321 "parent_last_child_output": 322 [{"parent_output_index": num, "child_output_index": num}, ...] 323 "internal_children_input_output": 324 [{"child_input_index": num, "child_output_index": num}, ...] 325 **kwargs: Keyword arguments of any constant attributes for the function. 326 """ 327 self._function_name = function_name 328 self._level = level 329 if self._level == 1: 330 assert children_inputs_mappings is None 331 else: 332 assert isinstance(children_inputs_mappings, dict) 333 self._children_inputs_mappings = children_inputs_mappings 334 if self._children_inputs_mappings is not None: 335 self._validate_children_inputs_mappings(self._children_inputs_mappings) 336 self._unique_function_id = _uuid.uuid1().hex # TODO(aselle): Unique enough? 337 self._attrs_to_store_later = kwargs 338 self._stored_attrs = False 339 self._inputs = OpHint.OpHintArgumentTracker( 340 self._function_name, self._unique_function_id, "InputHint", 341 OpHint.FUNCTION_INPUT_INDEX_ATTR, level, self._children_inputs_mappings) 342 self._outputs = OpHint.OpHintArgumentTracker( 343 self._function_name, self._unique_function_id, "OutputHint", 344 OpHint.FUNCTION_OUTPUT_INDEX_ATTR, level, 345 self._children_inputs_mappings) 346 347 def _validate_children_inputs_mappings(self, children_inputs_mappings): 348 """Validate children inputs mappings is in the right format. 349 350 Args: 351 children_inputs_mappings: the Children ophint inputs/outputs mapping. 352 """ 353 assert isinstance(children_inputs_mappings, dict) 354 assert "parent_first_child_input" in children_inputs_mappings 355 assert "parent_last_child_output" in children_inputs_mappings 356 assert "internal_children_input_output" in children_inputs_mappings 357 358 # validate parent_first_child_input. 359 360 def assert_dictlist_has_keys(dictlist, keys): 361 for dikt in dictlist: 362 assert isinstance(dikt, dict) 363 for key in keys: 364 assert key in dikt 365 366 assert_dictlist_has_keys( 367 children_inputs_mappings["parent_first_child_input"], 368 ["parent_ophint_input_index", "first_child_ophint_input_index"]) 369 assert_dictlist_has_keys( 370 children_inputs_mappings["parent_last_child_output"], 371 ["parent_output_index", "child_output_index"]) 372 assert_dictlist_has_keys( 373 children_inputs_mappings["internal_children_input_output"], 374 ["child_input_index", "child_output_index"]) 375 376 def _setattr(self, dest_op, name, value): 377 tensor_value = _ops.convert_to_tensor(value) 378 # pylint: disable=protected-access 379 dest_op.op._set_attr(name, _attr_value_pb2.AttrValue( 380 tensor=tensor_value.op.node_def.attr["value"].tensor)) 381 # pylint: enable=protected-access 382 383 def add_input(self, *args, **kwargs): 384 """Add a wrapped input argument to the hint. 385 386 Args: 387 *args: The input tensor. 388 **kwargs: 389 "name" label 390 "tag" a tag to group multiple arguments that will be aggregated. I.e. 391 a string like 'cool_input'. Basically multiple inputs can be added 392 to the same hint for parallel operations that will eventually be 393 combined. An example would be static_rnn which creates multiple copies 394 of state or inputs. 395 "aggregate" aggregation strategy that is valid only for tag non None. 396 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 397 and OpHint.AGGREGATE_STACK. 398 "index_override" The global index to use. This corresponds to the 399 argument order in the final stub that will be generated. 400 Returns: 401 The wrapped input tensor. 402 """ 403 return self._inputs.add(*args, **kwargs) 404 405 def add_output(self, *args, **kwargs): 406 """Add a wrapped output argument to the hint. 407 408 Args: 409 *args: The output tensor. 410 **kwargs: 411 "name" label 412 "tag" a tag to group multiple arguments that will be aggregated. I.e. 413 a string like 'cool_input'. Basically multiple inputs can be added 414 to the same hint for parallel operations that will eventually be 415 combined. An example would be static_rnn which creates multiple copies 416 of state or inputs. 417 "aggregate" aggregation strategy that is valid only for tag non None. 418 Acceptable values are OpHint.AGGREGATE_FIRST, OpHint.AGGREGATE_LAST, 419 and OpHint.AGGREGATE_STACK. 420 "index_override" The global index to use. This corresponds to the 421 argument order in the final stub that will be generated. 422 Returns: 423 The wrapped output tensor. 424 """ 425 return self._outputs.add(*args, **kwargs) 426 427 def add_inputs(self, *args, **kwargs): 428 """Add a sequence of inputs to the function invocation. 429 430 Args: 431 *args: List of inputs to be converted (should be Tf.Tensor). 432 **kwargs: This allows 'names' which should be a list of names. 433 Returns: 434 Wrapped inputs (identity standins that have additional metadata). These 435 are also are also tf.Tensor's. 436 """ 437 if "names" in kwargs: 438 return [ 439 self._inputs.add(arg, name=name) 440 for arg, name in zip(args, kwargs["names"]) 441 ] 442 else: 443 return [self._inputs.add(arg) for arg in args] 444 445 def add_outputs(self, *args, **kwargs): 446 """Add a sequence of outputs to the function invocation. 447 448 Args: 449 *args: List of outputs to be converted (should be tf.Tensor). 450 **kwargs: See 451 Returns: 452 Wrapped outputs (identity standins that have additional metadata). These 453 are also tf.Tensor's. 454 """ 455 if "names" in kwargs: 456 return [ 457 self._outputs.add(arg, name=name) 458 for arg, name in zip(args, kwargs["names"]) 459 ] 460 else: 461 return [self._outputs.add(arg) for arg in args] 462 463 464class _LiteOperand(object): 465 """Abstract operand for a tflite hint function._dynamic_rnn_loop. 466 467 This is a base class that handles representing arguments to an OpHint. 468 It also is able to serialize operands to the stubbed graph_def. 469 Child classes are responsible for being able to 470 store information about the hint identity operators. They are also responsible 471 for knowing how to serialize to output graphdefs. 472 473 Typically this will be implemented by holding one or more identity nodes 474 that were previously discovered as hints. 475 """ 476 477 def aggregate_and_return_name_for_input(self, out_graphdef): 478 """This adds the node(s) to out_graphdef and returns the input node name. 479 480 Args: 481 out_graphdef: A graphdef that is ready to have this input added. 482 483 Returns: 484 The output that the stub should use as an input for this operand. 485 486 Raises: 487 RuntimeError: if the method is not implemented. 488 """ 489 del out_graphdef 490 raise RuntimeError("Unimplemented abstract method.") 491 492 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 493 out_graphdef): 494 """Add node(s) to graph representing output operands and returns type. 495 496 Args: 497 fused_op_name: name of the fused op stub name. 498 output_index: Output index that we are currently processing from stub. 499 out_graphdef: The destination graphdef we are currently building up. 500 501 Returns: 502 The datatype of this identity. 503 504 Raises: 505 RuntimeError: if the method is not implemented. 506 """ 507 del fused_op_name, output_index, out_graphdef 508 raise RuntimeError("Unimplemented abstract method.") 509 510 511class _LiteSingleOperand(_LiteOperand): 512 """A simple operand that is non-aggregated (i.e. most hints).""" 513 514 def __init__(self, node): 515 _LiteOperand.__init__(self) 516 self.node = node 517 self.name = _tensor_name_base(node.name) 518 519 def flatten(self): 520 return [self.name] 521 522 def aggregate_and_return_name_for_input(self, out_graphdef): 523 return self.name 524 525 def aggregate_and_return_name_for_output(self, fused_op_name, index, 526 out_graphdef): 527 output_node = _copy.deepcopy(self.node) 528 del output_node.input[:] 529 output_node.input.append(_tensorflow_output_name(fused_op_name, index)) 530 out_graphdef.node.extend([output_node]) 531 return self.node.attr["type"].i 532 533 def __str__(self): 534 return str(self.name) 535 536 537class _LiteAggregateOperand(_LiteOperand): 538 """An operand for a tflite hint function that is aggregated from many. 539 540 For example, an LSTM is a grid of operators that are all related. Inputs 541 going into them may need to be fused, so they should all be tracked as 542 related arguments. 543 """ 544 545 def __init__(self, aggregation): 546 _LiteOperand.__init__(self) 547 self.aggregation = aggregation 548 self.names = {} 549 self.nodes = {} 550 self.flattened = None 551 552 def add(self, sort, node): 553 self.names[sort] = _tensor_name_base(node.name) 554 self.nodes[sort] = node 555 556 def flatten_nodes(self): 557 """Return a list of all the node protos in aggregation sorted order.""" 558 if not self.flattened: 559 self.flattened = [None] * len(self.nodes) 560 for idx, node in _six.iteritems(self.nodes): 561 self.flattened[idx] = node 562 for n in self.nodes: 563 if n is None: 564 raise RuntimeError("Aggregate was missing argument.") 565 if self.aggregation == OpHint.AGGREGATE_FIRST: 566 self.flattened = self.flattened[:1] 567 elif self.aggregation == OpHint.AGGREGATE_LAST: 568 self.flattened = self.flattened[-1:] 569 elif self.aggregation == OpHint.AGGREGATE_STACK: 570 pass 571 else: 572 raise ValueError( 573 "Invalid aggregation type %r specified" % self.aggregation) 574 return self.flattened 575 576 def flatten(self): 577 """Return a list of all node names in aggregation sorted sorter.""" 578 return [_tensor_name_base(x.name) for x in self.flatten_nodes()] 579 580 def aggregate_and_return_name_for_input(self, out_graphdef): 581 """This adds the nodes to out_graphdef and returns an aggregated output. 582 583 In particular, if you have 4 inputs to a hint stub, this will be the 584 node that you can use as an output. I.e. you have 4 timesteps from a 585 static rnn, then a fused UnidriecitonalLSTM will expect 1 input with 586 all 4 time steps. So here we make a pack and return the output name of 587 that pack. 588 589 Args: 590 out_graphdef: A graphdef that is ready to have this input added. 591 592 Returns: 593 The name of a pack that aggregates this node. 594 """ 595 flattened = self.flatten_nodes() 596 if len(flattened) == 1: 597 return _tensor_name_base(flattened[0].name) 598 else: 599 new_node = _node_def_pb2.NodeDef() 600 new_node.op = "Pack" 601 new_node.name = "OpHintStack-%s" % flattened[0].name 602 new_node.attr["N"].i = len(flattened) 603 new_node.attr["T"].type = flattened[0].attr["T"].type 604 for discrete in flattened: 605 new_node.input.append(_tensor_name_base(discrete.name)) 606 out_graphdef.node.extend([new_node]) 607 return new_node.name 608 609 def aggregate_and_return_name_for_output(self, fused_op_name, output_index, 610 out_graphdef): 611 """This adds to `out_graphdef` all the unaggregated outputs. 612 613 I.e. we are outputting from a fused stub, but we need to make it compatible 614 with the unfused original graph so we insert an unpack. Ideally in a later 615 stage the unpack -> pack sequences will be removed. 616 617 Args: 618 fused_op_name: The name of the stub we are in the process of fusing. 619 output_index: The output output_index this object represents. 620 out_graphdef: The graphdef we are in the process of buildings 621 622 Returns: 623 The type of the aggregated output (so we can finish building the stub 624 op). 625 """ 626 flattened = self.flatten_nodes() 627 if len(flattened) == 1: 628 temp_op = _LiteSingleOperand(flattened[0]) 629 return temp_op.aggregate_and_return_name_for_output( 630 fused_op_name, output_index, out_graphdef) 631 else: 632 stack_node = _node_def_pb2.NodeDef() 633 stack_node.op = "Unpack" 634 stack_node.name = "OpHintUnstack-%s" % flattened[0].name 635 stack_node.attr["num"].i = len(flattened) 636 output_type = flattened[0].attr["T"].type 637 stack_node.attr["T"].type = output_type 638 stack_node.input.append(_tensorflow_output_name( 639 fused_op_name, output_index)) 640 out_graphdef.node.extend([stack_node]) 641 642 for idx, discrete in enumerate(flattened): 643 output_node = _copy.deepcopy(discrete) 644 del output_node.input[:] 645 output_node.input.append(_tensorflow_output_name(stack_node.name, idx)) 646 out_graphdef.node.extend([output_node]) 647 648 return output_type 649 650 def __str__(self): 651 s = "\t\t\tAGGREGATE %s\n" % self.aggregation 652 for sort, val in self.names.iteritems(): 653 s += "\t\t\t%d: %s\n" % (sort, val) 654 return s 655 656 657class _LiteFuncCall(object): 658 """Represent a TensorFlow Lite custom function. 659 660 This is uses to accumulate found hints in the graphdef into a single 661 conceptual unit. 662 663 Attributes: 664 inputs: inputs to the op (hash from index # to argument) 665 outputs: outputs to the op (hash from index # to argument) 666 function_name: the tflite custom op name to use 667 uuid: a unique call id for this particular call (i.e. 668 multiple function calls would have the same function_name but different 669 uuids. 670 params: A param name to key value for op constant data. I.e. for 671 axis on a reduction, strides on a convolution, etc. 672 level: Level of the OpHint. 673 children_inputs_mappings: If the Ophint has children, children inputs 674 mappings indicate how their inputs & outputs are mapped. 675 """ 676 677 def __init__(self): 678 self.inputs = {} 679 self.outputs = {} 680 self.function_name = None 681 self.uuid = None 682 self.params = {} 683 self.level = -1 684 self.children_inputs_mappings = {} 685 686 def flattened_inputs_and_outputs(self): 687 """Return a list of inputs and outputs in a flattened format. 688 689 Returns: 690 Tuple of (inputs, outputs). where input and output i a list of names. 691 """ 692 def _flatten(input_or_output_dict): 693 flattened_items = [] 694 for item in input_or_output_dict.values(): 695 flattened_items.extend(item.flatten()) 696 return flattened_items 697 698 return _flatten(self.inputs), _flatten(self.outputs) 699 700 def __str__(self): 701 def format_args(items): 702 s = "" 703 for idx, item in items.iteritems(): 704 s += ("\t\t%d:\n" % idx) + str(item) 705 return s 706 707 inputs_str = "\tInputs\n" + format_args(self.inputs) 708 outputs_str = "\tOutputs\n" + format_args(self.outputs) 709 710 return ( 711 "tflite function %s call %s level %d " 712 "\n\tinputs:\n\t\t%s\n\toutputs:\n\t\t%s" % 713 (self.function_name, self.uuid, self.level, inputs_str, outputs_str)) 714 715 716def _find_all_hints_in_nodes(nodes): 717 """Look at the all the input nodes and return a list of LiteFuncCall objs. 718 719 Args: 720 nodes: A TensorFlow graph_def to look for LiteFuncCalls. 721 722 Returns: 723 a list of `LifeFuncCall` objects in the form 724 725 """ 726 func_calls = _collections.defaultdict(_LiteFuncCall) 727 728 for node in nodes: 729 attr = node.attr 730 # This is an op hint if it has a FUNCTION_UUID_ATTR, otherwise skip 731 uuid = attr[OpHint.FUNCTION_UUID_ATTR].s 732 if (OpHint.FUNCTION_UUID_ATTR not in attr 733 or not attr[OpHint.FUNCTION_UUID_ATTR].s): 734 continue 735 736 # Start building function 737 call_def = func_calls[uuid] 738 call_def.uuid = uuid 739 call_def.function_name = attr[OpHint.FUNCTION_NAME_ATTR].s 740 call_def.level = attr[OpHint.FUNCTION_LEVEL_ATTR].i 741 # Get sorting and aggregation information 742 743 sort = (attr[OpHint.FUNCTION_SORT_INDEX_ATTR].i 744 if OpHint.FUNCTION_SORT_INDEX_ATTR in attr else None) 745 if sort == -1: sort = None 746 aggregation = None 747 if OpHint.FUNCTION_AGGREGATE_ATTR in attr: 748 aggregation = _compat.as_text(attr[OpHint.FUNCTION_AGGREGATE_ATTR].s) 749 750 if OpHint.CHILDREN_INPUTS_MAPPINGS in attr: 751 call_def.children_inputs_mappings = _json.loads( 752 _compat.as_text(attr[OpHint.CHILDREN_INPUTS_MAPPINGS].s)) 753 754 # Add the input or output 755 def put_operand(stuff, index, sort, operand, aggregation): 756 """Add a given index into the function structure.""" 757 if sort is None: 758 stuff[index] = _LiteSingleOperand(operand) 759 else: 760 if index not in stuff: 761 stuff[index] = _LiteAggregateOperand(aggregation) 762 stuff[index].add(sort, operand) 763 764 if OpHint.FUNCTION_INPUT_INDEX_ATTR in attr: 765 put_operand(call_def.inputs, attr[OpHint.FUNCTION_INPUT_INDEX_ATTR].i, 766 sort, node, aggregation) 767 if OpHint.FUNCTION_OUTPUT_INDEX_ATTR in attr: 768 put_operand(call_def.outputs, attr[OpHint.FUNCTION_OUTPUT_INDEX_ATTR].i, 769 sort, node, aggregation) 770 771 # Remember attributes 772 for a in attr: 773 if a.startswith("_tflite_attr_"): 774 call_def.params[a.replace("_tflite_attr_,", "")] = attr[a].tensor 775 776 return func_calls 777 778 779def _extract_topology_sequence_mapping(nodes): 780 return dict( 781 (_tensor_name_base(node.name), idx) for idx, node in enumerate(nodes)) 782 783 784def _find_children_hints_in_while_loop(function_def, nodes_mapping): 785 """Find children hints and all nodes inside the while loop. 786 787 Args: 788 function_def: Function def of the while loop. 789 nodes_mapping: While loop input_arg : real node name. 790 791 Returns: 792 Ordered children hints and all re-mapped nodes inside the while loop. 793 """ 794 new_nodes = [] 795 796 # Make nodes inside function def inputs point to the real nodes. 797 for node in function_def.node_def: 798 for i in range(len(node.input)): 799 if node.input[i] in nodes_mapping: 800 node.input[i] = nodes_mapping[node.input[i]] 801 new_nodes.append(_copy.deepcopy(node)) 802 name_to_seq_num = _extract_topology_sequence_mapping(function_def.node_def) 803 children_hints = _find_all_hints_in_nodes(new_nodes) 804 children_hints_q = [] 805 # Ordered by the outputs. 806 for hint in _six.itervalues(children_hints): 807 _, output_names = hint.flattened_inputs_and_outputs() 808 seq = name_to_seq_num[output_names[0]] 809 for output_name in output_names: 810 seq = min(seq, name_to_seq_num[output_name]) 811 children_hints_q.append((seq, hint)) 812 children_hints_q.sort(key=lambda tup: tup[0]) 813 ordered_children_hints = [x[1] for x in children_hints_q] 814 return ordered_children_hints, new_nodes 815 816 817def _find_children_hints(call, graph_def): 818 """Find all children hints. 819 820 For a given OpHint, we find all children hints inside it, we also copy all the 821 nodes inside function defs (if applicable) to the original graph_def, they are 822 returned in a list as well. 823 824 Args: 825 call: Parent OpHint that contains children ophints. 826 graph_def: Original graph def. 827 828 Returns: 829 Ordered children hints inside the parent ophint; new graph def that contains 830 nodes inside function defs (if applicable); nodes inside function defs. 831 """ 832 name_to_input_name, _, _ = _extract_graph_summary(graph_def) 833 input_names, output_names = call.flattened_inputs_and_outputs() 834 835 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 836 reachable_by_output = _bfs_for_reachable_nodes(output_names, 837 name_to_input_name) 838 output_nodes_set = set(output_names) 839 children_hints = [] 840 out = _graph_pb2.GraphDef() 841 out.library.CopyFrom(graph_def.library) 842 out.versions.CopyFrom(graph_def.versions) 843 function_def_nodes = set() 844 for node in graph_def.node: 845 out.node.extend([_copy.deepcopy(node)]) 846 n = _tensor_name_base(node.name) 847 if n in reachable_by_output: 848 if n not in reachable_by_input and n not in output_nodes_set: 849 # special handle for while loop function def. 850 if node.op == "While": 851 body_name = node.attr["body"].func.name 852 inputs_outside_loop = node.input 853 for function_def in graph_def.library.function: 854 if function_def.signature.name == body_name: 855 function_inputs = function_def.signature.input_arg 856 assert len(inputs_outside_loop) == len(function_inputs) 857 nodes_mapping = {} 858 for i in range(len(function_inputs)): 859 nodes_mapping[function_inputs[i].name] = inputs_outside_loop[i] 860 # TODO(b/123050804): Consider use grappler. 861 (children_hints_in_loop, 862 new_nodes) = _find_children_hints_in_while_loop( 863 function_def, nodes_mapping) 864 function_def_nodes.update([x.name for x in new_nodes]) 865 children_hints.extend(children_hints_in_loop) 866 out.node.extend(new_nodes) 867 868 return children_hints, out, function_def_nodes 869 870 871def _tensor_name_base(full_tensor_name): 872 """Removes the device assignment code from a tensor. 873 874 e.g. _tensor_name_base("foo:3") => "foo" 875 876 Args: 877 full_tensor_name: A tensor name that is annotated with a device placement 878 (this is what tensor flow introspection gives). 879 Returns: 880 A name without any device assignment. 881 """ 882 if full_tensor_name.startswith("^"): 883 return full_tensor_name[1:] 884 return full_tensor_name.split(":")[0] 885 886 887def _tensorflow_output_name(tensor_name, output_index): 888 return tensor_name if output_index == 0 else "%s:%d" % (tensor_name, 889 output_index) 890 891 892# TODO(aselle): This should be converted to grappler in the future. 893def _check_subgraph_closed(n, reachable_by_input, input_nodes_set, 894 name_to_input_name): 895 """Checks to make sure node only connects to predecessor graph through inputs. 896 897 Args: 898 n: Node to check 899 reachable_by_input: Nodes that are reachable by all inputs of subgraph 900 input_nodes_set: The set of nodes that are "inputs". 901 name_to_input_name: Maps from name to the list of inputs. 902 903 Raises: 904 TypeError: If the given node uses items past inputs directly. 905 """ 906 next_to_visit = [n] 907 visited = set() 908 while next_to_visit: 909 current_node = next_to_visit.pop() 910 visited.add(current_node) 911 if (current_node in reachable_by_input 912 and current_node not in input_nodes_set): 913 raise TypeError( 914 "Node %s uses input %s not in input_nodes." % (n, current_node)) 915 if current_node not in input_nodes_set: 916 next_to_visit += [ 917 input_node for input_node in name_to_input_name[current_node] 918 if input_node not in visited 919 ] 920 921 922# TODO(aselle): This should be converted to grappler in the future. 923def _convert_single_op_hint_to_stub(call, 924 graph_def, 925 function_def_nodes=None, 926 is_last_run=True): 927 """Given a graph_def, converts `call` into a stub and returns a new graph_def. 928 929 Args: 930 call: A single function call to be converted. 931 graph_def: A graph_def to use as input (that has call obviously). 932 function_def_nodes: Nodes inside the function def those are not connected to 933 the graph. 934 is_last_run: Whether it is the last run for a given pass (for OpHint has 935 children). 936 937 Returns: 938 A new transformed graph-def that has call as a stub (single op). 939 940 Note: after this process, the graph_def can no longer be loaded into 941 the tensorflow runtime, so all future manipulations are done in graph_def 942 level. 943 """ 944 if function_def_nodes is None: 945 function_def_nodes = set() 946 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 947 graph_def) 948 input_names, output_names = call.flattened_inputs_and_outputs() 949 950 reachable_by_input = _bfs_for_reachable_nodes(input_names, name_to_input_name) 951 reachable_by_output = _bfs_for_reachable_nodes(output_names, 952 name_to_input_name) 953 output_nodes_set = set(output_names) 954 nodes_after_fuse = [] 955 nodes_deleted_by_fuse = set() 956 # Classify each node. We want to keep everything reachable by input, but 957 # we don't know if things that are not reachable by output or input (things 958 # after fusing). 959 for node in graph_def.node: 960 n = _tensor_name_base(node.name) 961 if n in reachable_by_output: 962 if n not in reachable_by_input and n not in output_nodes_set: 963 nodes_deleted_by_fuse.add(n) 964 elif n not in reachable_by_input and n not in function_def_nodes: 965 # n is a node that after all the fusings, so keep it. 966 nodes_after_fuse.append(n) 967 else: 968 # In the last run, n is a node that is randomly in the graph but not 969 # connected to the chain of dependencies, we will delete n, otherwise 970 # we keep them. 971 if not is_last_run: 972 nodes_after_fuse.append(n) 973 974 # Make a new graphdef with all the pre-input and input nodes 975 out = _graph_pb2.GraphDef() 976 reachable_by_input_sorted = sorted( 977 list(reachable_by_input), key=lambda n: name_to_seq_num[n]) 978 for node in reachable_by_input_sorted: 979 out.node.extend([_copy.deepcopy(name_to_node[node])]) 980 981 # Create any stacks to aggregate arguments into to a single input 982 # i.e. for static_rnn's. 983 # TODO(aselle): Check that the inputs are complete i.e. 0 to n-1 984 sorted_input_indices = list(call.inputs.keys()) 985 sorted_input_indices.sort() 986 sorted_output_indices = list(call.outputs.keys()) 987 sorted_output_indices.sort() 988 new_node = _node_def_pb2.NodeDef() 989 # Delegate to each operand to produce the proper new input for this stub node. 990 # In particular, an aggregate input will now be a Pack of some previously 991 # non-fused things. 992 for input_index in sorted_input_indices: 993 inputs = call.inputs[input_index] 994 input_name = inputs.aggregate_and_return_name_for_input(out) 995 new_node.input.append(input_name) 996 new_node.attr[OpHint.TFLITE_INPUT_INDICES].list.i.extend(sorted_input_indices) 997 998 # Create the function 999 new_node.op = call.function_name 1000 new_node.name = call.uuid 1001 out.node.extend([new_node]) 1002 1003 # Now call each output argument to give them a chance to make the proper 1004 # output type and add it to our new_node. 1005 output_dtypes = [] 1006 for output_index in sorted_output_indices: 1007 output = call.outputs[output_index] 1008 output_dtype = ( 1009 output.aggregate_and_return_name_for_output(new_node.name, output_index, 1010 out)) 1011 output_dtypes.append(output_dtype) 1012 new_node.attr["_output_types"].list.type[:] = output_dtypes 1013 # TODO(aselle): what is right here? 1014 new_node.attr["_output_quantized"].b = False 1015 1016 # Add post output nodes that do not depend on the outputs 1017 for n in nodes_after_fuse: 1018 should_keep = True 1019 for input_name in name_to_input_name[n]: 1020 if input_name in nodes_deleted_by_fuse: 1021 should_keep = False 1022 if should_keep: 1023 out.node.extend([_copy.deepcopy(name_to_node[n])]) 1024 1025 # Misc. graph_def data that needs copying. 1026 out.library.CopyFrom(graph_def.library) 1027 out.versions.CopyFrom(graph_def.versions) 1028 1029 return out 1030 1031 1032# TODO(aselle): This should be converted to grappler in the future. 1033def _remove_one_redundant_stack_unstack(in_graph_def): 1034 """Removes a stack->unstack pattern from in_graph_def in a returned graph. 1035 1036 Args: 1037 in_graph_def: Graph def to use as input. 1038 Returns: 1039 Simplified tuple (graph_def, changed_something) where changed_something 1040 is true if anything was done. 1041 """ 1042 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 1043 in_graph_def) 1044 del name_to_seq_num 1045 1046 # TODO(aselle): Make this not hardcoded. 1047 do_generic_pack_unpack = True 1048 1049 out = _graph_pb2.GraphDef() 1050 out.library.CopyFrom(in_graph_def.library) 1051 out.versions.CopyFrom(in_graph_def.versions) 1052 for n in in_graph_def.node: 1053 node_name = _tensor_name_base(n.name) 1054 if not node_name.startswith("OpHintStack") and not n.op.startswith("Pack"): 1055 continue 1056 next_to_visit = [node_name] 1057 visited = set() 1058 1059 unpack_nodes = set() 1060 pack_node = node_name 1061 1062 # Find a pattern of unstack connected to a stack (with identities 1063 # in between. 1064 matches_pattern = True 1065 is_hint_created_stack = False 1066 while next_to_visit: 1067 current_node_name = next_to_visit[0] 1068 visited.add(current_node_name) 1069 del next_to_visit[0] 1070 node = name_to_node[current_node_name] 1071 is_op_hint_stack = node.name.startswith("OpHintStack") 1072 is_op_hint_unstack = node.name.startswith("OpHintUnstack") 1073 if (node.op == "Identity" or is_op_hint_stack 1074 or (do_generic_pack_unpack and node.op == "Pack")): 1075 is_hint_created_stack |= is_op_hint_stack 1076 next_to_visit += [ 1077 input_node for input_node in name_to_input_name[current_node_name] 1078 if input_node not in visited 1079 ] 1080 elif (is_op_hint_unstack 1081 or (do_generic_pack_unpack and node.op == "Unpack")): 1082 unpack_nodes.add(node.name) 1083 is_hint_created_stack &= is_op_hint_unstack 1084 else: 1085 matches_pattern = False 1086 break 1087 visited.add(node.name) 1088 1089 if matches_pattern and len(unpack_nodes) == 1: 1090 pack_node = node_name 1091 1092 # Check to see if anyone depends on the intermediate identity or the 1093 # Unstacked form 1094 no_external_dependency = True 1095 for other_n in in_graph_def.node: 1096 if other_n.name in visited: continue 1097 for input_tensor in name_to_input_name[other_n.name]: 1098 input_op = _tensor_name_base(input_tensor) 1099 if input_op in visited and input_op != pack_node: 1100 no_external_dependency = False 1101 # Proceed with the substitution if the stack/unstack pair was created 1102 # through hints, or that it was not, but nobody is consuming things 1103 # between the stack and unstack. 1104 if is_hint_created_stack or no_external_dependency: 1105 end = unpack_nodes.pop() 1106 end_input = name_to_node[end].input[0] 1107 # All nodes that depend on the final stack need to be redone to use 1108 for other_n in in_graph_def.node: 1109 node_name = _tensor_name_base(other_n.name) 1110 if node_name not in visited: 1111 new_node = _copy.deepcopy(other_n) 1112 new_node.input[:] = [ 1113 (end_input if stripped == pack_node else 1114 non_stripped) for stripped, non_stripped in zip( 1115 name_to_input_name[node_name], new_node.input[:]) 1116 ] 1117 out.node.extend([new_node]) 1118 return out, True 1119 return in_graph_def, False 1120 1121 1122def _remove_redundant_stack_unstack(graph_def): 1123 curr = graph_def 1124 del graph_def 1125 changed_stuff = True 1126 while changed_stuff: 1127 curr, changed_stuff = _remove_one_redundant_stack_unstack(curr) 1128 return curr 1129 1130 1131def _get_correct_mapping(original_index, nodes): 1132 # Special handle for the index is -1 case. 1133 # If it is -1, return the last index. 1134 if original_index == -1: 1135 node_indices = nodes.keys() 1136 node_indices = sorted(node_indices) 1137 return node_indices[-1] 1138 else: 1139 return original_index 1140 return original_index 1141 1142 1143def _convert_op_hints_to_stubs_helper( 1144 graph_def, write_callback=lambda sess, graph_def: None): 1145 """Converts a graph_def to a new graph_def where all op hints are stubbed. 1146 1147 Args: 1148 graph_def: A graph def that we should convert. 1149 write_callback: A function pointer that can be used to write intermediate 1150 steps of graph transformation (optional). 1151 Returns: 1152 A new stubbed graph_def. 1153 """ 1154 hints = _find_all_hints_in_nodes(graph_def.node) 1155 1156 hints_q = [] 1157 for hint in _six.itervalues(hints): 1158 hints_q.append((hint.level, hint.uuid)) 1159 1160 hints_q.sort(key=lambda tup: tup[0]) 1161 for i in range(len(hints_q) - 1, -1, -1): 1162 level, hint_uuid = hints_q[i] 1163 1164 curr_graph_def = graph_def 1165 del graph_def # prevent using graph_def again (common source of error) 1166 for i in range(len(hints_q) - 1, -1, -1): 1167 level, hint_uuid = hints_q[i] 1168 if level >= 2: 1169 children_hints, curr_graph_def, function_def_nodes = _find_children_hints( 1170 hints[hint_uuid], curr_graph_def) 1171 # pylint: disable=superfluous-parens 1172 assert (len(children_hints) > 0) # pylint: disable=g-explicit-length-test 1173 # pylint: enable=superfluous-parens 1174 1175 # Re-wire the children hints inputs/outputs, so latter child's inputs 1176 # connect to previous child node's outputs. 1177 children_inputs_mappings = hints[hint_uuid].children_inputs_mappings 1178 for j in range(len(children_hints)): 1179 child_hint = children_hints[j] 1180 if j == 0: 1181 for mapping in children_inputs_mappings["parent_first_child_input"]: 1182 parent_input_index = _get_correct_mapping( 1183 mapping["parent_ophint_input_index"], hints[hint_uuid].inputs) 1184 child_input_index = _get_correct_mapping( 1185 mapping["first_child_ophint_input_index"], child_hint.inputs) 1186 child_hint.inputs[child_input_index] = hints[hint_uuid].inputs[ 1187 parent_input_index] 1188 else: 1189 for mapping in children_inputs_mappings[ 1190 "internal_children_input_output"]: 1191 input_index = _get_correct_mapping(mapping["child_input_index"], 1192 child_hint.inputs) 1193 output_index = _get_correct_mapping(mapping["child_output_index"], 1194 children_hints[j - 1].outputs) 1195 child_hint.inputs[input_index] = children_hints[ 1196 j - 1].outputs[output_index] 1197 if j == len(children_hints) - 1: 1198 for mapping in children_inputs_mappings["parent_last_child_output"]: 1199 parent_output_index = _get_correct_mapping( 1200 mapping["parent_output_index"], hints[hint_uuid].outputs) 1201 child_output_index = _get_correct_mapping( 1202 mapping["child_output_index"], child_hint.outputs) 1203 child_hint.outputs[child_output_index] = hints[hint_uuid].outputs[ 1204 parent_output_index] 1205 1206 for j in range(len(children_hints)): 1207 child_hint = children_hints[j] 1208 curr_graph_def = _convert_single_op_hint_to_stub( 1209 child_hint, curr_graph_def, function_def_nodes, 1210 j == len(children_hints) - 1) 1211 else: 1212 curr_graph_def = _convert_single_op_hint_to_stub(hints[hint_uuid], 1213 curr_graph_def) 1214 write_callback(curr_graph_def, "initial") 1215 # The stubbing process can create stacks/unstacks in the case of LSTMs 1216 # remove them. 1217 curr_graph_def = _remove_redundant_stack_unstack(curr_graph_def) 1218 return curr_graph_def 1219 1220 1221def find_all_hinted_output_nodes(session=None, graph_def=None): 1222 """Find all Ophints output nodes in the graph. 1223 1224 This is used to get all the output nodes those are ophinted, it is important 1225 for operation like convert_variables_to_constants keep all ophints structure. 1226 Note: only one of session or graph_def should be used, not both. 1227 1228 Args: 1229 session: A TensorFlow session that contains the graph to convert. 1230 graph_def: A graph def that we should convert. 1231 1232 Returns: 1233 A list of OpHints output nodes. 1234 Raises: 1235 ValueError: If both session and graph_def are provided. 1236 """ 1237 if session is not None and graph_def is not None: 1238 raise ValueError("Provide only one of session and graph_def.") 1239 hinted_outputs_nodes = [] 1240 if session is not None: 1241 hints = _find_all_hints_in_nodes(session.graph_def.node) 1242 elif graph_def is not None: 1243 hints = _find_all_hints_in_nodes(graph_def.node) 1244 for hint in _six.itervalues(hints): 1245 _, ouput_nodes = hint.flattened_inputs_and_outputs() 1246 hinted_outputs_nodes.extend(ouput_nodes) 1247 return hinted_outputs_nodes 1248 1249 1250@_tf_export("lite.experimental.convert_op_hints_to_stubs") 1251def convert_op_hints_to_stubs(session=None, 1252 graph_def=None, 1253 write_callback=lambda graph_def, comments: None): 1254 """Converts a graphdef with LiteOp hints into stub operations. 1255 1256 This is used to prepare for toco conversion of complex intrinsic usages. 1257 Note: only one of session or graph_def should be used, not both. 1258 1259 Args: 1260 session: A TensorFlow session that contains the graph to convert. 1261 graph_def: A graph def that we should convert. 1262 write_callback: A function pointer that can be used to write intermediate 1263 steps of graph transformation (optional). 1264 Returns: 1265 A new graphdef with all ops contained in OpHints being replaced by 1266 a single op call with the right parameters. 1267 Raises: 1268 ValueError: If both session and graph_def are provided. 1269 """ 1270 1271 if session is not None and graph_def is not None: 1272 raise ValueError("Provide only one of session and graph_def.") 1273 1274 if session is not None: 1275 return _convert_op_hints_to_stubs_helper(session.graph_def, write_callback) 1276 elif graph_def is not None: 1277 return _convert_op_hints_to_stubs_helper(graph_def, write_callback) 1278 else: 1279 raise ValueError("Must specify session or graph_def as input.") 1280 1281 1282_allowed_symbols = [ 1283 "OpHint", "convert_op_hints_to_stubs", "convert_op_hints_to_stubs_new", 1284 "find_all_hinted_output_nodes" 1285] 1286remove_undocumented(__name__, _allowed_symbols) 1287