1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Compiled parallel-for loop.""" 16# pylint: disable=missing-docstring,g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import string 24import sys 25import traceback 26 27import numpy as np 28import six 29 30from tensorflow.compiler.tf2xla.python import xla 31from tensorflow.core.framework import full_type_pb2 32from tensorflow.python.eager import context 33from tensorflow.python.eager import def_function 34from tensorflow.python.eager import execute 35from tensorflow.python.framework import constant_op 36from tensorflow.python.framework import dtypes 37from tensorflow.python.framework import func_graph 38from tensorflow.python.framework import ops 39from tensorflow.python.framework import smart_cond 40from tensorflow.python.framework import sparse_tensor 41from tensorflow.python.framework import tensor_shape 42from tensorflow.python.framework import tensor_spec 43from tensorflow.python.framework import tensor_util 44from tensorflow.python.ops import array_ops 45from tensorflow.python.ops import control_flow_ops 46from tensorflow.python.ops import data_flow_ops 47from tensorflow.python.ops import gen_array_ops 48from tensorflow.python.ops import gen_dataset_ops 49from tensorflow.python.ops import gen_image_ops 50from tensorflow.python.ops import gen_linalg_ops 51from tensorflow.python.ops import gen_list_ops 52from tensorflow.python.ops import gen_math_ops 53from tensorflow.python.ops import gen_nn_ops 54from tensorflow.python.ops import gen_parsing_ops 55from tensorflow.python.ops import gen_random_ops 56from tensorflow.python.ops import gen_sparse_ops 57from tensorflow.python.ops import gen_spectral_ops 58from tensorflow.python.ops import handle_data_util 59from tensorflow.python.ops import linalg_ops 60from tensorflow.python.ops import list_ops 61from tensorflow.python.ops import manip_ops 62from tensorflow.python.ops import map_fn 63from tensorflow.python.ops import math_ops 64from tensorflow.python.ops import nn_ops 65from tensorflow.python.ops import parsing_ops 66from tensorflow.python.ops import resource_variable_ops 67from tensorflow.python.ops import sparse_ops 68from tensorflow.python.ops import special_math_ops 69from tensorflow.python.ops import tensor_array_ops 70from tensorflow.python.platform import flags 71from tensorflow.python.platform import tf_logging as logging 72from tensorflow.python.util import compat 73from tensorflow.python.util import nest 74from tensorflow.python.util import object_identity 75 76 77# TODO(agarwal): remove flag. 78flags.DEFINE_bool( 79 "op_conversion_fallback_to_while_loop", True, 80 "DEPRECATED: Flag is ignored.") 81 82 83def _variant_handle_data(t): 84 """Fetches handle data for a variant tensor `t`, or None if unavailable.""" 85 handle_data = resource_variable_ops.get_eager_safe_handle_data(t) 86 if not handle_data.is_set: 87 return None 88 return handle_data.shape_and_type 89 90 91def _variant_type_id(t): 92 """Returns the full_type_pb2 type of `t`, or None if it is not available.""" 93 if t.dtype != dtypes.variant: 94 return None 95 shapes_and_types = _variant_handle_data(t) 96 if shapes_and_types is None or not shapes_and_types: 97 # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can 98 # make this an error instead of assuming TensorLists have handle data. 99 return None # Presumed not a TensorList/Optional 100 return shapes_and_types[0].type.type_id 101 102 103_INTERNAL_STACKING_TYPE_IDS = ( 104 full_type_pb2.TFT_ARRAY, 105 full_type_pb2.TFT_OPTIONAL) 106 107 108def _is_variant_with_internal_stacking(t): 109 """Identifies variant tensors which pfor always maintains as scalars. 110 111 For these, the pfor tensor is recorded as "stacked" if the content of the 112 variant tensor (e.g. the elements of a TensorList) are all stacked. 113 114 Args: 115 t: A tensor to identify. 116 Returns: 117 True if `t` is a TensorList/Optional, False not, None if unknown. 118 """ 119 type_id = _variant_type_id(t) 120 return type_id in _INTERNAL_STACKING_TYPE_IDS 121 122 123def _parse_variant_shapes_and_types(t): 124 """Extracts shape and dtype information from a variant tensor `t`.""" 125 shapes_and_types = _variant_handle_data(t) 126 if shapes_and_types is None or not shapes_and_types: 127 raise ValueError("Required handle data not set for {!r}".format(t)) 128 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: 129 return shapes_and_types 130 else: 131 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET: 132 return shapes_and_types 133 else: 134 raise ValueError( 135 "Attempted to stack a variant-dtype tensor with no type set ({!r})" 136 .format(t)) 137 138 139def _stack(t, length): 140 """stacks `t` `length` times.""" 141 # Note that this stacking may currently be triggered, for example, when a 142 # loop invariant tensor with dtype variant is input to a while_loop which then 143 # produces a loop dependent output. Simply stacking the variants may not be 144 # suitable since operations on stacked handles may expect a vectorized version 145 # of the variant. 146 if t.dtype == dtypes.variant: 147 shapes_and_types = _parse_variant_shapes_and_types(t) 148 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: 149 if len(shapes_and_types) != 1: 150 raise ValueError( 151 "Expected handle data of length 1, got {!r} of length {}" 152 .format(shapes_and_types, len(shapes_and_types))) 153 return wrap( 154 _stack_tensor_list(t, shapes_and_types[0].dtype, length), 155 True) 156 else: 157 raise ValueError( 158 ("Attempted to stack an unhandled variant-dtype tensor of " 159 "type {!r} ({!r})").format(shapes_and_types[0].type, t)) 160 ones = array_ops.ones_like(array_ops.shape(t)) 161 ones = array_ops.reshape(ones, [-1]) 162 length = array_ops.reshape(length, [-1]) 163 multiples = array_ops.concat([length, ones], 0) 164 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) 165 return wrap(t, True) 166 167 168# The following stateful ops can be safely called once, and with the same 169# signature as the unconverted version, if their inputs are loop invariant. 170# TODO(agarwal): implement a strategy for converting Variable reads/writes. The 171# plan is to map each read/write in the loop_fn to a corresponding merged 172# read/write in the converted graph. Writes need to be mergeable (e.g. 173# AssignAdd) to be used in `pfor`. Given a certain read/write order in the 174# loop_fn, doing a one-to-one conversion will simulate executing such 175# instructions in lock-step across all iterations. 176passthrough_stateful_ops = set([ 177 "VariableV2", 178 "VarHandleOp", 179 "VariableShape", 180 "ReadVariableOp", 181 "StackV2", 182 "TensorArrayWriteV3", 183 "TensorArrayReadV3", 184 "TensorArraySizeV3", 185]) 186 187 188# Ops which we will treat like stateful for the purpose of vectorization. 189# Typically this is used to force pfor converters to run for these ops. 190force_stateful_ops = set([ 191 # We vectorize this since we need to change the element shape set on the 192 # list. 193 "TensorListReserve", 194]) 195 196 197def _is_stateful_pfor_op(op): 198 if isinstance(op, WhileOp): 199 return op.is_stateful 200 if op.type == "Const": 201 # Const didn't have an op_def. 202 return False 203 if op.type in passthrough_stateful_ops: 204 return False 205 if op.type in force_stateful_ops: 206 return True 207 assert hasattr(op, "op_def") and op.op_def is not None, op 208 return op.op_def.is_stateful 209 210 211# pylint: disable=protected-access 212class WhileOp(object): 213 """Object for storing state for converting the outputs of a while_loop.""" 214 215 def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): 216 """Initializer. 217 218 Args: 219 exit_node: A tensor output from the while_loop. 220 pfor_ops: list of ops inside the current pfor loop. 221 fallback_to_while_loop: If True, fallback to while loop when conversion of 222 an op is not supported 223 pfor_config: PForConfig object used while constructing loop body. 224 """ 225 self._fallback_to_while_loop = fallback_to_while_loop 226 self._pfor_config = pfor_config 227 self._pfor_ops = set(pfor_ops) 228 self._pfor_op_ids = set(x._id for x in pfor_ops) 229 assert isinstance(exit_node, ops.Tensor) 230 self._while_context = exit_node.op._get_control_flow_context() 231 assert isinstance(self._while_context, control_flow_ops.WhileContext) 232 self._context_name = self._while_context.name 233 self._condition = self._while_context.pivot.op.inputs[0] 234 # Parts of an external while_loop could be created inside a pfor loop. 235 # However for the purpose here, we declare such loops to be external. Also 236 # note that we check if the condition was created inside or outside to 237 # determine if the while_loop was first created inside or outside. 238 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. 239 self._is_inside_loop = self.op_is_inside_loop(self._condition.op) 240 if self._is_inside_loop: 241 for e in self._while_context.loop_exits: 242 assert self.op_is_inside_loop(e.op) 243 244 # Note the code below tries to reverse engineer an existing while_loop graph 245 # by assuming the following pattern of nodes. 246 # 247 # NextIteration <---- Body <--- Enter 248 # | ^ 249 # V ___| Y 250 # Enter -> Merge -> Switch___ 251 # ^ | N 252 # | V 253 # LoopCond Exit 254 255 # Node that elements in the list below correspond one-to-one with each 256 # other. i.e. these lists are the same size, and the i_th entry corresponds 257 # to different Operations/Tensors of a single cycle as illustrated above. 258 # List of Switch ops (ops.Operation) that feed into an Exit Node. 259 self._exit_switches = [] 260 # List of inputs (ops.Tensor) to NextIteration. 261 self._body_outputs = [] 262 # List of list of control inputs of the NextIteration nodes. 263 self._next_iter_control_inputs = [] 264 # List of Merge ops (ops.Operation). 265 self._enter_merges = [] 266 # List of output (ops.Tensor) of Exit nodes. 267 self._outputs = [] 268 269 # List of Enter Tensors. 270 # There are two types of Enter nodes: 271 # - The Enter nodes that are used in the `loop_vars` argument to 272 # `while_loop` (see 273 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect 274 # these Enter nodes immediately below by tracing backwards from the Exit 275 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the 276 # diagram above. This allows us to have a 1:1 correspondence between the 277 # self._outputs and the first elements in self._enters. 278 # - The Enter nodes that are used only by the body. They don't appear in the 279 # `loop_vars` and are not returned from the `while_loop`. In Python code, 280 # they are usually captured by the body lambda. We collect them below by 281 # iterating over all the ops in the graph. They are appended to the end of 282 # self._enters or self._direct_enters, and don't correspond to any outputs 283 # in self._outputs. Note that we keep the resource/variant Enter nodes in 284 # self._direct_enters and the constructed while_loop's body uses them 285 # directly as opposed to passing them as loop variables. This is done 286 # because the while_body cannot partition the resource/variant Tensors, so 287 # it has to leave them unchanged. 288 self._enters = [] 289 self._direct_enters = [] 290 291 for e in self._while_context.loop_exits: 292 self._outputs.append(e.op.outputs[0]) 293 switch = e.op.inputs[0].op 294 assert switch.type == "Switch", switch 295 self._exit_switches.append(switch) 296 merge = switch.inputs[0].op 297 assert merge.type == "Merge", merge 298 self._enter_merges.append(merge) 299 enter = merge.inputs[0].op 300 assert enter.type == "Enter", enter 301 self._enters.append(enter.outputs[0]) 302 next_iter = merge.inputs[1].op 303 assert next_iter.type == "NextIteration", next_iter 304 self._body_outputs.append(next_iter.inputs[0]) 305 self._next_iter_control_inputs.append(next_iter.control_inputs) 306 307 # Collect all the Enter nodes that are not part of `loop_vars`, the second 308 # category described above. 309 # Also track whether the loop body has any stateful ops. 310 self._is_stateful = False 311 for op in ops.get_default_graph().get_operations(): 312 # TODO(agarwal): make sure this works with nested case. 313 control_flow_context = op._get_control_flow_context() 314 if control_flow_context is None: 315 continue 316 if control_flow_context.name == self._context_name: 317 self._is_stateful |= _is_stateful_pfor_op(op) 318 if op.type == "Enter": 319 output = op.outputs[0] 320 if output not in self._enters: 321 if output.dtype in (dtypes.resource, dtypes.variant): 322 if output not in self._direct_enters: 323 self._direct_enters.append(output) 324 else: 325 self._enters.append(output) 326 327 def __str__(self): 328 """String representation.""" 329 return "while_loop(%s)" % self.name 330 331 @property 332 def inputs(self): 333 """Input to all the Enter nodes.""" 334 return [x.op.inputs[0] for x in self._enters + self._direct_enters] 335 336 @property 337 def control_inputs(self): 338 """Control input to all the Enter nodes.""" 339 control_inputs = [] 340 for x in self._enters + self._direct_enters: 341 control_inputs.extend(x.op.control_inputs) 342 return control_inputs 343 344 @property 345 def outputs(self): 346 """Outputs of all the Exit nodes.""" 347 return self._outputs 348 349 @property 350 def name(self): 351 """Context name for the while loop.""" 352 return self._context_name 353 354 @property 355 def is_inside_loop(self): 356 """Returns true if the while_loop was created inside the pfor.""" 357 return self._is_inside_loop 358 359 def op_is_inside_loop(self, op): 360 """True if op was created inside the pfor loop body.""" 361 assert isinstance(op, ops.Operation) 362 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 363 # since it appears there tensorflow API could return different python 364 # objects representing the same Operation node. 365 return op._id in self._pfor_op_ids 366 367 @property 368 def is_stateful(self): 369 return self._is_stateful 370 371 @property 372 def pfor_converter(self): 373 """Return a converter for the while loop.""" 374 return self 375 376 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, 377 inputs_stacked): 378 """Create a PFor object for converting parts of the while_loop. 379 380 Args: 381 parent_pfor: PFor object being used for converting the while_loop. 382 indices: int32 Tensor of ids for the iterations that are still active 383 (i.e. did not exit the while_loop). 384 cond_stacked: True if the while_loop condition is stacked. 385 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note 386 that these Tensors are a subset of the loop variables for the generated 387 while_loop. 388 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, 389 indicating if the value is stacked or not. 390 391 Returns: 392 A PFor instance. The instance is initialized by adding conversion mappings 393 of nodes that will be external to the conversion that the returned 394 instance will be used for. e.g. Enter nodes as well as Merge and Switch 395 outputs are mapped to converted values. 396 """ 397 num_outputs = len(self._outputs) 398 assert len(inputs) == len(self._enters) 399 assert len(inputs_stacked) == len(self._enters) 400 loop_var = parent_pfor.loop_var 401 loop_len = array_ops.size(indices) 402 pfor = PFor( 403 loop_var, 404 loop_len, 405 pfor_ops=self._pfor_ops, 406 all_indices=indices, 407 all_indices_partitioned=cond_stacked, 408 fallback_to_while_loop=self._fallback_to_while_loop, 409 pfor_config=self._pfor_config) 410 # Map all inputs of Enter nodes in self._direct_enters to their converted 411 # values. 412 for enter in self._direct_enters: 413 enter_input = enter.op.inputs[0] 414 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( 415 enter_input) 416 # Since these are resources / variants, they should be unstacked. 417 assert not stacked and not is_sparse_stacked, (enter, converted_enter) 418 pfor._add_conversion(enter, wrap(converted_enter, False)) 419 420 # Map all Enter nodes to the inputs. 421 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): 422 pfor._add_conversion(enter, wrap(inp, stacked)) 423 # Map outputs of Switch and Merge. 424 for i in range(num_outputs): 425 wrapped_inp = wrap(inputs[i], inputs_stacked[i]) 426 merge = self._enter_merges[i] 427 pfor._add_conversion(merge.outputs[0], wrapped_inp) 428 # Note that second output of Merge is typically not used, except possibly 429 # as a control dependency. To avoid trying to output the correct value, we 430 # employ a hack here. We output a dummy invalid value with an incorrect 431 # dtype. This will allow control dependency to work but if using it as an 432 # input, it should typically lead to errors during graph construction due 433 # to dtype mismatch. 434 # TODO(agarwal): Check in the original graph to see if there are any 435 # consumers of this Tensor that use it as an input. 436 pfor._add_conversion(merge.outputs[1], 437 wrap(constant_op.constant(-1.0), False)) 438 switch = self._exit_switches[i] 439 # Don't need to worry about switch.output[0] which will feed to Exit node. 440 pfor._add_conversion(switch.outputs[1], wrapped_inp) 441 return pfor 442 443 def _convert_enter(self, parent_pfor, enter): 444 """Converts an Enter node.""" 445 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) 446 control_inputs = [] 447 for x in enter.op.control_inputs: 448 converted = parent_pfor._convert_helper(x) 449 if not isinstance(converted, ops.Operation): 450 converted = converted.t 451 control_inputs.append(converted) 452 if control_inputs: 453 with ops.control_dependencies(control_inputs): 454 inp = array_ops.identity(inp) 455 return inp, stacked 456 457 def _maybe_stacked(self, cache, inp): 458 """Heuristic to figure out if the converting inp leads to a stacked value. 459 460 461 Args: 462 cache: map from Tensor to boolean indicating stacked/unstacked. 463 inp: input Tensor. 464 465 Returns: 466 True if `inp` could get stacked. If the function returns False, the 467 converted value should be guaranteed to be unstacked. If returning True, 468 it may or may not be stacked. 469 """ 470 if inp in cache: 471 return cache[inp] 472 if not self.op_is_inside_loop(inp.op): 473 return False 474 op = inp.op 475 output = False 476 if op.type in [ 477 "Shape", 478 "Rank", 479 "ShapeN", 480 "ZerosLike", 481 "TensorArrayV3", 482 "TensorArraySizeV3", 483 ]: 484 output = False 485 elif _is_stateful_pfor_op(op): 486 # This may be fairly aggressive. 487 output = True 488 elif op.type == "Exit": 489 # This may be fairly aggressive. 490 output = True 491 else: 492 for t in op.inputs: 493 if self._maybe_stacked(cache, t): 494 output = True 495 break 496 cache[inp] = output 497 return output 498 499 def _create_init_values(self, pfor_input): 500 """Create arguments passed to converted while_loop.""" 501 with ops.name_scope("while_init"): 502 loop_len_vector = pfor_input.pfor.loop_len_vector 503 loop_len = loop_len_vector[0] 504 num_outputs = len(self._outputs) 505 506 inputs = [] 507 maybe_stacked_cache = {} 508 # Convert all the Enters. Need to do this before checking for stacking 509 # below. 510 for i, enter in enumerate(self._enters): 511 inp, stacked = self._convert_enter(pfor_input.pfor, enter) 512 inputs.append(inp) 513 maybe_stacked_cache[enter] = stacked 514 # Since this enter node is part of the `loop_vars`, it corresponds to an 515 # output and its preceding switch. We mark this switch's output the same 516 # stackness, to act at the base case for the logic below. Below, we will 517 # be going through the body figuring out which inputs might need to be 518 # stacked and which inputs can safely remain unstacked. 519 if i < num_outputs: 520 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked 521 522 # Shape invariants for init_values corresponding to self._enters. 523 input_shape_invariants = [] 524 # TensorArrays for outputs of converted while loop 525 output_tas = [] 526 # Shape invariants for output TensorArrays. 527 ta_shape_invariants = [] 528 # List of booleans indicating stackness of inputs, i.e. tensors 529 # corresponding to self._enters. 530 inputs_stacked = [] 531 for i, inp in enumerate(inputs): 532 enter = self._enters[i] 533 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) 534 # Note that even when an input is unstacked, the body could make it 535 # stacked. we use a heuristic below to figure out if body may be making 536 # it stacked. 537 if i < num_outputs: 538 body_output = self._body_outputs[i] 539 if enter.op in self._pfor_ops: 540 body_output_stacked = self._maybe_stacked(maybe_stacked_cache, 541 body_output) 542 else: 543 # If constructed outside of pfor loop, then the output would not be 544 # stacked. 545 body_output_stacked = False 546 if body_output_stacked and not inp_stacked: 547 inp = _stack(inp, loop_len_vector).t 548 inputs[i] = inp 549 inp_stacked = True 550 # TODO(agarwal): other attributes for the TensorArray ? 551 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) 552 ta_shape_invariants.append(tensor_shape.TensorShape(None)) 553 554 inputs_stacked.append(inp_stacked) 555 input_shape_invariants.append(tensor_shape.TensorShape(None)) 556 557 # See documentation for __call__ for the structure of init_values. 558 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas 559 # TODO(agarwal): try stricter shape invariants 560 shape_invariants = ( 561 [tensor_shape.TensorShape(None), 562 tensor_shape.TensorShape(None)] + input_shape_invariants + 563 ta_shape_invariants) 564 565 return init_values, inputs_stacked, shape_invariants 566 567 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 568 """Handles case when condition is unstacked. 569 570 Note that all iterations end together. So we don't need to partition the 571 inputs. When all iterations are done, we write the inputs to the 572 TensorArrays. Note that we only write to index 0 of output_tas. Since all 573 iterations end together, they can all be output together. 574 """ 575 not_all_done = array_ops.reshape(conditions, []) 576 new_output_tas = [] 577 # pylint: disable=cell-var-from-loop 578 for i, out_ta in enumerate(output_tas): 579 inp = inputs[i] 580 new_output_tas.append( 581 control_flow_ops.cond(not_all_done, lambda: out_ta, 582 lambda: out_ta.write(0, inp))) 583 # pylint: enable=cell-var-from-loop 584 return not_all_done, indices, inputs, new_output_tas 585 586 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 587 output_tas): 588 num_outputs = len(self._outputs) 589 # Compute if all iterations are done. 590 not_all_done = math_ops.reduce_any(conditions) 591 conditions_int = math_ops.cast(conditions, dtypes.int32) 592 # Partition the indices. 593 done_indices, new_indices = data_flow_ops.dynamic_partition( 594 indices, conditions_int, 2) 595 596 new_inputs = [] 597 new_output_tas = [] 598 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 599 # Partition the inputs. 600 if stacked: 601 done_inp, new_inp = data_flow_ops.dynamic_partition( 602 inp, conditions_int, 2) 603 else: 604 # TODO(agarwal): avoid this stacking. See TODO earlier in 605 # _process_cond_unstacked. 606 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 607 new_inp = inp 608 new_inputs.append(new_inp) 609 # For iterations that are done, write them to TensorArrays. 610 if i < num_outputs: 611 out_ta = output_tas[i] 612 # Note that done_indices can be empty. done_inp should also be empty in 613 # that case. 614 new_output_tas.append(out_ta.scatter(done_indices, done_inp)) 615 return not_all_done, new_indices, new_inputs, new_output_tas 616 617 def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked, 618 new_inputs, not_all_done): 619 """Convert the body function.""" 620 621 def true_fn(control_inputs, body_pfor, body_output, stacked): 622 """Converts the body function for all but last iteration. 623 624 This essentially converts body_output. Additionally, it needs to handle 625 any control dependencies on the NextIteration node. So it creates another 626 Identity node with the converted dependencies. 627 """ 628 converted_control_inp = [] 629 for x in control_inputs: 630 for t in x.outputs: 631 converted_control_inp.append(body_pfor._convert_helper(t).t) 632 if stacked: 633 # Note convert always does the stacking. 634 output = body_pfor.convert(body_output) 635 else: 636 output, convert_stacked, _ = body_pfor._convert_helper(body_output) 637 assert convert_stacked == stacked, body_output 638 with ops.control_dependencies(converted_control_inp): 639 return array_ops.identity(output) 640 641 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked, 642 new_inputs, inputs_stacked) 643 new_outputs = [] 644 645 for i, (body_output, 646 stacked) in enumerate(zip(self._body_outputs, inputs_stacked)): 647 control_inp = self._next_iter_control_inputs[i] 648 out_dtype = body_output.dtype 649 # Note that we want to run the body only if not all pfor iterations are 650 # done. If all are done, we return empty tensors since these values will 651 # not be used. Notice that the value returned by the loop is based on 652 # TensorArrays and not directly on these returned values. 653 # pylint: disable=cell-var-from-loop 654 new_output = control_flow_ops.cond( 655 not_all_done, 656 lambda: true_fn(control_inp, body_pfor, body_output, stacked), 657 lambda: constant_op.constant([], dtype=out_dtype)) 658 # pylint: enable=cell-var-from-loop 659 new_outputs.append(new_output) 660 return new_outputs 661 662 def __call__(self, pfor_input): 663 """Converter for the while_loop. 664 665 The conversion of a while_loop is another while_loop. 666 667 The arguments to this converted while_loop are as follows: 668 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 669 are done. 670 indices: int32 1-D Tensor storing the id of the iterations that are not 671 done. 672 args: Remaining arguments. These can be divided into 3 categories: 673 - First set of arguments are the tensors that correspond to the initial 674 elements of self._enters. The elements that appear in original while 675 loop's `loop_vars`. 676 - The second set of arguments are the tensors that correspond to the 677 remaining elements of self._enters. These are the tensors that directly 678 enter the original while loop body. 679 - Finally, the last set of arguments are TensorArrays. These TensorArrays 680 correspond to the outputs of the original while_loop, i.e. to the 681 elements in self._outputs. Each TensorArray has `PFor.loop_len` 682 elements, i.e. the number of pfor iterations. At the end, the i'th 683 element of each TensorArray will contain the output computed by the 684 i'th iteration of pfor. Note that elements can be written into these 685 tensors arrays in any order, depending on when the corresponding pfor 686 iteration is done. 687 If the original while_loop had `k` tensors in its `loop_vars` and its body 688 directly captured `m` tensors, the `args` will contain `2 * k + m` values. 689 690 In each iteration, the while_loop body recomputes the condition for all 691 active pfor iterations to see which of them are now done. It then partitions 692 all the inputs and passes them along to the converted body. Values for all 693 the iterations that are done are written to TensorArrays indexed by the pfor 694 iteration number. When all iterations are done, the TensorArrays are stacked 695 to get the final value. 696 697 Args: 698 pfor_input: A PForInput object corresponding to the output of any Exit 699 node from this while loop. 700 701 Returns: 702 List of converted outputs. 703 """ 704 # Create init_values that will be passed to the while_loop. 705 init_values, inputs_stacked, shape_invariants = self._create_init_values( 706 pfor_input) 707 # Note that we use a list as a hack since we need the nested function body 708 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal 709 # variables. 710 cond_is_stacked = [None] 711 712 def cond(not_all_done, *_): 713 return not_all_done 714 715 def body(not_all_done, indices, *args): 716 # See documentation for __call__ for the structure of *args. 717 num_enters = len(self._enters) 718 inputs = args[:num_enters] 719 output_tas = args[num_enters:] 720 # TODO(agarwal): see which outputs have consumers and only populate the 721 # TensorArrays corresponding to those. Or do those paths get trimmed out 722 # from inside the while_loop body? 723 assert len(inputs) >= len(output_tas) 724 assert len(inputs) == len(inputs_stacked) 725 726 # Convert condition 727 with ops.name_scope("while_cond"): 728 # Note that we set cond_stacked to True here. At this point we don't 729 # know if it could be loop invariant, hence the conservative value is 730 # to assume stacked. 731 cond_pfor = self._init_pfor( 732 pfor_input.pfor, 733 indices, 734 cond_stacked=True, 735 inputs=inputs, 736 inputs_stacked=inputs_stacked) 737 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) 738 cond_is_stacked[0] = cond_stacked 739 740 # Recompute the new condition, write outputs of done iterations, and 741 # partition the inputs if needed. 742 if not cond_stacked: 743 (not_all_done, new_indices, new_inputs, 744 new_output_tas) = self._process_cond_unstacked(conditions, indices, 745 inputs, output_tas) 746 else: 747 (not_all_done, new_indices, new_inputs, 748 new_output_tas) = self._process_cond_stacked(conditions, indices, 749 inputs, inputs_stacked, 750 output_tas) 751 752 # Convert body 753 with ops.name_scope("while_body"): 754 # Compute the outputs from the body. 755 new_outputs = self._process_body(pfor_input, inputs_stacked, 756 new_indices, cond_stacked, new_inputs, 757 not_all_done) 758 759 # Note that the first num_outputs new values of inputs are computed using 760 # the body. Rest of them were direct Enters into the condition/body and 761 # the partitioning done earlier is sufficient to give the new value. 762 num_outputs = len(self._outputs) 763 new_args = ([not_all_done, new_indices] + new_outputs + 764 list(new_inputs[num_outputs:]) + new_output_tas) 765 return tuple(new_args) 766 767 while_outputs = control_flow_ops.while_loop( 768 cond, body, init_values, shape_invariants=shape_invariants) 769 output_tas = while_outputs[-len(self._outputs):] 770 outputs = [] 771 assert cond_is_stacked[0] is not None 772 for inp_stacked, ta in zip(inputs_stacked, output_tas): 773 if cond_is_stacked[0]: 774 outputs.append(wrap(ta.stack(), True)) 775 else: 776 # Note that if while_loop condition is unstacked, all iterations exit at 777 # the same time and we wrote those outputs in index 0 of the tensor 778 # array. 779 outputs.append(wrap(ta.read(0), inp_stacked)) 780 return outputs 781 782 783class ConversionNotImplementedError(Exception): 784 pass 785 786 787class _PforInput(object): 788 """Input object passed to registered pfor converters.""" 789 790 __slots__ = ["pfor", "_op", "_inputs"] 791 792 def __init__(self, pfor, op, inputs): 793 """Creates a _PforInput object. 794 795 Args: 796 pfor: PFor converter object. 797 op: the Operation object that is being converted. 798 inputs: list of WrappedTensor objects representing converted values of the 799 inputs of `op`. 800 """ 801 self.pfor = pfor 802 self._op = op 803 self._inputs = inputs 804 805 def stack_inputs(self, stack_indices=None, tile_variants=False): 806 """Stacks unstacked inputs at `stack_indices`. 807 808 Args: 809 stack_indices: indices of inputs at which stacking is done. If None, 810 stacking is done at all indices. 811 tile_variants: If True, affected indices which have a variant dtype will 812 be tiled after this operation to match the expected shape of a 813 vectorized tensor. Variants generally need to be un-tiled when they are 814 inputs to operations and tiled when returned. 815 """ 816 if stack_indices is None: 817 stack_indices = range(len(self._inputs)) 818 length = self.pfor.loop_len_vector 819 for i in stack_indices: 820 inp = self._inputs[i] 821 is_variant = inp.t.dtype == dtypes.variant 822 if not inp.is_stacked: 823 self._inputs[i] = _stack(inp.t, length) 824 if tile_variants and is_variant: 825 self._inputs[i] = wrap( 826 _tile_variant_with_length(self._inputs[i].t, length), True) 827 elif not tile_variants and is_variant: 828 self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True) 829 830 def expanddim_inputs_for_broadcast(self): 831 """Reshapes stacked inputs to prepare them for broadcast. 832 833 Since stacked inputs have an extra leading dimension, automatic broadcasting 834 rules could incorrectly try to expand dimensions before that leading 835 dimension. To avoid that, we reshape these stacked inputs to the maximum 836 rank they will need to be broadcasted to. 837 """ 838 if not self._inputs: 839 return 840 841 # Find max rank 842 def _get_rank(x): 843 rank = array_ops.rank(x.t) 844 if not x.is_stacked: 845 rank += 1 846 return rank 847 848 ranks = [_get_rank(x) for x in self._inputs] 849 max_rank = ranks[0] 850 for rank in ranks[1:]: 851 max_rank = math_ops.maximum(rank, max_rank) 852 853 for i, inp in enumerate(self._inputs): 854 if inp.is_stacked: 855 shape = array_ops.shape(inp.t) 856 rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) 857 ones = array_ops.tile([1], rank_diff) 858 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) 859 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) 860 861 @property 862 def inputs(self): 863 return self._inputs 864 865 @property 866 def num_inputs(self): 867 return len(self._inputs) 868 869 def input(self, index): 870 assert len(self._inputs) > index, (index, self._inputs) 871 return self._inputs[index] 872 873 def stacked_input(self, index): 874 t, is_stacked, _ = self.input(index) 875 if not is_stacked: 876 op_type = self.op_type 877 op_def = getattr(self._op, "op_def", None) 878 if op_def is None: 879 input_name = "at index %d" % index 880 else: 881 input_name = "\"%s\"" % op_def.input_arg[index].name 882 raise ConversionNotImplementedError( 883 "Input %s of op \"%s\" expected to be not loop invariant" % 884 (input_name, op_type)) 885 return t 886 887 def unstacked_input(self, index): 888 t, is_stacked, _ = self.input(index) 889 if is_stacked: 890 op_type = self.op_type 891 op_def = getattr(self._op, "op_def", None) 892 if op_def is None: 893 input_name = "at index %d" % index 894 else: 895 input_name = "\"%s\"" % op_def.input_arg[index].name 896 raise ConversionNotImplementedError( 897 "Input %s of op \"%s\" expected to be loop invariant" % 898 (input_name, op_type)) 899 return t 900 901 @property 902 def op(self): 903 return self._op 904 905 @property 906 def op_type(self): 907 return self._op.type 908 909 def get_attr(self, attr): 910 return self._op.get_attr(attr) 911 912 @property 913 def outputs(self): 914 return self._op.outputs 915 916 def output(self, index): 917 assert index < len(self._op.outputs) 918 return self._op.outputs[index] 919 920 921_pfor_converter_registry = {} 922 923 924class RegisterPFor(object): 925 """Utility to register converters for pfor. 926 927 Usage: 928 @RegisterPFor(foo_op_type) 929 def _foo_converter(pfor_input): 930 ... 931 932 The above will register conversion function `_foo_converter` for handling 933 conversion of `foo_op_type`. These converters are called during vectorization 934 of a `pfor` loop body. For each operation node in this loop body, 935 the vectorization process will call the converter corresponding to the 936 operation type of the node. 937 938 During conversion, the registered function will be called with a single 939 argument `pfor_input`, of type `PForInput`, which will contain state needed 940 for the conversion. When the converter is called for a node, all its inputs 941 should already have been converted and these converted values are stored in 942 `pfor_input.inputs`. This registered function should output a list of 943 WrappedTensor objects with the same length as the number of outputs of the 944 node being converted. If the node had zero outputs, then it should return an 945 ops.Operation object. These new sets of nodes should implement the 946 functionality of running that operation for the number of iterations specified 947 by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each 948 iteration are picked from `pfor_inputs.inputs()`. 949 950 One tricky aspect of the conversion process is keeping track of, and 951 leveraging loop invariance of computation. Each converted input is a 952 WrappedTensor which indicates whether the input was loop invariant or not. If 953 the converted value is loop invariant, its rank should match the rank of the 954 corresponding tensor in the loop body, else its rank is larger by 1. The 955 converter should look at the loop invariance of the inputs and generate new 956 nodes based on that. Note that the converter will not be called if all inputs 957 are loop invariant and the operation is not stateful. The converter should 958 determine if its own output is loop invariant and `wrap` its output 959 accordingly. 960 961 Example: 962 963 Here, the converter is trying to convert a Reshape node in the loop body. This 964 node will have two inputs: the tensor to reshape, and the new shape. The 965 example here only handles the case where the shape is loop invariant. 966 967 @RegisterPFor("Reshape") 968 def _convert_reshape(pfor_input): 969 # We assume that input is not loop invariant. Call to `stacked_input` 970 # asserts that and returns the converted value. This value will have a rank 971 # larger by 1 compared to the rank of the input in the loop body. 972 t = pfor_input.stacked_input(0) 973 974 # We assume that shape input is loop invariant. Call to `unstacked_input` 975 # asserts that and returns the converted value. 976 shape = pfor_input.unstacked_input(1) 977 978 # We compute `new_shape` by prepending the number of iterations to the 979 # original shape. 980 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], 981 axis=0) 982 983 # The vectorized output involves reshaping the converted input `t` using 984 # `new_shape`. 985 new_output = array_ops.reshape(t, new_shape) 986 987 # The converted output is marked as not loop invariant using the call to 988 # wrap. 989 return wrap(new_output, True) 990 """ 991 992 def __init__(self, op_type): 993 """Creates an object to register a converter for op with type `op_type`.""" 994 self.op_type = op_type 995 996 def __call__(self, converter): 997 name = self.op_type 998 assert name not in _pfor_converter_registry, "Re-registering %s " % name 999 _pfor_converter_registry[name] = converter 1000 return converter 1001 1002 1003class RegisterPForWithArgs(RegisterPFor): 1004 """Utility to register converters for pfor. 1005 1006 Usage: 1007 @RegisteRPFor(foo_op_type, foo=value, ....) 1008 def _foo_converter(pfor_input, foo=None, ....): 1009 ... 1010 1011 See RegisterPFor for details on the conversion function. 1012 `RegisterPForWithArgs` allows binding extra arguments to the 1013 conversion function at registration time. 1014 """ 1015 1016 def __init__(self, op_type, *args, **kw_args): 1017 super(RegisterPForWithArgs, self).__init__(op_type) 1018 self._args = args 1019 self._kw_args = kw_args 1020 1021 def __call__(self, converter): 1022 1023 def _f(pfor_input): 1024 return converter(pfor_input, self.op_type, *self._args, **self._kw_args) 1025 1026 super(RegisterPForWithArgs, self).__call__(_f) 1027 return converter 1028 1029 1030# TODO(agarwal): call raw_ops instead of calling these low level routines. 1031def _create_op(op_type, inputs, op_dtypes, attrs=None): 1032 """Utility to create an op.""" 1033 op = ops.get_default_graph().create_op( 1034 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) 1035 flat_attrs = [] 1036 # The tape expects an alternating flat list of names and attribute values. 1037 for a in attrs: 1038 flat_attrs.append(str(a)) 1039 flat_attrs.append(op.get_attr(str(a))) 1040 execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:]) 1041 return op 1042 1043 1044WrappedTensor = collections.namedtuple("WrappedTensor", 1045 ["t", "is_stacked", "is_sparse_stacked"]) 1046"""Wrapper around the result of a Tensor conversion. 1047 1048The additional fields are useful for keeping track of the conversion state as 1049data flows through the ops in the loop body. For every op whose output is a 1050Tensor, its converter should return either a WrappedTensor or a list of 1051WrappedTensors. 1052 1053Args: 1054 t: The converted tensor 1055 is_stacked: True if the tensor is stacked, i.e. represents the results of all 1056 the iterations of the loop, where each row i of the tensor corresponds to 1057 that op's output on iteration i of the loop. False if the tensor is not 1058 stacked, i.e. represents the result of the op on of a single iteration of 1059 the loop, where the result does not vary between iterations. 1060 is_sparse_stacked: True if the tensor corresponds to a component tensor 1061 (indices, values, or dense_shape) of a sparse tensor, and has been logically 1062 stacked via a sparse conversion. 1063""" 1064 1065 1066def wrap(tensor, is_stacked=True, is_sparse_stacked=False): 1067 """Helper to create a WrappedTensor object.""" 1068 assert isinstance(is_stacked, bool) 1069 assert isinstance(is_sparse_stacked, bool) 1070 assert isinstance(tensor, ops.Tensor) 1071 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " 1072 "stacked via a sparse " 1073 "conversion, it must also be " 1074 "stacked.") 1075 return WrappedTensor(tensor, is_stacked, is_sparse_stacked) 1076 1077 1078def _wrap_and_tile_variants(tensor, length): 1079 if tensor.dtype == dtypes.variant: 1080 tensor = _tile_variant_with_length(tensor, length) 1081 return wrap(tensor) 1082 1083 1084def _fallback_converter(pfor_input, warn=True): 1085 if warn: 1086 logging.warning("Using a while_loop for converting %s", pfor_input.op_type) 1087 output_dtypes = [x.dtype for x in pfor_input.outputs] 1088 iters = pfor_input.pfor.loop_len_vector[0] 1089 1090 def while_body(i, *ta_list): 1091 """Body of while loop.""" 1092 inputs = [ 1093 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs 1094 ] 1095 op_outputs = _create_op( 1096 pfor_input.op_type, 1097 inputs, 1098 output_dtypes, 1099 attrs=pfor_input.op.node_def.attr).outputs 1100 1101 outputs = [] 1102 # TODO(agarwal): Add tf.debugging asserts to check that the shapes across 1103 # the different iterations are the same. 1104 for out, ta in zip(op_outputs, ta_list): 1105 assert isinstance(out, ops.Tensor) 1106 outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) 1107 return tuple([i + 1] + outputs) 1108 1109 ta_list = control_flow_ops.while_loop( 1110 lambda i, *ta: i < iters, while_body, [0] + 1111 [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes 1112 ])[1:] 1113 return tuple([wrap(ta.concat(), True) for ta in ta_list]) 1114 1115 1116class PForConfig(object): 1117 """A configuration object used to communicate with loop body function.""" 1118 1119 def __init__(self): 1120 # This may be set to the number of iterations. 1121 self._maybe_iters = None 1122 # Map from reduction node, created by `reduce`, to the bundle of reduction 1123 # function and arguments. 1124 self._reduce_map = {} 1125 1126 def _has_reductions(self): 1127 """True if some reductions where performed by loop body.""" 1128 return len(self._reduce_map) 1129 1130 def _set_iters(self, iters): 1131 """Set number of pfor iterations.""" 1132 if isinstance(iters, ops.Tensor): 1133 iters = tensor_util.constant_value(iters) 1134 self._maybe_iters = iters 1135 1136 def reduce(self, fn, *args): 1137 """Performs reduction `fn` on `args` vectorized across pfor iterations. 1138 1139 Note that `fn` is traced once inside the loop function context. Hence any 1140 captures or side-effects will happen in that context. Call to the traced 1141 version of `fn` happens during the construction of the vectorized code. 1142 1143 Note that this currently may not work inside a control flow construct. 1144 Args: 1145 fn: a reduction function. It will be called with arguments that have the 1146 same structure as *args but with individual values whose rank may be 1147 higher by 1 since they represent loop invariant vectorized versions of 1148 the corresponding Tensors in *args. 1149 *args: unvectorized Tensors. 1150 1151 Returns: 1152 The result of running `fn` on the vectorized versions of `*args`. These 1153 outputs will be available as loop invariant values to all the iterations. 1154 """ 1155 assert not context.executing_eagerly() 1156 # Creates a concrete function that will be used for reduction. 1157 tensor_specs = [] 1158 for arg in args: 1159 if not isinstance(arg, ops.Tensor): 1160 raise ValueError("Got a non-Tensor argument %s in reduce" % arg) 1161 batched_shape = tensor_shape.TensorShape([self._maybe_iters 1162 ]).concatenate(arg.shape) 1163 tensor_specs.append( 1164 tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) 1165 concrete_function = def_function.function(fn).get_concrete_function( 1166 *tensor_specs) 1167 1168 # Creates PlaceholderWithDefault and IdentityN nodes corresponding the 1169 # reduction. 1170 pl_outputs = [] 1171 with ops.control_dependencies(args): 1172 for output in concrete_function.outputs: 1173 if not isinstance(output, ops.Tensor): 1174 raise ValueError("Got a non-Tensor output %s while running reduce" % 1175 output) 1176 # Note that we use placeholder_with_default just to make XLA happy since 1177 # it does not like placeholder ops. 1178 if output.shape.is_fully_defined(): 1179 dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype) 1180 pl_outputs.append( 1181 array_ops.placeholder_with_default(dummy, shape=output.shape)) 1182 else: 1183 # TODO(agarwal): support case when under XLA and output.shape is not 1184 # fully defined. 1185 pl_outputs.append( 1186 array_ops.placeholder(output.dtype, shape=output.shape)) 1187 1188 reduction_op = array_ops.identity_n(pl_outputs)[0].op 1189 self._reduce_map[reduction_op] = (concrete_function, args) 1190 if len(reduction_op.outputs) == 1: 1191 return reduction_op.outputs[0] 1192 else: 1193 return tuple(reduction_op.outputs) 1194 1195 # TODO(agarwal): handle reductions inside control flow constructs. 1196 def reduce_concat(self, x): 1197 """Performs a concat reduction on `x` across pfor iterations. 1198 1199 Note that this currently may not work inside a control flow construct. 1200 Args: 1201 x: an unvectorized Tensor. 1202 1203 Returns: 1204 A Tensor that has rank one higher than `x`. The value is the vectorized 1205 version of `x`, i.e. stacking the value of `x` across different pfor 1206 iterations. 1207 """ 1208 return self.reduce(lambda y: y, x) 1209 1210 def reduce_mean(self, x): 1211 """Performs a mean reduction on `x` across pfor iterations. 1212 1213 Note that this currently may not work inside a control flow construct. 1214 Args: 1215 x: an unvectorized Tensor. 1216 1217 Returns: 1218 A Tensor that has same rank as `x`. The value is the mean of the values 1219 of `x` across the pfor iterations. 1220 """ 1221 return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x) 1222 1223 def reduce_sum(self, x): 1224 """Performs a sum reduction on `x` across pfor iterations. 1225 1226 Note that this currently may not work inside a control flow construct. 1227 Args: 1228 x: an unvectorized Tensor. 1229 1230 Returns: 1231 A Tensor that has same rank as `x`. The value is the sum of the values 1232 of `x` across the pfor iterations. 1233 """ 1234 return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x) 1235 1236 def _lookup_reduction(self, t): 1237 """Lookups Tensor `t` in the reduction maps.""" 1238 assert isinstance(t, ops.Tensor), t 1239 return self._reduce_map.get(t.op) 1240 1241 1242class PFor(object): 1243 """Implementation of rewrite of parallel-for loops. 1244 1245 This class takes a DAG or a set of DAGs representing the body of a 1246 parallel-for loop, and adds new operations to the graph that implements 1247 functionality equivalent to running that loop body for a specified number of 1248 iterations. This new set of nodes may or may not use a tensorflow loop 1249 construct. 1250 1251 The process of conversion does not delete or change any existing operations. 1252 It only adds operations that efficiently implement the equivalent 1253 functionality. We refer to the added ops as "converted ops". 1254 1255 The conversion process uses a simple greedy heuristic. It walks the loop body 1256 and tries to express the functionality of running each node in a loop with a 1257 new set of nodes. When converting an op several cases are possible: 1258 - The op is not inside the loop body. Hence it can be used as is. 1259 - The op does not depend on the iteration number and is stateless. In this 1260 case, it can be used as is. 1261 - The op is not stateful, and depends on iteration number only through control 1262 dependencies. In this case, we can create a single op with same inputs and 1263 attributes, but with "converted" control dependencies. 1264 - The op is not stateful, and all its inputs are loop invariant. In this 1265 case, similar to above, we can create a single op with same inputs and 1266 attributes, but with "converted" control dependencies. 1267 - The op is stateful or at least one of the inputs is not loop invariant. In 1268 this case, we run the registered converter for that op to create a set of 1269 converted ops. All nodes in the set will have converted control dependencies 1270 corresponding to control dependencies of the original op. If the op returned 1271 multiple outputs, "converted outputs" could be produced by different ops in 1272 this set. 1273 """ 1274 1275 def __init__(self, 1276 loop_var, 1277 loop_len, 1278 pfor_ops, 1279 fallback_to_while_loop, 1280 all_indices=None, 1281 all_indices_partitioned=False, 1282 pfor_config=None): 1283 """Creates an object to rewrite a parallel-for loop. 1284 1285 Args: 1286 loop_var: ops.Tensor output of a Placeholder operation. The value should 1287 be an int32 scalar representing the loop iteration number. 1288 loop_len: A scalar or scalar Tensor representing the number of iterations 1289 the loop is run for. 1290 pfor_ops: List of all ops inside the loop body. 1291 fallback_to_while_loop: If True, on failure to vectorize an op, a while 1292 loop is used to sequentially execute that op. 1293 all_indices: If not None, an int32 vector with size `loop_len` 1294 representing the iteration ids that are still active. These values 1295 should be unique and sorted. However they may not be contiguous. This is 1296 typically the case when inside a control flow construct which has 1297 partitioned the indices of the iterations that are being converted. 1298 all_indices_partitioned: If True, this object is being constructed from a 1299 control flow construct where not all the pfor iterations are guaranteed 1300 to be active. 1301 pfor_config: PForConfig object used while constructing the loop body. 1302 """ 1303 assert isinstance(loop_var, ops.Tensor) 1304 assert loop_var.op.type == "PlaceholderWithDefault" 1305 self._loop_var = loop_var 1306 loop_len_value = tensor_util.constant_value(loop_len) 1307 if loop_len_value is not None: 1308 loop_len = loop_len_value 1309 self._loop_len_vector = array_ops.reshape(loop_len, [1]) 1310 self._all_indices_partitioned = all_indices_partitioned 1311 if all_indices_partitioned: 1312 assert all_indices is not None 1313 self.all_indices = ( 1314 math_ops.range(loop_len) if all_indices is None else all_indices) 1315 1316 self._conversion_map = object_identity.ObjectIdentityDictionary() 1317 self._conversion_map[loop_var] = wrap(self.all_indices, True) 1318 self._pfor_ops = set(pfor_ops) 1319 self._pfor_op_ids = set(x._id for x in pfor_ops) 1320 self._fallback_to_while_loop = fallback_to_while_loop 1321 self._pfor_config = pfor_config 1322 1323 def op_is_inside_loop(self, op): 1324 """True if op was created inside the pfor loop body.""" 1325 assert isinstance(op, ops.Operation) 1326 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 1327 # since it appears there tensorflow API could return different python 1328 # objects representing the same Operation node. 1329 return op._id in self._pfor_op_ids 1330 1331 def _convert_sparse(self, y): 1332 """Returns the converted value corresponding to SparseTensor y. 1333 1334 For SparseTensors, instead of stacking the component tensors separately, 1335 resulting in component tensors with shapes (N, m, rank), (N, m), and (N, 1336 rank) respectively for indices, values, and dense_shape (where N is the loop 1337 length and m is the number of sparse tensor values per loop iter), we want 1338 to logically stack the SparseTensors, to create a SparseTensor whose 1339 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) 1340 respectively. 1341 1342 Here, we try to get the conversion of each component tensor. 1343 If the tensors are stacked via a sparse conversion, return the resulting 1344 SparseTensor composed of the converted components. Otherwise, the component 1345 tensors are either unstacked or stacked naively. In the latter case, we 1346 unstack the component tensors to reform loop_len SparseTensor elements, 1347 then correctly batch them. 1348 1349 The unstacked tensors must have the same rank. Each dimension of each 1350 SparseTensor will expand to be the largest among all SparseTensor elements 1351 for that dimension. For example, if there are N SparseTensors of rank 3 1352 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), 1353 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). 1354 1355 Args: 1356 y: A tf.sparse.SparseTensor. 1357 1358 Returns: 1359 A tf.sparse.SparseTensor that is the converted value corresponding to y. 1360 """ 1361 outputs = [ 1362 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) 1363 ] 1364 assert all(isinstance(o, WrappedTensor) for o in outputs) 1365 1366 if all(w.is_sparse_stacked for w in outputs): 1367 return sparse_tensor.SparseTensor(*[w.t for w in outputs]) 1368 1369 assert not any(w.is_sparse_stacked for w in outputs), ( 1370 "Error converting SparseTensor. All components should be logically " 1371 "stacked, or none.") 1372 1373 # If component tensors were not sparsely stacked, they are either unstacked 1374 # or stacked without knowledge that they are components of sparse tensors. 1375 # In this case, we have to restack them. 1376 return self._restack_sparse_tensor_logically( 1377 *[self._unwrap_or_tile(w) for w in outputs]) 1378 1379 def _restack_sparse_tensor_logically(self, indices, values, shape): 1380 sparse_tensor_rank = indices.get_shape().dims[-1].value 1381 if sparse_tensor_rank is not None: 1382 sparse_tensor_rank += 1 1383 1384 def fn(args): 1385 res = gen_sparse_ops.serialize_sparse( 1386 args[0], args[1], args[2], out_type=dtypes.variant) 1387 return res 1388 1389 # Applies a map function to the component tensors to serialize each 1390 # sparse tensor element and batch them all, then deserializes the batch. 1391 # TODO(rachelim): Try to do this without map_fn -- add the right offsets 1392 # to shape and indices tensors instead. 1393 result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant) 1394 return sparse_ops.deserialize_sparse( 1395 result, dtype=values.dtype, rank=sparse_tensor_rank) 1396 1397 def _unwrap_or_tile(self, wrapped_tensor): 1398 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" 1399 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked 1400 if is_stacked: 1401 return output 1402 else: 1403 return _stack(output, self._loop_len_vector).t 1404 1405 def convert(self, y): 1406 """Returns the converted value corresponding to y. 1407 1408 Args: 1409 y: A ops.Tensor or a ops.Operation object. If latter, y should not have 1410 any outputs. 1411 1412 Returns: 1413 If y does not need to be converted, it returns y as is. Else it returns 1414 the "converted value" corresponding to y. 1415 """ 1416 if y is None: 1417 return None 1418 if isinstance(y, sparse_tensor.SparseTensor): 1419 return self._convert_sparse(y) 1420 assert isinstance(y, (ops.Tensor, ops.Operation)), y 1421 output = self._convert_helper(y) 1422 if isinstance(output, WrappedTensor): 1423 assert isinstance(y, ops.Tensor) 1424 return self._unwrap_or_tile(output) 1425 else: 1426 assert isinstance(y, ops.Operation) 1427 assert not y.outputs 1428 assert isinstance(output, ops.Operation) 1429 return output 1430 1431 def _was_converted(self, t): 1432 """True if t is not a conversion of itself.""" 1433 converted_t = self._conversion_map[t] 1434 return converted_t.t is not t 1435 1436 def _add_conversion(self, old_output, new_output): 1437 assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output 1438 assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output 1439 self._conversion_map[old_output] = new_output 1440 1441 def _convert_reduction(self, y): 1442 # Handle reductions. 1443 if self._pfor_config is None or isinstance(y, ops.Operation): 1444 return None 1445 reduction = self._pfor_config._lookup_reduction(y) 1446 if reduction is None: 1447 return None 1448 (reduction_fn, reduction_args) = reduction 1449 batched_args = [] 1450 for reduction_arg in reduction_args: 1451 assert isinstance(reduction_arg, ops.Tensor), reduction_arg 1452 # Tensor being reduced should already be converted due to a control 1453 # dependency on the created placeholder. 1454 # Note that in cases where reduction_arg is in an outer context, one 1455 # needs to locate the corresponding Enter node and use that to lookup 1456 # the conversion. 1457 # TODO(agarwal): handle reductions inside control flow constructs. 1458 assert reduction_arg in self._conversion_map, ( 1459 "Unable to handle reduction of %s, possibly as it was used " 1460 "inside a control flow construct. Note that reductions across " 1461 "pfor iterations are currently not supported inside control flow " 1462 "constructs." % reduction_arg) 1463 batched_arg = self._conversion_map[reduction_arg] 1464 batched_args.append(self._unwrap_or_tile(batched_arg)) 1465 outputs = reduction_fn(*batched_args) 1466 return [wrap(output, False) for output in nest.flatten(outputs)] 1467 1468 def _convert_helper(self, op_or_tensor): 1469 stack = collections.deque([op_or_tensor]) 1470 while stack: 1471 y = stack[0] 1472 if y in self._conversion_map: 1473 assert isinstance(self._conversion_map[y], 1474 (WrappedTensor, ops.Operation)) 1475 stack.popleft() 1476 continue 1477 if isinstance(y, ops.Operation): 1478 assert not y.outputs, ( 1479 "We only support converting Operation objects with no outputs. " 1480 "Got %s", y) 1481 y_op = y 1482 else: 1483 assert isinstance(y, ops.Tensor), y 1484 y_op = y.op 1485 1486 is_while_loop = y_op.type == "Exit" 1487 if is_while_loop: 1488 while_op = WhileOp( 1489 y, pfor_ops=self._pfor_ops, 1490 fallback_to_while_loop=self.fallback_to_while_loop, 1491 pfor_config=self._pfor_config) 1492 is_inside_loop = while_op.is_inside_loop 1493 # If all nodes in the while_loop graph were created inside the pfor, we 1494 # treat the whole loop subgraph as a single op (y_op) and try to convert 1495 # it. For while_loops that are created completely or partially outside, 1496 # we treat them as external and should be able to simply return the Exit 1497 # node output as is without needing any conversion. Note that for 1498 # while_loops that are partially constructed inside, we assume they will 1499 # be loop invariant. If that is not the case, it will create runtime 1500 # errors since the converted graph would depend on the self._loop_var 1501 # placeholder. 1502 if is_inside_loop: 1503 y_op = while_op 1504 else: 1505 is_inside_loop = self.op_is_inside_loop(y_op) 1506 1507 # If this op was not created inside the loop body, we will return as is. 1508 # 1. Convert inputs and control inputs. 1509 1510 def _add_to_stack(x): 1511 if x not in self._conversion_map: 1512 stack.appendleft(x) 1513 return True 1514 else: 1515 return False 1516 1517 if is_inside_loop: 1518 added_to_stack = False 1519 for inp in y_op.inputs: 1520 added_to_stack |= _add_to_stack(inp) 1521 for cinp in y_op.control_inputs: 1522 if cinp.outputs: 1523 for t in cinp.outputs: 1524 added_to_stack |= _add_to_stack(t) 1525 else: 1526 added_to_stack |= _add_to_stack(cinp) 1527 if added_to_stack: 1528 continue 1529 1530 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] 1531 some_input_converted = any(self._was_converted(x) for x in y_op.inputs) 1532 some_input_stacked = any(x.is_stacked for x in converted_inputs) 1533 1534 converted_control_ops = set() 1535 some_control_input_converted = False 1536 for cinp in y_op.control_inputs: 1537 if cinp.outputs: 1538 for t in cinp.outputs: 1539 converted_t = self._conversion_map[t] 1540 if self._was_converted(t): 1541 some_control_input_converted = True 1542 converted_control_ops.add(converted_t.t.op) 1543 else: 1544 converted_cinp = self._conversion_map[cinp] 1545 assert isinstance(converted_cinp, ops.Operation) 1546 if converted_cinp != cinp: 1547 some_control_input_converted = True 1548 converted_control_ops.add(converted_cinp) 1549 converted_control_ops = list(converted_control_ops) 1550 is_stateful = _is_stateful_pfor_op(y_op) 1551 else: 1552 converted_inputs = [] 1553 converted_control_ops = [] 1554 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, 1555 converted_inputs, converted_control_ops) 1556 1557 # 2. Convert y_op 1558 # If converting a while_loop, we let the while_loop convertor deal with 1559 # putting the control dependencies appropriately. 1560 control_dependencies = [] if is_while_loop else converted_control_ops 1561 with ops.control_dependencies(control_dependencies), ops.name_scope( 1562 y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op): 1563 # Op is a placeholder for a reduction. 1564 reduce_output = self._convert_reduction(y) 1565 if reduce_output is not None: 1566 new_outputs = reduce_output 1567 # None of the inputs and control inputs were converted. 1568 elif ((not is_inside_loop or 1569 (not is_stateful and not some_input_converted and 1570 not some_control_input_converted)) and 1571 y.graph == ops.get_default_graph()): 1572 if y is y_op: 1573 assert not isinstance(y_op, WhileOp) 1574 new_outputs = y_op 1575 else: 1576 new_outputs = [wrap(x, False) for x in y_op.outputs] 1577 elif not (is_stateful or is_while_loop or some_input_stacked): 1578 # All inputs are unstacked or unconverted but some control inputs are 1579 # converted. 1580 # TODO(rachelim): Handle the case where some inputs are sparsely 1581 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) 1582 new_op = _create_op(y_op.type, [x.t for x in converted_inputs], 1583 [x.dtype for x in y_op.outputs], 1584 y_op.node_def.attr) 1585 if y is y_op: 1586 new_outputs = new_op 1587 else: 1588 new_outputs = [] 1589 for old_output, new_output in zip(y_op.outputs, new_op.outputs): 1590 handle_data_util.copy_handle_data(old_output, new_output) 1591 new_outputs.append(wrap(new_output, False)) 1592 else: 1593 # Either some inputs are not loop invariant or op is stateful. 1594 if hasattr(y_op, "pfor_converter"): 1595 converter = y_op.pfor_converter 1596 else: 1597 converter = _pfor_converter_registry.get(y_op.type, None) 1598 if converter is None: 1599 has_variant_outputs = any(x.dtype == dtypes.variant for x in 1600 y_op.outputs) 1601 has_vectorized_variant_inputs = any( 1602 _is_variant_with_internal_stacking(x) for x in 1603 y_op.inputs) 1604 if (self._fallback_to_while_loop and not has_variant_outputs 1605 and not has_vectorized_variant_inputs): 1606 converter = _fallback_converter 1607 else: 1608 message = ("No pfor vectorization defined for %s\n" 1609 "%s\n" 1610 "inputs: %s. " % 1611 (y_op.type, y_op, converted_inputs)) 1612 if not self._fallback_to_while_loop: 1613 message += ("Consider enabling the fallback_to_while_loop " 1614 "option to pfor, which may run slower.") 1615 raise ValueError(message) 1616 # TODO(rachelim): Handle the case where some inputs are sparsely 1617 # stacked. We should only call the converter if it supports handling 1618 # those inputs. 1619 pfor_inputs = _PforInput(self, y_op, converted_inputs) 1620 try: 1621 try: 1622 new_outputs = converter(pfor_inputs) 1623 except ConversionNotImplementedError as e: 1624 has_vectorized_variant_inputs = any( 1625 _is_variant_with_internal_stacking(x) for x in 1626 y_op.inputs) 1627 if (self._fallback_to_while_loop 1628 and not has_vectorized_variant_inputs): 1629 new_outputs = _fallback_converter(pfor_inputs) 1630 else: 1631 six.reraise(ValueError, ValueError(str(e)), sys.exc_info()[2]) 1632 except Exception as e: # pylint: disable=broad-except 1633 logging.error( 1634 "Got error while pfor was converting op %s" 1635 "with inputs %s\n, converted inputs %s\n" 1636 "%s\n" 1637 "Here are the pfor conversion stack traces:", y_op, 1638 y_op.inputs[:], pfor_inputs.inputs, str(e)) 1639 original_op = y_op 1640 while isinstance(original_op, ops.Operation): 1641 logging.error( 1642 "%s\ncreated at:\n %s", original_op, 1643 " ".join(traceback.format_list(original_op.traceback))) 1644 original_op = original_op._original_op 1645 six.reraise(e.__class__, e, sys.exc_info()[2]) 1646 1647 if isinstance(new_outputs, WrappedTensor): 1648 new_outputs = [new_outputs] 1649 assert isinstance(new_outputs, 1650 (list, tuple, ops.Operation)), new_outputs 1651 logging.vlog(2, "converted %s %s", y_op, new_outputs) 1652 1653 # Insert into self._conversion_map 1654 if y is y_op: 1655 assert isinstance(new_outputs, ops.Operation) 1656 self._add_conversion(y_op, new_outputs) 1657 else: 1658 assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs, 1659 new_outputs) 1660 for old_output, new_output in zip(y_op.outputs, new_outputs): 1661 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) 1662 assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op) 1663 # Set shape for converted output. 1664 output_shape = old_output.shape 1665 if not new_output.is_sparse_stacked: 1666 if new_output.is_stacked: 1667 loop_len = tensor_util.constant_value(self.loop_len_vector) 1668 if loop_len is None: 1669 batch_dim = tensor_shape.TensorShape([None]) 1670 else: 1671 batch_dim = tensor_shape.TensorShape(loop_len) 1672 output_shape = batch_dim.concatenate(output_shape) 1673 if _is_variant_with_internal_stacking(new_output.t): 1674 new_output.t.set_shape([]) 1675 else: 1676 new_output.t.set_shape(output_shape) 1677 self._add_conversion(old_output, new_output) 1678 stack.popleft() 1679 1680 return self._conversion_map[op_or_tensor] 1681 1682 @property 1683 def loop_len_vector(self): 1684 """Returns a single element vector whose value is number of iterations.""" 1685 return self._loop_len_vector 1686 1687 @property 1688 def loop_var(self): 1689 """Returns placeholder loop variable.""" 1690 return self._loop_var 1691 1692 @property 1693 def pfor_ops(self): 1694 return self._pfor_ops 1695 1696 @property 1697 def pfor_config(self): 1698 return self._pfor_config 1699 1700 @property 1701 def all_indices_partitioned(self): 1702 """all_indices_partitioned property. 1703 1704 Returns: 1705 True if we are inside a control flow construct and not all pfor iterations 1706 may be active. 1707 """ 1708 return self._all_indices_partitioned 1709 1710 @property 1711 def fallback_to_while_loop(self): 1712 return self._fallback_to_while_loop 1713 1714 1715# The code below defines converters for different operations. Please see comment 1716# for RegisterPFor to see how converters should be defined. 1717 1718 1719# image_ops 1720 1721 1722@RegisterPFor("AdjustContrastv2") 1723def _convert_adjust_contrastv2(pfor_input): 1724 images = pfor_input.stacked_input(0) 1725 contrast_factor = pfor_input.unstacked_input(1) 1726 return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True) 1727 1728 1729@RegisterPFor("AdjustHue") 1730def _convert_adjust_hue(pfor_input): 1731 images = pfor_input.stacked_input(0) 1732 delta = pfor_input.unstacked_input(1) 1733 return wrap(gen_image_ops.adjust_hue(images, delta), True) 1734 1735 1736@RegisterPFor("AdjustSaturation") 1737def _convert_adjust_saturation(pfor_input): 1738 images = pfor_input.stacked_input(0) 1739 scale = pfor_input.unstacked_input(1) 1740 return wrap(gen_image_ops.adjust_saturation(images, scale), True) 1741 1742 1743# nn_ops 1744 1745 1746def _flatten_first_two_dims(x): 1747 """Merges first two dimensions.""" 1748 old_shape = array_ops.shape(x) 1749 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) 1750 return array_ops.reshape(x, new_shape) 1751 1752 1753def _unflatten_first_dim(x, first_dim): 1754 """Splits first dimension into [first_dim, -1].""" 1755 old_shape = array_ops.shape(x) 1756 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) 1757 return array_ops.reshape(x, new_shape) 1758 1759 1760def _inputs_with_flattening(pfor_input, input_indices): 1761 """Stacks and flattens first dim of inputs at indices `input_indices`.""" 1762 if input_indices is None: 1763 input_indices = [] 1764 pfor_input.stack_inputs(stack_indices=input_indices) 1765 inputs = [] 1766 for i in range(pfor_input.num_inputs): 1767 if i in input_indices: 1768 inp = pfor_input.stacked_input(i) 1769 inp = _flatten_first_two_dims(inp) 1770 else: 1771 inp = pfor_input.unstacked_input(i) 1772 inputs.append(inp) 1773 return inputs 1774 1775 1776@RegisterPForWithArgs("Conv2D", dims=[0]) 1777@RegisterPForWithArgs("DepthToSpace", dims=[0]) 1778@RegisterPForWithArgs("AvgPool", dims=[0]) 1779@RegisterPForWithArgs("AvgPool3D", dims=[0]) 1780@RegisterPForWithArgs("MaxPool", dims=[0]) 1781@RegisterPForWithArgs("MaxPoolV2", dims=[0]) 1782@RegisterPForWithArgs("MaxPool3D", dims=[0]) 1783@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) 1784@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) 1785@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2]) 1786@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) 1787@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) 1788@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2]) 1789@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1790@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1791@RegisterPForWithArgs("SpaceToDepth", dims=[0]) 1792def _convert_flatten_batch(pfor_input, op_type, dims): 1793 del op_type 1794 inputs = _inputs_with_flattening(pfor_input, dims) 1795 outputs = _create_op( 1796 pfor_input.op_type, 1797 inputs, [x.dtype for x in pfor_input.outputs], 1798 attrs=pfor_input.op.node_def.attr).outputs 1799 n = pfor_input.pfor.loop_len_vector 1800 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1801 return [wrap(x, True) for x in outputs] 1802 1803 1804_channel_flatten_input_cache = {} 1805 1806 1807@RegisterPFor("BatchToSpaceND") 1808def _convert_batch_to_space_nd(pfor_input): 1809 inp = pfor_input.stacked_input(0) 1810 block_shape = pfor_input.unstacked_input(1) 1811 crops = pfor_input.unstacked_input(2) 1812 1813 inp_shape = array_ops.shape(inp) 1814 n = pfor_input.pfor.loop_len_vector 1815 1816 # Reshape and transpose to move the vectorization axis inside the axes that 1817 # will move to space. 1818 # Reshape to 4D and transpose 1819 block_size = math_ops.reduce_prod(block_shape) 1820 new_shape = [n[0], block_size, inp_shape[1] // block_size, -1] 1821 inp = array_ops.reshape(inp, new_shape) 1822 inp = array_ops.transpose(inp, [1, 0, 2, 3]) 1823 # Reshape back to merge the block, vectorization and batch dimension, and 1824 # restore the other dimensions. 1825 new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0) 1826 inp = array_ops.reshape(inp, new_shape) 1827 # Call batch_to_space and then split the new batch axis. 1828 output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops) 1829 output = _unflatten_first_dim(output, n) 1830 return wrap(output, True) 1831 1832 1833@RegisterPFor("SpaceToBatchND") 1834def _convert_space_to_batch_nd(pfor_input): 1835 inp = pfor_input.stacked_input(0) 1836 block_shape = pfor_input.unstacked_input(1) 1837 paddings = pfor_input.unstacked_input(2) 1838 1839 n = pfor_input.pfor.loop_len_vector 1840 inp_shape = array_ops.shape(inp) 1841 inp = _flatten_first_two_dims(inp) 1842 output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings) 1843 output_shape = array_ops.shape(output) 1844 block_size = math_ops.reduce_prod(block_shape) 1845 new_shape = [block_size, n[0], -1] 1846 output = array_ops.reshape(output, new_shape) 1847 output = array_ops.transpose(output, [1, 0, 2]) 1848 new_shape = array_ops.concat( 1849 [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0) 1850 output = array_ops.reshape(output, new_shape) 1851 return wrap(output, True) 1852 1853 1854def _channel_flatten_input(x, data_format): 1855 """Merge the stack dimension with the channel dimension. 1856 1857 If S is pfor's stacking dimension, then, 1858 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose 1859 should be cheap. 1860 - for SNHWC, we transpose to NHWSC. 1861 We then merge the S and C dimension. 1862 1863 Args: 1864 x: ops.Tensor to transform. 1865 data_format: "NCHW" or "NHWC". 1866 1867 Returns: 1868 A 3-element tuple with the transformed value, along with the shape for 1869 reshape and order for transpose required to transform back. 1870 """ 1871 1872 graph = ops.get_default_graph() 1873 cache_key = (graph, x.ref(), data_format) 1874 if cache_key not in _channel_flatten_input_cache: 1875 x_shape = array_ops.shape(x) 1876 if data_format == b"NCHW": 1877 order = [1, 0, 2, 3, 4] 1878 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) 1879 reverse_order = order 1880 else: 1881 order = [1, 2, 3, 0, 4] 1882 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) 1883 reverse_order = [3, 0, 1, 2, 4] 1884 # Move S dimension next to C dimension. 1885 x = array_ops.transpose(x, order) 1886 reverse_shape = array_ops.shape(x) 1887 # Reshape to merge the S and C dimension. 1888 x = array_ops.reshape(x, shape) 1889 outputs = x, reverse_order, reverse_shape 1890 _channel_flatten_input_cache[cache_key] = outputs 1891 else: 1892 outputs = _channel_flatten_input_cache[cache_key] 1893 return outputs 1894 1895 1896# Note that with training=True, running FusedBatchNormV3 on individual examples 1897# is very different from running FusedBatchNormV3 on a batch of those examples. 1898# This is because, for the latter case, the operation can be considered as first 1899# computing the mean and variance over all the examples and then using these 1900# to scale all those examples. This creates a data dependency between these 1901# different "iterations" since the inputs to the scaling step depends on the 1902# statistics coming from all these inputs. 1903# As with other kernels, the conversion here effectively runs the kernel 1904# independently for each iteration, and returns outputs by stacking outputs from 1905# each of those iterations. 1906@RegisterPFor("FusedBatchNormV3") 1907def _convert_fused_batch_norm(pfor_input): 1908 is_training = pfor_input.get_attr("is_training") 1909 # When BatchNorm is used with training=False, mean and variance are provided 1910 # externally and used as is by the op. Thus, we can merge the S and N 1911 # dimensions as we do for regular operations. 1912 # When BatchNorm is used with training=True, mean and variance are computed 1913 # for each channel across the batch dimension (first one). If we merge S and N 1914 # dimensions, mean and variances will be computed over a larger set. So, we 1915 # merge the S and C dimensions instead. 1916 if not is_training: 1917 # We return zeros for batch_mean and batch_variance output. Note that CPU 1918 # and GPU seem to have different behavior for those two outputs. CPU outputs 1919 # zero because these values are not used during inference. GPU outputs 1920 # something, probably real means and variances. 1921 inputs = _inputs_with_flattening(pfor_input, [0]) 1922 outputs = _create_op( 1923 pfor_input.op_type, 1924 inputs, [x.dtype for x in pfor_input.outputs], 1925 attrs=pfor_input.op.node_def.attr).outputs 1926 y = outputs[0] 1927 n = pfor_input.pfor.loop_len_vector 1928 y = _unflatten_first_dim(y, n) 1929 mean = pfor_input.unstacked_input(3) 1930 zeros = array_ops.zeros_like(mean) 1931 return [wrap(y, True)] + [wrap(zeros, False)] * 5 1932 1933 pfor_input.stack_inputs() 1934 data_format = pfor_input.get_attr("data_format") 1935 # We merge the first dimension with the "C" dimension, run FusedBatchNormV3, 1936 # and then transpose back. 1937 x = pfor_input.stacked_input(0) 1938 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) 1939 # Note that we stack all the other inputs as well so that they are the same 1940 # size as the new size of the channel dimension. 1941 inputs = [x] + [ 1942 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1943 for i in range(1, pfor_input.num_inputs) 1944 ] 1945 outputs = _create_op( 1946 pfor_input.op_type, 1947 inputs, [x.dtype for x in pfor_input.outputs], 1948 attrs=pfor_input.op.node_def.attr).outputs 1949 y = outputs[0] 1950 y = array_ops.reshape(y, reverse_shape) 1951 y = array_ops.transpose(y, reverse_order) 1952 n = pfor_input.pfor.loop_len_vector 1953 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1954 outputs = [y] + outputs 1955 return [wrap(x, True) for x in outputs] 1956 1957 1958@RegisterPFor("FusedBatchNormGradV3") 1959def _convert_fused_batch_norm_grad(pfor_input): 1960 pfor_input.stack_inputs() 1961 data_format = pfor_input.get_attr("data_format") 1962 y_backprop = pfor_input.stacked_input(0) 1963 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) 1964 x = pfor_input.stacked_input(1) 1965 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) 1966 inputs = [y_backprop, x] + [ 1967 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1968 for i in range(2, pfor_input.num_inputs) 1969 ] 1970 outputs = _create_op( 1971 pfor_input.op_type, 1972 inputs, [x.dtype for x in pfor_input.outputs], 1973 attrs=pfor_input.op.node_def.attr).outputs 1974 x_backprop = outputs[0] 1975 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) 1976 x_backprop = array_ops.transpose(x_backprop, x_reverse_order) 1977 n = pfor_input.pfor.loop_len_vector 1978 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1979 outputs = [x_backprop] + outputs 1980 return [wrap(output, True) for output in outputs] 1981 1982 1983@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) 1984@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) 1985@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0) 1986def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, 1987 shape_dim): 1988 del op_type 1989 inputs = _inputs_with_flattening(pfor_input, flatten_dims) 1990 n = pfor_input.pfor.loop_len_vector 1991 # Adjust the `input_sizes` input. 1992 ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1], 1993 dtype=n.dtype) 1994 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) 1995 outputs = _create_op( 1996 pfor_input.op_type, 1997 inputs, [x.dtype for x in pfor_input.outputs], 1998 attrs=pfor_input.op.node_def.attr).outputs 1999 outputs = [_unflatten_first_dim(x, n) for x in outputs] 2000 return [wrap(x, True) for x in outputs] 2001 2002 2003@RegisterPFor("Conv2DBackpropFilter") 2004def _convert_conv2d_backprop_filter(pfor_input): 2005 pfor_input.stack_inputs(stack_indices=[2]) 2006 inputs, inputs_stacked, _ = pfor_input.input(0) 2007 filter_sizes = pfor_input.unstacked_input(1) 2008 grads = pfor_input.stacked_input(2) 2009 strides = pfor_input.get_attr("strides") 2010 padding = pfor_input.get_attr("padding") 2011 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") 2012 data_format = pfor_input.get_attr("data_format") 2013 dilations = pfor_input.get_attr("dilations") 2014 if inputs_stacked: 2015 # TODO(agarwal): Implement this efficiently. 2016 logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!") 2017 2018 def while_body(i, ta): 2019 inp_i = inputs[i, ...] 2020 grad_i = grads[i, ...] 2021 output = nn_ops.conv2d_backprop_filter( 2022 inp_i, 2023 filter_sizes, 2024 grad_i, 2025 strides=strides, 2026 padding=padding, 2027 use_cudnn_on_gpu=use_cudnn_on_gpu, 2028 data_format=data_format, 2029 dilations=dilations) 2030 return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) 2031 2032 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) 2033 _, ta = control_flow_ops.while_loop( 2034 lambda i, ta: i < n, while_body, 2035 (0, tensor_array_ops.TensorArray(inputs.dtype, n))) 2036 output = ta.concat() 2037 return wrap(output, True) 2038 else: 2039 # We merge the stack dimension with the channel dimension of the gradients 2040 # and pretend we had a larger filter (see change to filter_sizes below). 2041 # Once the filter backprop is computed, we reshape and transpose back 2042 # appropriately. 2043 grads, _, _ = _channel_flatten_input(grads, data_format) 2044 n = pfor_input.pfor.loop_len_vector 2045 old_filter_sizes = filter_sizes 2046 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) 2047 output = nn_ops.conv2d_backprop_filter( 2048 inputs, 2049 filter_sizes, 2050 grads, 2051 strides=strides, 2052 padding=padding, 2053 use_cudnn_on_gpu=use_cudnn_on_gpu, 2054 data_format=data_format, 2055 dilations=dilations) 2056 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) 2057 output = array_ops.reshape(output, new_filter_shape) 2058 output = array_ops.transpose(output, [3, 0, 1, 2, 4]) 2059 return wrap(output, True) 2060 2061 2062def _flatten_with_inner_dim(x, dim, x_rank): 2063 """Merges the first dim with the specified dim.""" 2064 shape = array_ops.shape(x) 2065 x = array_ops.transpose(x, 2066 list(range(1, dim)) + [0] + list(range(dim, x_rank))) 2067 2068 if dim < x_rank - 1: 2069 new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]] 2070 else: 2071 new_shape_pieces = [shape[1:dim], [-1]] 2072 new_shape = array_ops.concat(new_shape_pieces, axis=0) 2073 return array_ops.reshape(x, new_shape) 2074 2075 2076def _unflatten_with_inner_dim(x, dim, x_rank, stack_size): 2077 """Undoes _flatten_with_inner_dim.""" 2078 shape = array_ops.shape(x) 2079 if dim < x_rank - 1: 2080 new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]] 2081 else: 2082 new_shape_pieces = [shape[:dim], [stack_size], [-1]] 2083 new_shape = array_ops.concat(new_shape_pieces, axis=0) 2084 x = array_ops.reshape(x, new_shape) 2085 dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1)) 2086 return array_ops.transpose(x, dims_permutation) 2087 2088 2089@RegisterPFor("DepthwiseConv2dNative") 2090def _convert_depthwise_conv2d_native(pfor_input): 2091 # Kernel can be vectorized, so folding to batch dimension does not work. We 2092 # instead fold into the channel dimension because it is parallel. 2093 stack_size = pfor_input.pfor.loop_len_vector[0] 2094 data_format = pfor_input.get_attr("data_format") 2095 c_dim = 1 if data_format == b"NCHW" else 3 2096 t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) 2097 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) 2098 conv = _create_op( 2099 "DepthwiseConv2dNative", [t, kernel], 2100 [x.dtype for x in pfor_input.outputs], 2101 attrs=pfor_input.op.node_def.attr).outputs[0] 2102 return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True) 2103 2104 2105@RegisterPFor("DepthwiseConv2dNativeBackpropInput") 2106def _convert_depthwise_conv2d_native_backprop_input(pfor_input): 2107 stack_size = pfor_input.pfor.loop_len_vector[0] 2108 input_sizes = pfor_input.unstacked_input(0) 2109 data_format = pfor_input.get_attr("data_format") 2110 c_dim = 1 if data_format == b"NCHW" else 3 2111 input_sizes_mutipliers = [ 2112 constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size] 2113 ] 2114 if c_dim < 3: 2115 input_sizes_mutipliers += [ 2116 constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32) 2117 ] 2118 input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0) 2119 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) 2120 out_backprop = _flatten_with_inner_dim( 2121 pfor_input.stacked_input(2), c_dim + 1, 5) 2122 result = _create_op( 2123 "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop], 2124 [x.dtype for x in pfor_input.outputs], 2125 attrs=pfor_input.op.node_def.attr).outputs[0] 2126 return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True) 2127 2128 2129@RegisterPFor("DepthwiseConv2dNativeBackpropFilter") 2130def _convert_depthwise_conv2d_native_backprop_filter(pfor_input): 2131 stack_size = pfor_input.pfor.loop_len_vector[0] 2132 data_format = pfor_input.get_attr("data_format") 2133 c_dim = 1 if data_format == b"NCHW" else 3 2134 inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) 2135 filter_sizes = pfor_input.unstacked_input(1) 2136 filter_sizes_multipliers = [ 2137 constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size], 2138 constant_op.constant([1], dtype=dtypes.int32) 2139 ] 2140 filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0) 2141 out_backprop = _flatten_with_inner_dim( 2142 pfor_input.stacked_input(2), c_dim + 1, 5) 2143 result = _create_op( 2144 "DepthwiseConv2dNativeBackpropFilter", 2145 [inputs, filter_sizes, out_backprop], 2146 [x.dtype for x in pfor_input.outputs], 2147 attrs=pfor_input.op.node_def.attr).outputs[0] 2148 return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True) 2149 2150 2151@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax) 2152@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax) 2153def _convert_softmax(pfor_input, op_type, op_func): 2154 del op_type 2155 return wrap(op_func(pfor_input.stacked_input(0)), True) 2156 2157 2158# array_ops 2159 2160 2161@RegisterPForWithArgs("Identity", array_ops.identity) 2162@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) 2163@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag) 2164@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) 2165@RegisterPForWithArgs("_EagerConst", array_ops.identity) 2166def _convert_identity(pfor_input, op_type, op_func): 2167 del op_type 2168 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 2169 2170 2171@RegisterPFor("IdentityN") 2172def _convert_identity_n(pfor_input): 2173 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) 2174 return [ 2175 wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs) 2176 ] 2177 2178 2179@RegisterPFor("Reshape") 2180def _convert_reshape(pfor_input): 2181 t = pfor_input.stacked_input(0) 2182 shape = pfor_input.unstacked_input(1) 2183 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2184 return wrap(array_ops.reshape(t, new_shape), True) 2185 2186 2187@RegisterPFor("Fill") 2188def _convert_fill(pfor_input): 2189 dims = pfor_input.unstacked_input(0) 2190 value = pfor_input.stacked_input(1) 2191 # Expand the rank of `value` 2192 new_shape = array_ops.concat( 2193 [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)], 2194 axis=0) 2195 value = array_ops.reshape(value, new_shape) 2196 # Compute the new output shape 2197 new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0) 2198 # Broadcast 2199 return wrap(array_ops.broadcast_to(value, new_dims), True) 2200 2201 2202@RegisterPFor("BroadcastTo") 2203def _convert_broadcast_to(pfor_input): 2204 t = pfor_input.stacked_input(0) 2205 shape = pfor_input.unstacked_input(1) 2206 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2207 2208 # Expand dims of stacked t to broadcast against the new shape. 2209 # TODO(davmre): consider factoring out common code with 2210 # `expanddim_inputs_for_broadcast`, which has similar logic but with 2211 # implicit shapes (of input Tensors) rather than explicit shapes. 2212 rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t) 2213 ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1])) 2214 t_shape = array_ops.shape(t) 2215 t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0) 2216 2217 return wrap( 2218 array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape), 2219 True) 2220 2221 2222@RegisterPFor("ExpandDims") 2223def _convert_expanddims(pfor_input): 2224 t = pfor_input.stacked_input(0) 2225 dim = pfor_input.unstacked_input(1) 2226 dim += math_ops.cast(dim >= 0, dim.dtype) 2227 return wrap(array_ops.expand_dims(t, axis=dim), True) 2228 2229 2230@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound) 2231@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound) 2232def _convert_searchsorted(pfor_input, _, op_func): 2233 pfor_input.stack_inputs() 2234 sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0)) 2235 values = _flatten_first_two_dims(pfor_input.stacked_input(1)) 2236 out_type = pfor_input.get_attr("out_type") 2237 output = op_func(sorted_inputs, values, out_type) 2238 return wrap( 2239 _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True) 2240 2241 2242@RegisterPFor("MatrixBandPart") 2243def _convert_matrix_band_part(pfor_input): 2244 t = pfor_input.stacked_input(0) 2245 num_lower = pfor_input.unstacked_input(1) 2246 num_upper = pfor_input.unstacked_input(2) 2247 return wrap( 2248 array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper), 2249 True) 2250 2251 2252@RegisterPFor("MatrixSetDiag") 2253def _convert_matrix_set_diag(pfor_input): 2254 pfor_input.stack_inputs() 2255 t = pfor_input.stacked_input(0) 2256 diag = pfor_input.stacked_input(1) 2257 return wrap(array_ops.matrix_set_diag(t, diag), True) 2258 2259 2260# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. 2261# The input orders defined in the OpKernel and the actual python API are 2262# different (for compatibility with V1), so we cannot use _convert_identity. 2263# v2 is not compatible with v3 and is never exposed on the public API. 2264@RegisterPFor("MatrixDiagV2") 2265@RegisterPFor("MatrixDiagV3") 2266def _convert_matrix_diag_v2(pfor_input): 2267 params = { 2268 "diagonal": pfor_input.stacked_input(0), 2269 "k": pfor_input.unstacked_input(1), 2270 "num_rows": pfor_input.unstacked_input(2), 2271 "num_cols": pfor_input.unstacked_input(3), 2272 "padding_value": pfor_input.unstacked_input(4) 2273 } 2274 if pfor_input.op_type == "MatrixDiagV2": 2275 return wrap(array_ops.matrix_diag_v2(**params), True) 2276 params["align"] = pfor_input.get_attr("align") 2277 return wrap(array_ops.matrix_diag(**params), True) 2278 2279 2280@RegisterPFor("Diag") 2281def _convert_diag(pfor_input): 2282 diag = pfor_input.stacked_input(0) 2283 if diag.shape.ndims == 2: 2284 # We can use matrix_diag. 2285 return wrap(array_ops.matrix_diag(diag), True) 2286 else: 2287 # It is not clear if we can do better than a while loop here with existing 2288 # kernels. 2289 return _fallback_converter(pfor_input, warn=False) 2290 2291 2292# See notes for MatrixDiagV2 2293@RegisterPFor("MatrixDiagPartV2") 2294@RegisterPFor("MatrixDiagPartV3") 2295def _convert_matrix_diag_part_v2(pfor_input): 2296 params = { 2297 "input": pfor_input.stacked_input(0), 2298 "k": pfor_input.unstacked_input(1), 2299 "padding_value": pfor_input.unstacked_input(2) 2300 } 2301 if pfor_input.op_type == "MatrixDiagPartV2": 2302 return wrap(array_ops.matrix_diag_part_v2(**params), True) 2303 params["align"] = pfor_input.get_attr("align") 2304 return wrap(array_ops.matrix_diag_part(**params), True) 2305 2306 2307# See notes for MatrixDiagV2 2308@RegisterPFor("MatrixSetDiagV2") 2309@RegisterPFor("MatrixSetDiagV3") 2310def _convert_matrix_set_diag_v2(pfor_input): 2311 pfor_input.stack_inputs([0, 1]) 2312 params = { 2313 "input": pfor_input.stacked_input(0), 2314 "diagonal": pfor_input.stacked_input(1), 2315 "k": pfor_input.unstacked_input(2) 2316 } 2317 if pfor_input.op_type == "MatrixSetDiagV2": 2318 return wrap(array_ops.matrix_set_diag_v2(**params), True) 2319 params["align"] = pfor_input.get_attr("align") 2320 return wrap(array_ops.matrix_set_diag(**params), True) 2321 2322 2323@RegisterPFor("DiagPart") 2324def _convert_diag_part(pfor_input): 2325 inp = pfor_input.stacked_input(0) 2326 if inp.shape.ndims == 3: 2327 # We can use matrix_diag_part. 2328 return wrap(array_ops.matrix_diag_part(inp), True) 2329 else: 2330 # It is not clear if we can do better than a while loop here with existing 2331 # kernels. 2332 return _fallback_converter(pfor_input, warn=False) 2333 2334 2335@RegisterPFor("OneHot") 2336def _convert_one_hot(pfor_input): 2337 indices = pfor_input.stacked_input(0) 2338 depth = pfor_input.unstacked_input(1) 2339 on_value = pfor_input.unstacked_input(2) 2340 off_value = pfor_input.unstacked_input(3) 2341 axis = pfor_input.get_attr("axis") 2342 if axis >= 0: 2343 axis += 1 2344 return wrap( 2345 array_ops.one_hot(indices, depth, on_value, off_value, axis), True) 2346 2347 2348@RegisterPFor("Slice") 2349def _convert_slice(pfor_input): 2350 t = pfor_input.stacked_input(0) 2351 begin = pfor_input.unstacked_input(1) 2352 size = pfor_input.unstacked_input(2) 2353 begin = array_ops.concat([[0], begin], axis=0) 2354 size = array_ops.concat([[-1], size], axis=0) 2355 return wrap(array_ops.slice(t, begin, size), True) 2356 2357 2358@RegisterPFor("Tile") 2359def _convert_tile(pfor_input): 2360 t = pfor_input.stacked_input(0) 2361 multiples = pfor_input.unstacked_input(1) 2362 multiples = array_ops.concat([[1], multiples], 0) 2363 return wrap(array_ops.tile(t, multiples), True) 2364 2365 2366@RegisterPFor("Pack") 2367def _convert_pack(pfor_input): 2368 pfor_input.stack_inputs() 2369 axis = pfor_input.get_attr("axis") 2370 if axis >= 0: 2371 axis += 1 2372 return wrap( 2373 array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) 2374 2375 2376@RegisterPFor("Unpack") 2377def _convert_unpack(pfor_input): 2378 value = pfor_input.stacked_input(0) 2379 axis = pfor_input.get_attr("axis") 2380 if axis >= 0: 2381 axis += 1 2382 num = pfor_input.get_attr("num") 2383 return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] 2384 2385 2386@RegisterPFor("Pad") 2387def _convert_pad(pfor_input): 2388 t = pfor_input.stacked_input(0) 2389 paddings = pfor_input.unstacked_input(1) 2390 paddings = array_ops.concat([[[0, 0]], paddings], 0) 2391 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) 2392 2393 2394@RegisterPFor("Split") 2395def _convert_split(pfor_input): 2396 split_dim = pfor_input.unstacked_input(0) 2397 t = pfor_input.stacked_input(1) 2398 num_split = pfor_input.get_attr("num_split") 2399 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2400 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] 2401 2402 2403@RegisterPFor("SplitV") 2404def _convert_split_v(pfor_input): 2405 t = pfor_input.stacked_input(0) 2406 splits = pfor_input.unstacked_input(1) 2407 split_dim = pfor_input.unstacked_input(2) 2408 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2409 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] 2410 2411 2412@RegisterPFor("Squeeze") 2413def _convert_squeeze(pfor_input): 2414 t = pfor_input.stacked_input(0) 2415 squeeze_dims = pfor_input.get_attr("squeeze_dims") 2416 squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] 2417 return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) 2418 2419 2420@RegisterPFor("ReverseV2") 2421def _convert_reverse(pfor_input): 2422 value = pfor_input.stacked_input(0) 2423 axis = pfor_input.unstacked_input(1) 2424 new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis) 2425 return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True) 2426 2427 2428@RegisterPForWithArgs("Transpose", gen_array_ops.transpose) 2429@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose) 2430def _convert_transpose(pfor_input, _, op_func): 2431 t = pfor_input.stacked_input(0) 2432 perm = pfor_input.unstacked_input(1) 2433 new_perm = array_ops.concat([[0], perm + 1], axis=0) 2434 return wrap(op_func(t, new_perm), True) 2435 2436 2437@RegisterPFor("ZerosLike") 2438def _convert_zeroslike(pfor_input): 2439 t = pfor_input.stacked_input(0) 2440 shape = array_ops.shape(t)[1:] 2441 return wrap(array_ops.zeros(shape, dtype=t.dtype), False) 2442 2443 2444@RegisterPFor("Gather") 2445@RegisterPFor("GatherV2") 2446def _convert_gather(pfor_input): 2447 param, param_stacked, _ = pfor_input.input(0) 2448 indices, indices_stacked, _ = pfor_input.input(1) 2449 batch_dims = pfor_input.get_attr("batch_dims") 2450 2451 op_type = pfor_input.op_type 2452 if op_type == "Gather": 2453 validate_indices = pfor_input.get_attr("validate_indices") 2454 axis = 0 2455 else: 2456 validate_indices = None 2457 # Assume we will never have a Tensor with rank > 2**32. 2458 axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32) 2459 axis_value = tensor_util.constant_value(axis) 2460 if axis_value is not None: 2461 axis = axis_value 2462 if indices_stacked and not param_stacked: 2463 if indices is pfor_input.pfor.all_indices and axis == 0: 2464 param_shape0 = tensor_shape.dimension_value(param.shape[0]) 2465 indices_shape0 = tensor_shape.dimension_value(indices.shape[0]) 2466 if param_shape0 is not None and indices_shape0 == param_shape0: 2467 # Note that with loops and conditionals, indices may not be contiguous. 2468 # However they will be sorted and unique. So if the shape matches, then 2469 # it must be picking up all the rows of param. 2470 return wrap(param, True) 2471 2472 if batch_dims != 0: 2473 # Convert `batch_dims` to its positive equivalent if necessary. 2474 batch_dims_pos = batch_dims 2475 if batch_dims < 0: 2476 batch_dims_pos += array_ops.rank(indices) 2477 # In order to maintain 2478 # indices.shape[:batch_dims] == params.shape[:batch_dims] 2479 # with stacked indices, we move the first dimension of `indices` to the 2480 # `batch_dims + 1`th position. The (non-batch) index dimensions will be 2481 # inserted into the shape of `output` at the `axis` dimension, which is 2482 # then transposed to the front (below). 2483 order = array_ops.concat([ 2484 math_ops.range(1, batch_dims_pos + 1), 2485 [0], 2486 math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0) 2487 indices = array_ops.transpose(indices, order) 2488 2489 output = array_ops.gather( 2490 param, indices, validate_indices=validate_indices, axis=axis, 2491 batch_dims=batch_dims) 2492 if axis != 0: 2493 axis = control_flow_ops.cond(axis < 0, 2494 lambda: axis + array_ops.rank(param), 2495 lambda: axis) 2496 order = array_ops.concat( 2497 [[axis], 2498 math_ops.range(axis), 2499 math_ops.range(axis + 1, array_ops.rank(output))], 2500 axis=0) 2501 output = control_flow_ops.cond( 2502 math_ops.equal(axis, 0), lambda: output, 2503 lambda: array_ops.transpose(output, order)) 2504 return wrap(output, True) 2505 if param_stacked: 2506 pfor_input.stack_inputs(stack_indices=[1]) 2507 indices = pfor_input.stacked_input(1) 2508 2509 output = array_ops.gather( 2510 param, indices, 2511 axis=array_ops.where(axis >= 0, axis + 1, axis), 2512 batch_dims=(batch_dims + 1 if batch_dims >= 0 else batch_dims)) 2513 return wrap(output, True) 2514 2515 2516@RegisterPFor("GatherNd") 2517def _convert_gather_nd(pfor_input): 2518 # TODO(jmenick): Add support for unstacked params. 2519 pfor_input.stack_inputs(stack_indices=[1]) 2520 params = pfor_input.stacked_input(0) 2521 indices = pfor_input.stacked_input(1) 2522 stacked_result = array_ops.gather_nd(params, indices, batch_dims=1) 2523 return wrap(stacked_result, True) 2524 2525 2526@RegisterPFor("ConcatV2") 2527def _convert_concatv2(pfor_input): 2528 n = pfor_input.num_inputs 2529 pfor_input.stack_inputs(stack_indices=range(n - 1)) 2530 axis = pfor_input.unstacked_input(n - 1) 2531 axis += math_ops.cast(axis >= 0, axis.dtype) 2532 return wrap( 2533 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), 2534 True) 2535 2536 2537@RegisterPFor("StridedSlice") 2538def _convert_strided_slice(pfor_input): 2539 inp = pfor_input.stacked_input(0) 2540 begin = pfor_input.unstacked_input(1) 2541 end = pfor_input.unstacked_input(2) 2542 strides = pfor_input.unstacked_input(3) 2543 begin_mask = pfor_input.get_attr("begin_mask") 2544 end_mask = pfor_input.get_attr("end_mask") 2545 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2546 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2547 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2548 2549 begin = array_ops.concat([[0], begin], axis=0) 2550 end = array_ops.concat([[0], end], axis=0) 2551 strides = array_ops.concat([[1], strides], axis=0) 2552 begin_mask = begin_mask << 1 | 1 2553 end_mask = end_mask << 1 | 1 2554 ellipsis_mask <<= 1 2555 new_axis_mask <<= 1 2556 shrink_axis_mask <<= 1 2557 return wrap( 2558 array_ops.strided_slice( 2559 inp, 2560 begin, 2561 end, 2562 strides, 2563 begin_mask=begin_mask, 2564 end_mask=end_mask, 2565 ellipsis_mask=ellipsis_mask, 2566 new_axis_mask=new_axis_mask, 2567 shrink_axis_mask=shrink_axis_mask), True) 2568 2569 2570@RegisterPFor("StridedSliceGrad") 2571def _convert_strided_slice_grad(pfor_input): 2572 shape = pfor_input.unstacked_input(0) 2573 begin = pfor_input.unstacked_input(1) 2574 end = pfor_input.unstacked_input(2) 2575 strides = pfor_input.unstacked_input(3) 2576 dy = pfor_input.stacked_input(4) 2577 begin_mask = pfor_input.get_attr("begin_mask") 2578 end_mask = pfor_input.get_attr("end_mask") 2579 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2580 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2581 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2582 2583 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2584 begin = array_ops.concat([[0], begin], axis=0) 2585 end = array_ops.concat([[0], end], axis=0) 2586 strides = array_ops.concat([[1], strides], axis=0) 2587 begin_mask = begin_mask << 1 | 1 2588 end_mask = end_mask << 1 | 1 2589 ellipsis_mask <<= 1 2590 new_axis_mask <<= 1 2591 shrink_axis_mask <<= 1 2592 return wrap( 2593 array_ops.strided_slice_grad( 2594 shape, 2595 begin, 2596 end, 2597 strides, 2598 dy, 2599 begin_mask=begin_mask, 2600 end_mask=end_mask, 2601 ellipsis_mask=ellipsis_mask, 2602 new_axis_mask=new_axis_mask, 2603 shrink_axis_mask=shrink_axis_mask), True) 2604 2605 2606@RegisterPFor("CheckNumerics") 2607def _convert_check_numerics(pfor_input): 2608 t = pfor_input.stacked_input(0) 2609 message = pfor_input.get_attr("message") 2610 return wrap(gen_array_ops.check_numerics(t, message), True) 2611 2612 2613@RegisterPFor("EnsureShape") 2614def _convert_ensure_shape(pfor_input): 2615 t = pfor_input.stacked_input(0) 2616 shape = tensor_shape.TensorShape(pfor_input.get_attr("shape")) 2617 return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True) 2618 2619 2620# manip_ops 2621 2622 2623@RegisterPFor("Roll") 2624def _convert_roll(pfor_input): 2625 t = pfor_input.stacked_input(0) 2626 shift = pfor_input.unstacked_input(1) 2627 axis = pfor_input.unstacked_input(2) 2628 return wrap(manip_ops.roll(t, shift, axis + 1), True) 2629 2630 2631# math_ops 2632 2633 2634@RegisterPFor("MatMul") 2635def _convert_matmul(pfor_input): 2636 # TODO(agarwal): Check if tiling is faster than two transposes. 2637 a, a_stacked, _ = pfor_input.input(0) 2638 b, b_stacked, _ = pfor_input.input(1) 2639 tr_a = pfor_input.get_attr("transpose_a") 2640 tr_b = pfor_input.get_attr("transpose_b") 2641 if a_stacked and b_stacked: 2642 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) 2643 return output 2644 elif a_stacked: 2645 if tr_a: 2646 a = array_ops.transpose(a, [0, 2, 1]) 2647 if a.shape.is_fully_defined(): 2648 x, y, z = a.shape 2649 else: 2650 x, y, z = [ 2651 array_ops.reshape(i, []) 2652 for i in array_ops.split(array_ops.shape(a), 3) 2653 ] 2654 a = array_ops.reshape(a, [x * y, z]) 2655 prod = math_ops.matmul(a, b, transpose_b=tr_b) 2656 return wrap(array_ops.reshape(prod, [x, y, -1]), True) 2657 else: 2658 assert b_stacked 2659 if tr_b: 2660 perm = [2, 0, 1] 2661 b = array_ops.transpose(b, perm) 2662 else: 2663 # As an optimization, if one of the first two dimensions is 1, then we can 2664 # reshape instead of transpose. 2665 # TODO(agarwal): This check can be done inside Transpose kernel. 2666 b_shape = array_ops.shape(b) 2667 min_dim = math_ops.minimum(b_shape[0], b_shape[1]) 2668 perm = array_ops.where( 2669 math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2]) 2670 new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) 2671 b = array_ops.transpose(b, perm) 2672 b = array_ops.reshape(b, new_shape) 2673 2674 if b.shape.is_fully_defined(): 2675 x, y, z = b.shape 2676 else: 2677 x, y, z = [ 2678 array_ops.reshape(i, []) 2679 for i in array_ops.split(array_ops.shape(b), 3) 2680 ] 2681 b = array_ops.reshape(b, [x, y * z]) 2682 prod = math_ops.matmul(a, b, transpose_a=tr_a) 2683 prod = array_ops.reshape(prod, [-1, y, z]) 2684 prod = array_ops.transpose(prod, [1, 0, 2]) 2685 return wrap(prod, True) 2686 2687 2688# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window 2689# is met. 2690@RegisterPFor("BatchMatMul") 2691def _convert_batch_mat_mul(pfor_input): 2692 # TODO(agarwal): There may be a more efficient way to do this instead of 2693 # stacking the inputs. 2694 pfor_input.stack_inputs() 2695 x = pfor_input.stacked_input(0) 2696 y = pfor_input.stacked_input(1) 2697 adj_x = pfor_input.get_attr("adj_x") 2698 adj_y = pfor_input.get_attr("adj_y") 2699 2700 x = _flatten_first_two_dims(x) 2701 y = _flatten_first_two_dims(y) 2702 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2703 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) 2704 return wrap(output, True) 2705 2706 2707@RegisterPFor("BatchMatMulV2") 2708def _convert_batch_mat_mul_v2(pfor_input): 2709 pfor_input.expanddim_inputs_for_broadcast() 2710 x = pfor_input.input(0)[0] 2711 y = pfor_input.input(1)[0] 2712 adj_x = pfor_input.get_attr("adj_x") 2713 adj_y = pfor_input.get_attr("adj_y") 2714 2715 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2716 return wrap(output, True) 2717 2718 2719@RegisterPForWithArgs("Sum", math_ops.reduce_sum) 2720@RegisterPForWithArgs("Prod", math_ops.reduce_prod) 2721@RegisterPForWithArgs("Max", math_ops.reduce_max) 2722@RegisterPForWithArgs("Min", math_ops.reduce_min) 2723@RegisterPForWithArgs("Mean", math_ops.reduce_mean) 2724@RegisterPForWithArgs("All", math_ops.reduce_all) 2725@RegisterPForWithArgs("Any", math_ops.reduce_any) 2726def _convert_reduction(pfor_input, _, op_func): 2727 t = pfor_input.stacked_input(0) 2728 indices = pfor_input.unstacked_input(1) 2729 # Shift positive indices by one to account for the extra dimension. 2730 indices += math_ops.cast(indices >= 0, indices.dtype) 2731 keep_dims = pfor_input.get_attr("keep_dims") 2732 return wrap(op_func(t, indices, keepdims=keep_dims), True) 2733 2734 2735@RegisterPForWithArgs("ArgMax", math_ops.argmax) 2736@RegisterPForWithArgs("ArgMin", math_ops.argmin) 2737def _convert_argmax_argmin(pfor_input, _, op_func): 2738 t = pfor_input.stacked_input(0) 2739 dimension = pfor_input.unstacked_input(1) 2740 dimension += math_ops.cast(dimension >= 0, dimension.dtype) 2741 output_type = pfor_input.get_attr("output_type") 2742 return wrap(op_func(t, axis=dimension, output_type=output_type), True) 2743 2744 2745@RegisterPFor("Bucketize") 2746def _convert_bucketize(pfor_input): 2747 t = pfor_input.stacked_input(0) 2748 boundaries = pfor_input.get_attr("boundaries") 2749 return wrap(math_ops.bucketize(t, boundaries), True) 2750 2751 2752@RegisterPFor("ClipByValue") 2753def _convert_clip_by_value(pfor_input): 2754 t = pfor_input.stacked_input(0) 2755 clip_value_min = pfor_input.unstacked_input(1) 2756 clip_value_max = pfor_input.unstacked_input(2) 2757 return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max), 2758 True) 2759 2760 2761@RegisterPForWithArgs("Cumsum", math_ops.cumsum) 2762@RegisterPForWithArgs("Cumprod", math_ops.cumprod) 2763def _convert_cumfoo(pfor_input, _, op_func): 2764 t = pfor_input.stacked_input(0) 2765 axis = pfor_input.unstacked_input(1) 2766 # Shift positive indices by one to account for the extra dimension. 2767 axis += math_ops.cast(axis >= 0, axis.dtype) 2768 exclusive = pfor_input.get_attr("exclusive") 2769 reverse = pfor_input.get_attr("reverse") 2770 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) 2771 2772 2773@RegisterPFor("BiasAdd") 2774def _convert_biasadd(pfor_input): 2775 t, t_stacked, _ = pfor_input.input(0) 2776 bias, bias_stacked, _ = pfor_input.input(1) 2777 data_format = pfor_input.get_attr("data_format").decode() 2778 if bias_stacked: 2779 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. 2780 pfor_input.expanddim_inputs_for_broadcast() 2781 t, _, _ = pfor_input.input(0) 2782 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) 2783 if compat.as_bytes(data_format) == b"NCHW": 2784 b_shape = array_ops.shape(bias) 2785 new_b_shape = array_ops.concat( 2786 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) 2787 bias = array_ops.reshape(bias, new_b_shape) 2788 return wrap(math_ops.add(t, bias), True) 2789 else: 2790 assert t_stacked, "At least one input to BiasAdd should be loop variant." 2791 if compat.as_bytes(data_format) == b"NCHW": 2792 shape = array_ops.shape(t) 2793 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) 2794 t = array_ops.reshape(t, flattened_shape) 2795 t = nn_ops.bias_add(t, bias, data_format="NCHW") 2796 t = array_ops.reshape(t, shape) 2797 return wrap(t, True) 2798 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) 2799 2800 2801@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum) 2802@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max) 2803@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min) 2804@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod) 2805def _convert_unsortedsegmentsum(pfor_input, _, op_func): 2806 pfor_input.stack_inputs([0, 1]) 2807 data = pfor_input.stacked_input(0) 2808 segment_ids = pfor_input.stacked_input(1) 2809 # TODO(agarwal): handle stacked? 2810 num_segments = pfor_input.unstacked_input(2) 2811 if segment_ids.dtype != num_segments.dtype: 2812 segment_ids = math_ops.cast(segment_ids, dtypes.int64) 2813 num_segments = math_ops.cast(num_segments, dtypes.int64) 2814 dtype = segment_ids.dtype 2815 segment_shape = array_ops.shape(segment_ids, out_type=dtype) 2816 n = segment_shape[0] 2817 ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:] 2818 segment_offset = num_segments * math_ops.range(n, dtype=dtype) 2819 segment_offset = array_ops.reshape(segment_offset, 2820 array_ops.concat([[n], ones], axis=0)) 2821 segment_ids += segment_offset 2822 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( 2823 n, dtypes.int64) 2824 output = op_func(data, segment_ids, num_segments) 2825 new_output_shape = array_ops.concat( 2826 [[n, -1], array_ops.shape(output)[1:]], axis=0) 2827 output = array_ops.reshape(output, new_output_shape) 2828 return wrap(output, True) 2829 2830 2831def _flatten_array_with_offset(ids, offset_delta, num_rows): 2832 """Flattens a rank 2 tensor, adding an offset to each row.""" 2833 # Note that if `ids` is rank 1, it is broadcast to rank 2. 2834 offset_delta = math_ops.cast(offset_delta, ids.dtype) 2835 n = math_ops.cast(num_rows, dtype=ids.dtype) 2836 offsets = math_ops.range( 2837 start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype) 2838 offsets = array_ops.expand_dims(offsets, -1) 2839 ids += offsets 2840 return array_ops.reshape(ids, [-1]) 2841 2842 2843@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2) 2844@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2) 2845@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2) 2846@RegisterPForWithArgs("SparseSegmentSumWithNumSegments", 2847 math_ops.sparse_segment_sum_v2) 2848@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments", 2849 math_ops.sparse_segment_mean_v2) 2850@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments", 2851 math_ops.sparse_segment_sqrt_n_v2) 2852def _convert_sparse_segment(pfor_input, _, op_func): 2853 _, segment_ids_stacked, _ = pfor_input.input(2) 2854 if segment_ids_stacked: 2855 pfor_input.stack_inputs([1]) 2856 data, data_stacked, _ = pfor_input.input(0) 2857 indices, _, _ = pfor_input.input(1) 2858 num_inputs = len(pfor_input.inputs) 2859 assert num_inputs in (3, 4) 2860 if num_inputs == 3: 2861 # `segment_ids` needs to be unstacked since otherwise output sizes could 2862 # differ across pfor iterations. 2863 segment_ids = pfor_input.unstacked_input(2) 2864 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2865 else: 2866 segment_ids, _, _ = pfor_input.input(2) 2867 num_segments = pfor_input.unstacked_input(3) 2868 2869 n = pfor_input.pfor.loop_len_vector[0] 2870 if data_stacked: 2871 indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n) 2872 data = _flatten_first_two_dims(data) 2873 else: 2874 indices = array_ops.reshape(indices, [-1]) 2875 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2876 2877 if num_inputs == 3: 2878 num_segments = None 2879 else: 2880 num_segments *= n 2881 output = op_func(data, indices, segment_ids, num_segments=num_segments) 2882 output = _unflatten_first_dim(output, [n]) 2883 return wrap(output, True) 2884 2885 2886@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad) 2887@RegisterPForWithArgs("SparseSegmentMeanGrad", 2888 math_ops.sparse_segment_mean_grad) 2889@RegisterPForWithArgs("SparseSegmentSqrtNGrad", 2890 math_ops.sparse_segment_sqrt_n_grad) 2891def _convert_sparse_segment_grad(pfor_input, _, op_func): 2892 grad = pfor_input.stacked_input(0) 2893 indices = pfor_input.unstacked_input(1) 2894 segment_ids = pfor_input.unstacked_input(2) 2895 dim0 = pfor_input.unstacked_input(3) 2896 2897 n = pfor_input.pfor.loop_len_vector[0] 2898 indices = _flatten_array_with_offset(indices, dim0, n) 2899 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2900 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2901 grad = _flatten_first_two_dims(grad) 2902 dim0 *= n 2903 output = op_func(grad, indices, segment_ids, dim0) 2904 output = _unflatten_first_dim(output, [n]) 2905 return wrap(output, True) 2906 2907 2908@RegisterPFor("Cast") 2909def _convert_cast(pfor_input): 2910 inp = pfor_input.stacked_input(0) 2911 dtype = pfor_input.get_attr("DstT") 2912 return wrap(math_ops.cast(inp, dtype), True) 2913 2914 2915@RegisterPFor("Abs") 2916@RegisterPFor("Acos") 2917@RegisterPFor("Acosh") 2918@RegisterPFor("Add") 2919@RegisterPFor("AddV2") 2920@RegisterPFor("Angle") 2921@RegisterPFor("Asin") 2922@RegisterPFor("Asinh") 2923@RegisterPFor("Atan") 2924@RegisterPFor("Atan2") 2925@RegisterPFor("Atanh") 2926@RegisterPFor("BesselI0") 2927@RegisterPFor("BesselI1") 2928@RegisterPFor("BesselI0e") 2929@RegisterPFor("BesselI1e") 2930@RegisterPFor("BesselK0") 2931@RegisterPFor("BesselK1") 2932@RegisterPFor("BesselK0e") 2933@RegisterPFor("BesselK1e") 2934@RegisterPFor("BesselJ0") 2935@RegisterPFor("BesselJ1") 2936@RegisterPFor("BesselY0") 2937@RegisterPFor("BesselY1") 2938@RegisterPFor("BitwiseAnd") 2939@RegisterPFor("BitwiseOr") 2940@RegisterPFor("BitwiseXor") 2941@RegisterPFor("Ceil") 2942@RegisterPFor("Complex") 2943@RegisterPFor("ComplexAbs") 2944@RegisterPFor("Conj") 2945@RegisterPFor("Cos") 2946@RegisterPFor("Cosh") 2947@RegisterPFor("Dawsn") 2948@RegisterPFor("Digamma") 2949@RegisterPFor("Div") 2950@RegisterPFor("DivNoNan") 2951@RegisterPFor("Elu") 2952@RegisterPFor("Erf") 2953@RegisterPFor("Erfc") 2954@RegisterPFor("Erfinv") 2955@RegisterPFor("Exp") 2956@RegisterPFor("Expint") 2957@RegisterPFor("Expm1") 2958@RegisterPFor("Floor") 2959@RegisterPFor("FloorDiv") 2960@RegisterPFor("FloorMod") 2961@RegisterPFor("FresnelCos") 2962@RegisterPFor("FresnelSin") 2963@RegisterPFor("Greater") 2964@RegisterPFor("GreaterEqual") 2965@RegisterPFor("Igamma") 2966@RegisterPFor("IgammaGradA") 2967@RegisterPFor("Igammac") 2968@RegisterPFor("Imag") 2969@RegisterPFor("Inv") 2970@RegisterPFor("Invert") 2971@RegisterPFor("IsFinite") 2972@RegisterPFor("IsInf") 2973@RegisterPFor("IsNan") 2974@RegisterPFor("LeftShift") 2975@RegisterPFor("Less") 2976@RegisterPFor("LessEqual") 2977@RegisterPFor("Lgamma") 2978@RegisterPFor("Log") 2979@RegisterPFor("Log1p") 2980@RegisterPFor("LogicalAnd") 2981@RegisterPFor("LogicalNot") 2982@RegisterPFor("LogicalOr") 2983@RegisterPFor("LogicalXor") 2984@RegisterPFor("Maximum") 2985@RegisterPFor("Minimum") 2986@RegisterPFor("Mod") 2987@RegisterPFor("Mul") 2988@RegisterPFor("MulNoNan") 2989@RegisterPFor("Ndtri") 2990@RegisterPFor("Neg") 2991@RegisterPFor("Polygamma") 2992@RegisterPFor("Pow") 2993@RegisterPFor("Real") 2994@RegisterPFor("RealDiv") 2995@RegisterPFor("Reciprocal") 2996@RegisterPFor("Relu") 2997@RegisterPFor("Relu6") 2998@RegisterPFor("RightShift") 2999@RegisterPFor("Rint") 3000@RegisterPFor("Round") 3001@RegisterPFor("Rsqrt") 3002@RegisterPFor("Selu") 3003@RegisterPFor("Sigmoid") 3004@RegisterPFor("Sign") 3005@RegisterPFor("Sin") 3006@RegisterPFor("Sinh") 3007@RegisterPFor("Softplus") 3008@RegisterPFor("Softsign") 3009@RegisterPFor("Spence") 3010@RegisterPFor("Sqrt") 3011@RegisterPFor("Square") 3012@RegisterPFor("SquaredDifference") 3013@RegisterPFor("Sub") 3014@RegisterPFor("Tan") 3015@RegisterPFor("Tanh") 3016@RegisterPFor("TruncateDiv") 3017@RegisterPFor("TruncateMod") 3018@RegisterPFor("Xdivy") 3019@RegisterPFor("Xlogy") 3020@RegisterPFor("Xlog1py") 3021@RegisterPFor("Zeta") 3022def _convert_cwise(pfor_input): 3023 if pfor_input.num_inputs > 1: 3024 pfor_input.expanddim_inputs_for_broadcast() 3025 3026 out = _create_op( 3027 pfor_input.op_type, [x.t for x in pfor_input.inputs], 3028 [x.dtype for x in pfor_input.outputs], 3029 attrs=pfor_input.op.node_def.attr).outputs 3030 assert len(out) == 1 3031 out = out[0] 3032 3033 op_output = wrap(out, True) 3034 return op_output 3035 3036 3037@RegisterPFor("XlaSharding") 3038def _convert_xla_sharding(pfor_input): 3039 t = pfor_input.stacked_input(0) 3040 sharding = pfor_input.get_attr("sharding") 3041 return wrap(xla.sharding(t, sharding=sharding), True) 3042 3043 3044@RegisterPFor("LeakyRelu") 3045def _convert_leaky_relu(pfor_input): 3046 t = pfor_input.stacked_input(0) 3047 alpha = pfor_input.get_attr("alpha") 3048 return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True) 3049 3050 3051@RegisterPFor("Equal") 3052def _convert_equal(pfor_input): 3053 pfor_input.expanddim_inputs_for_broadcast() 3054 x = pfor_input.input(0)[0] 3055 y = pfor_input.input(1)[0] 3056 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 3057 return wrap(gen_math_ops.equal( 3058 x, y, incompatible_shape_error=incompatible_shape_error), True) 3059 3060 3061@RegisterPFor("NotEqual") 3062def _convert_not_equal(pfor_input): 3063 pfor_input.expanddim_inputs_for_broadcast() 3064 x = pfor_input.input(0)[0] 3065 y = pfor_input.input(1)[0] 3066 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 3067 return wrap(gen_math_ops.not_equal( 3068 x, y, incompatible_shape_error=incompatible_shape_error), True) 3069 3070 3071@RegisterPFor("ApproximateEqual") 3072def _convert_approximate_equal(pfor_input): 3073 pfor_input.expanddim_inputs_for_broadcast() 3074 x = pfor_input.input(0)[0] 3075 y = pfor_input.input(1)[0] 3076 tolerance = pfor_input.get_attr("tolerance") 3077 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) 3078 3079 3080@RegisterPFor("Shape") 3081def _convert_shape(pfor_input): 3082 out_type = pfor_input.get_attr("out_type") 3083 return wrap( 3084 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], 3085 False) 3086 3087 3088@RegisterPFor("ShapeN") 3089def _convert_shape_n(pfor_input): 3090 out_type = pfor_input.get_attr("out_type") 3091 shapes = [ 3092 array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape( 3093 x, out_type=out_type) for x, stacked, _ in pfor_input.inputs 3094 ] 3095 return [wrap(x, False) for x in shapes] 3096 3097 3098@RegisterPFor("Size") 3099def _convert_size(pfor_input): 3100 out_type = pfor_input.get_attr("out_type") 3101 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) 3102 return wrap( 3103 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, 3104 False) 3105 3106 3107@RegisterPFor("Rank") 3108def _convert_rank(pfor_input): 3109 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) 3110 3111 3112@RegisterPFor("AddN") 3113def _convert_addn(pfor_input): 3114 # AddN does not support broadcasting. 3115 pfor_input.stack_inputs(tile_variants=False) 3116 return _wrap_and_tile_variants( 3117 math_ops.add_n([x.t for x in pfor_input.inputs]), 3118 pfor_input.pfor.loop_len_vector) 3119 3120 3121@RegisterPFor("Cross") 3122def _convert_cross(pfor_input): 3123 pfor_input.stack_inputs() 3124 a = pfor_input.stacked_input(0) 3125 b = pfor_input.stacked_input(1) 3126 return wrap(math_ops.cross(a, b), True) 3127 3128 3129@RegisterPFor("BiasAddGrad") 3130def _convert_biasaddgrad(pfor_input): 3131 grad = pfor_input.stacked_input(0) 3132 fmt = pfor_input.get_attr("data_format") 3133 if fmt == b"NCHW": 3134 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) 3135 else: 3136 grad_shape = array_ops.shape(grad) 3137 last_dim_shape = grad_shape[-1] 3138 first_dim_shape = grad_shape[0] 3139 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) 3140 output = math_ops.reduce_sum(output, axis=[1], keepdims=False) 3141 return wrap(output, True) 3142 3143 3144# Some required ops are not exposed under the tf namespace. Hence relying on 3145# _create_op to create them. 3146@RegisterPForWithArgs("EluGrad") 3147@RegisterPForWithArgs("LeakyReluGrad") 3148@RegisterPForWithArgs("ReciprocalGrad") 3149@RegisterPForWithArgs("Relu6Grad") 3150@RegisterPForWithArgs("ReluGrad") 3151@RegisterPForWithArgs("RsqrtGrad") 3152@RegisterPForWithArgs("SeluGrad") 3153@RegisterPForWithArgs("SigmoidGrad") 3154@RegisterPForWithArgs("SoftplusGrad") 3155@RegisterPForWithArgs("SoftsignGrad") 3156@RegisterPForWithArgs("SqrtGrad") 3157@RegisterPForWithArgs("TanhGrad") 3158def _convert_grads(pfor_input, op_type, *args, **kw_args): 3159 del args 3160 del kw_args 3161 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we 3162 # have to use tiling here. 3163 pfor_input.stack_inputs() 3164 outputs = _create_op( 3165 op_type, [x.t for x in pfor_input.inputs], 3166 [x.dtype for x in pfor_input.outputs], 3167 attrs=pfor_input.op.node_def.attr).outputs 3168 return [wrap(x, True) for x in outputs] 3169 3170 3171@RegisterPFor("Select") 3172def _convert_select(pfor_input): 3173 pfor_input.stack_inputs() 3174 cond = pfor_input.stacked_input(0) 3175 t = pfor_input.stacked_input(1) 3176 e = pfor_input.stacked_input(2) 3177 cond_rank = array_ops.rank(cond) 3178 cond, t, e = smart_cond.smart_cond( 3179 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), 3180 lambda: [cond, t, e]) 3181 outputs = _create_op( 3182 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], 3183 attrs=pfor_input.op.node_def.attr).outputs 3184 n = pfor_input.pfor.loop_len_vector 3185 out = smart_cond.smart_cond(cond_rank > 1, 3186 lambda: _unflatten_first_dim(outputs[0], n), 3187 lambda: outputs[0]) 3188 return [wrap(out, True) for x in outputs] 3189 3190 3191@RegisterPFor("SelectV2") 3192def _convert_selectv2(pfor_input): 3193 pfor_input.expanddim_inputs_for_broadcast() 3194 cond = pfor_input.input(0)[0] 3195 t = pfor_input.input(1)[0] 3196 e = pfor_input.input(2)[0] 3197 out = array_ops.where_v2(cond, t, e) 3198 return wrap(out, True) 3199 3200 3201# random_ops 3202 3203 3204def _transpose_dim_to_front(x, dim): 3205 rank = array_ops.rank(x) 3206 return array_ops.transpose( 3207 x, 3208 perm=array_ops.concat( 3209 [[dim], math_ops.range(0, dim), 3210 math_ops.range(dim + 1, rank)], 3211 axis=0)) 3212 3213 3214@RegisterPForWithArgs("RandomUniform") 3215@RegisterPForWithArgs("RandomUniformInt") 3216@RegisterPForWithArgs("RandomStandardNormal") 3217@RegisterPForWithArgs("TruncatedNormal") 3218def _convert_random(pfor_input, op_type, *args, **kw_args): 3219 del args 3220 del kw_args 3221 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] 3222 # inputs[0] is "shape" 3223 inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]], 3224 axis=0) 3225 logging.warning( 3226 "Note that %s inside pfor op may not give same output as " 3227 "inside a sequential loop.", op_type) 3228 outputs = _create_op( 3229 op_type, 3230 inputs, [x.dtype for x in pfor_input.outputs], 3231 attrs=pfor_input.op.node_def.attr).outputs 3232 return [wrap(x, True) for x in outputs] 3233 3234 3235@RegisterPFor("RandomGamma") 3236@RegisterPFor("RandomPoissonV2") 3237def _convert_random_with_param(pfor_input): 3238 shape = pfor_input.unstacked_input(0) 3239 # param is lam (Poisson rate) or alpha (Gamma shape). 3240 param, param_stacked, _ = pfor_input.input(1) 3241 logging.warning( 3242 "Note that %s inside pfor op may not give same output as " 3243 "inside a sequential loop.", pfor_input.op_type) 3244 3245 if param_stacked: 3246 samples = _create_op( 3247 pfor_input.op_type, 3248 inputs=[shape, param], 3249 op_dtypes=[x.dtype for x in pfor_input.outputs], 3250 attrs=pfor_input.op.node_def.attr).outputs[0] 3251 loop_dim = array_ops.shape(shape)[0] 3252 stacked_samples = _transpose_dim_to_front(samples, loop_dim) 3253 else: 3254 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 3255 stacked_samples = _create_op( 3256 pfor_input.op_type, 3257 inputs=[shape, param], 3258 op_dtypes=[x.dtype for x in pfor_input.outputs], 3259 attrs=pfor_input.op.node_def.attr).outputs[0] 3260 3261 return wrap(stacked_samples, True) 3262 3263 3264@RegisterPFor("Multinomial") 3265def _convert_multinomial(pfor_input): 3266 logits, logits_stacked, _ = pfor_input.input(0) 3267 num_samples = pfor_input.unstacked_input(1) 3268 seed = pfor_input.get_attr("seed") 3269 seed2 = pfor_input.get_attr("seed2") 3270 output_dtype = pfor_input.get_attr("output_dtype") 3271 logging.warning( 3272 "Note that Multinomial inside pfor op may not give same output as " 3273 "inside a sequential loop.") 3274 3275 n = pfor_input.pfor.loop_len_vector[0] 3276 if logits_stacked: 3277 flattened_logits = _flatten_first_two_dims(logits) 3278 samples = gen_random_ops.multinomial( 3279 flattened_logits, 3280 num_samples, 3281 seed=seed, 3282 seed2=seed2, 3283 output_dtype=output_dtype) 3284 stacked_samples = _unflatten_first_dim(samples, [n]) 3285 else: 3286 samples = gen_random_ops.multinomial( 3287 logits, 3288 num_samples * n, 3289 seed=seed, 3290 seed2=seed2, 3291 output_dtype=output_dtype) 3292 stacked_samples = array_ops.transpose( 3293 array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2]) 3294 3295 return wrap(stacked_samples, True) 3296 3297 3298@RegisterPFor("StatelessMultinomial") 3299@RegisterPFor("StatelessParameterizedTruncatedNormal") 3300@RegisterPFor("StatelessRandomBinomial") 3301@RegisterPFor("StatelessRandomGammaV2") 3302@RegisterPFor("StatelessRandomNormal") 3303@RegisterPFor("StatelessRandomPoisson") 3304@RegisterPFor("StatelessRandomUniform") 3305@RegisterPFor("StatelessRandomUniformInt") 3306@RegisterPFor("StatelessRandomUniformFullInt") 3307@RegisterPFor("StatelessTruncatedNormal") 3308def _convert_stateless_multinomial(pfor_input): 3309 # Unlike stateful random ops, for stateless ones we want better 3310 # reproducibility based on seed. Hence we don't want to use a similar strategy 3311 # as used for stateful ones where we generate a possibly different set of 3312 # random numbers under vectorization. 3313 # Unfortunately, the kernels currently are not necessarily setup to do this 3314 # efficiently and hence we fallback to a sequential loop for vectorization. 3315 return _fallback_converter(pfor_input, warn=False) 3316 3317 3318# linalg_ops 3319 3320 3321@RegisterPForWithArgs("XlaEinsum") 3322@RegisterPForWithArgs("Einsum") 3323def _convert_einsum(pfor_input, op_type): 3324 # Einsum may have either 1 or 2 inputs. 3325 inputs, input_stacked, _ = zip(*[ 3326 pfor_input.input(i) 3327 for i in range(pfor_input.num_inputs)]) 3328 3329 # Parse the einsum equation. 3330 equation = pfor_input.get_attr("equation").decode("utf-8") 3331 input_expr, output_expr = equation.split("->") 3332 input_exprs = input_expr.split(",") 3333 3334 # Pick a placeholder symbol to use for the new axis. 3335 chosen_symbol = None 3336 for s in string.ascii_letters: 3337 if s in equation: 3338 continue 3339 else: 3340 chosen_symbol = s 3341 break 3342 3343 if chosen_symbol is None: 3344 raise ValueError("Could not figure out what symbol to use for new axis.") 3345 3346 assert any(input_stacked) 3347 for i in range(len(inputs)): 3348 if input_stacked[i]: 3349 input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i]) 3350 output_expr = "{}{}".format(chosen_symbol, output_expr) 3351 3352 new_equation = "{}->{}".format(",".join(input_exprs), output_expr) 3353 3354 if op_type == "XlaEinsum": 3355 if len(inputs) == 1: 3356 result = xla.einsum(equation=new_equation, a=inputs[0]) 3357 else: 3358 result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1]) 3359 else: 3360 assert op_type == "Einsum" 3361 result = special_math_ops.einsum(new_equation, *inputs) 3362 3363 return wrap(result, True) 3364 3365 3366@RegisterPFor("Cholesky") 3367def _convert_cholesky(pfor_input): 3368 t = pfor_input.stacked_input(0) 3369 return wrap(linalg_ops.cholesky(t), True) 3370 3371 3372@RegisterPFor("LogMatrixDeterminant") 3373def _convert_log_matrix_determinant(pfor_input): 3374 t = pfor_input.stacked_input(0) 3375 return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] 3376 3377 3378@RegisterPFor("MatrixInverse") 3379def _convert_matrix_inverse(pfor_input): 3380 t = pfor_input.stacked_input(0) 3381 adjoint = pfor_input.get_attr("adjoint") 3382 return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True) 3383 3384 3385@RegisterPFor("MatrixSolve") 3386def _convert_matrix_solve(pfor_input): 3387 pfor_input.stack_inputs() 3388 matrix = pfor_input.stacked_input(0) 3389 rhs = pfor_input.stacked_input(1) 3390 adjoint = pfor_input.get_attr("adjoint") 3391 output = gen_linalg_ops.matrix_solve( 3392 matrix, rhs, adjoint=adjoint) 3393 return wrap(output, True) 3394 3395 3396@RegisterPFor("MatrixTriangularSolve") 3397def _convert_matrix_triangular_solve(pfor_input): 3398 pfor_input.expanddim_inputs_for_broadcast() 3399 matrix = pfor_input.input(0)[0] 3400 rhs = pfor_input.input(1)[0] 3401 lower = pfor_input.get_attr("lower") 3402 adjoint = pfor_input.get_attr("adjoint") 3403 output = linalg_ops.matrix_triangular_solve( 3404 matrix, rhs, lower=lower, adjoint=adjoint) 3405 return wrap(output, True) 3406 3407 3408@RegisterPFor("SelfAdjointEigV2") 3409def _convert_self_adjoint_eig(pfor_input): 3410 t = pfor_input.stacked_input(0) 3411 compute_v = pfor_input.get_attr("compute_v") 3412 e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v) 3413 # If compute_v is False, v will have shape [0]. 3414 return wrap(e, True), wrap(v, compute_v) 3415 3416 3417# logging_ops 3418 3419 3420@RegisterPFor("Assert") 3421def _convert_assert(pfor_input): 3422 cond, cond_stacked, _ = pfor_input.input(0) 3423 if cond_stacked: 3424 cond = math_ops.reduce_all(cond) 3425 3426 data_list = [x.t for x in pfor_input.inputs][1:] 3427 return _create_op( 3428 "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr) 3429 3430 3431@RegisterPFor("Print") 3432def _convert_print(pfor_input): 3433 # Note that we don't stack all the inputs. Hence unstacked values are printed 3434 # once here vs multiple times in a while_loop. 3435 pfor_input.stack_inputs([0]) 3436 outputs = _create_op( 3437 "Print", [x.t for x in pfor_input.inputs], 3438 [x.dtype for x in pfor_input.outputs], 3439 attrs=pfor_input.op.node_def.attr).outputs 3440 return [wrap(x, True) for x in outputs] 3441 3442 3443@RegisterPFor("PrintV2") 3444def _convert_print_v2(pfor_input): 3445 # Print the full input Tensor(s), including the batch dimension if stacked. 3446 return _create_op( 3447 "PrintV2", [x.t for x in pfor_input.inputs], 3448 [x.dtype for x in pfor_input.outputs], 3449 attrs=pfor_input.op.node_def.attr) 3450 3451 3452@RegisterPFor("StringFormat") 3453def _convert_string_format(pfor_input): 3454 # Format using the full input Tensor(s), including the batch dimension if 3455 # stacked. 3456 op = _create_op( 3457 "StringFormat", [x.t for x in pfor_input.inputs], 3458 [x.dtype for x in pfor_input.outputs], 3459 attrs=pfor_input.op.node_def.attr) 3460 return [wrap(output, False) for output in op.outputs] 3461 3462 3463# data_flow_ops 3464 3465# TensorArray conversion is tricky since we don't support arrays of 3466# TensorArrays. For converting them, we consider two distinct cases: 3467# 3468# 1. The array is constructed outside the pfor call, and read/written inside the 3469# loop. 3470# This is an easier case since we don't need to make an array of TensorArrays. 3471# A correctness requirement is that these parallel iterations shouldn't attempt 3472# to write to the same location. Hence at conversion time we disallow indices to 3473# be loop-invariant as that would guarantee a collision. Even if the indices are 3474# not loop-invariant, they could conflict and that shall trigger runtime errors. 3475# 3476# 2. The array is constructed and used entirely inside each pfor iteration. 3477# For simplicity, here we require that the indices used for write/scatter are 3478# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in 3479# different pfor iterations. We consider two sub_cases: 3480# 3481# 2a Elements written to the array are "stacked" 3482# To simulate multiple TensorArrays, we may increase the dimension of each 3483# element of the array. i.e. the i_th row of the j_th entry of the converted 3484# TensorArray corresponds to the j_th entry of the TensorArray in the i_th 3485# pfor iteration. 3486# 3487# 2b Elements written to the array are "unstacked" 3488# In this case we don't increase the dimensions to avoid redundant tiling. Each 3489# iteration is trying to write the same value. So we convert that to a single 3490# write. 3491# 3492# Here are some tricks used to implement the above: 3493# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of 3494# trying to trace whether future writes are stacked or unstacked in order to set 3495# this attr, we set it to correspond to unknown shape. 3496# - We use the "flow" output of the different ops to track whether the array 3497# elements are stacked or unstacked. If a stacked write/scatter is done, we make 3498# the flow stacked as well. 3499# - We use some heuristic traversal of the graph to track whether the 3500# TensorArray handle was created inside or outside the pfor loop. 3501 3502 3503@RegisterPFor("TensorArrayV3") 3504def _convert_tensor_array_v3(pfor_input): 3505 size = pfor_input.unstacked_input(0) 3506 dtype = pfor_input.get_attr("dtype") 3507 dynamic_size = pfor_input.get_attr("dynamic_size") 3508 clear_after_read = pfor_input.get_attr("clear_after_read") 3509 identical_element_shapes = pfor_input.get_attr("identical_element_shapes") 3510 tensor_array_name = pfor_input.get_attr("tensor_array_name") 3511 handle, flow = data_flow_ops.tensor_array_v3( 3512 size, 3513 dtype=dtype, 3514 # We don't set element shape since we don't know if writes are stacked or 3515 # not yet. 3516 element_shape=None, 3517 dynamic_size=dynamic_size, 3518 clear_after_read=clear_after_read, 3519 identical_element_shapes=identical_element_shapes, 3520 tensor_array_name=tensor_array_name) 3521 # Note we keep flow unstacked for now since we don't know if writes will be 3522 # stacked or not. 3523 return wrap(handle, False), wrap(flow, False) 3524 3525 3526@RegisterPFor("TensorArraySizeV3") 3527def _convert_tensor_array_size_v3(pfor_input): 3528 handle = pfor_input.unstacked_input(0) 3529 flow, flow_stacked, _ = pfor_input.input(1) 3530 if flow_stacked: 3531 flow = _unstack_flow(flow) 3532 size = data_flow_ops.tensor_array_size_v3(handle, flow) 3533 return wrap(size, False) 3534 3535 3536def _handle_inside_pfor(pfor_input, handle): 3537 """Returns True if handle was created inside the pfor loop.""" 3538 # We use some heuristic to find the original TensorArray creation op. 3539 # The logic should handle the common cases (except cond based subgraphs). 3540 # In theory the user could perform different operations on the handle (like 3541 # Reshape, stack multiple handles, etc) which could break this logic. 3542 # TODO(agarwal): handle Switch/Merge. 3543 while handle.op.type in ("Enter", "Identity"): 3544 handle = handle.op.inputs[0] 3545 if handle.op.type not in [ 3546 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape" 3547 ]: 3548 raise ValueError("Unable to find source for handle %s" % handle) 3549 else: 3550 return pfor_input.pfor.op_is_inside_loop(handle.op) 3551 3552 3553def _unstack_flow(value): 3554 # TODO(agarwal): consider looking if this is a Tile op then get its input. 3555 # This may avoid running the Tile operations. 3556 return array_ops.gather(value, 0) 3557 3558 3559@RegisterPFor("TensorArrayReadV3") 3560def _convert_tensor_array_read_v3(pfor_input): 3561 handle = pfor_input.unstacked_input(0) 3562 index, index_stacked, _ = pfor_input.input(1) 3563 dtype = pfor_input.get_attr("dtype") 3564 flow, flow_stacked, _ = pfor_input.input(2) 3565 if flow_stacked: 3566 flow = _unstack_flow(flow) 3567 3568 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3569 if is_inside_pfor: 3570 # Note that if we are inside a control flow construct inside the pfor, and 3571 # only some of the iterations are doing the read (i.e. 3572 # `all_indices_partitioned` is True), then the read operation should only 3573 # return values for the currently active pfor iterations (`all_indices` 3574 # below). Hence, whenever the returned value is stacked (i.e. `flow` is 3575 # stacked), we may need to do an extra gather after reading the values. Also 3576 # note that if `is_inside` is false, then values in the tensor array are 3577 # unstacked. So the check is only needed in this branch. 3578 all_indices = pfor_input.pfor.all_indices 3579 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned 3580 # Note: flow_stacked indicates if values in the TensorArray are stacked or 3581 # not. 3582 if index_stacked: 3583 if flow_stacked: 3584 raise ValueError( 3585 "It looks like TensorArrayReadV3 was called on a TensorArray whose" 3586 " values are not loop-invariant, and the read indices were also" 3587 " not loop invariant. This is currently unsupported.") 3588 value = data_flow_ops.tensor_array_gather_v3( 3589 handle, index, flow, dtype=dtype) 3590 return wrap(value, True) 3591 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3592 if flow_stacked and all_indices_partitioned: 3593 value = array_ops.gather(value, all_indices) 3594 return wrap(value, flow_stacked) 3595 # Values in the TensorArray should be unstacked (since different iterations 3596 # couldn't write to the same location). So whether output is stacked or not 3597 # depends on index_stacked. 3598 if index_stacked: 3599 value = data_flow_ops.tensor_array_gather_v3( 3600 handle, index, flow, dtype=dtype) 3601 else: 3602 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3603 return wrap(value, index_stacked) 3604 3605 3606@RegisterPFor("TensorArrayWriteV3") 3607def _convert_tensor_array_write_v3(pfor_input): 3608 handle = pfor_input.unstacked_input(0) 3609 index, index_stacked, _ = pfor_input.input(1) 3610 value, value_stacked, _ = pfor_input.input(2) 3611 flow, flow_stacked, _ = pfor_input.input(3) 3612 if value_stacked and pfor_input.pfor.all_indices_partitioned: 3613 # Looks like we are in a control flow in a pfor where not all iterations are 3614 # active now. We don't allow that since that could lead to different indices 3615 # having different shapes which will be hard to merge later. 3616 raise ValueError("Writing non loop invariant values to TensorArray from " 3617 "inside a while_loop/cond not supported.") 3618 if flow_stacked: 3619 flow = _unstack_flow(flow) 3620 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3621 if is_inside: 3622 if index_stacked: 3623 raise ValueError("Need indices for %s to be loop invariant" % handle) 3624 if not flow_stacked and not value_stacked: 3625 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3626 return wrap(flow_out, False) 3627 else: 3628 if not value_stacked: 3629 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3630 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then 3631 # this may or may not be a safe situation. flow is unstacked both for a 3632 # freshly created TensorArray, as well as after unstacked values are 3633 # written to it. If it is the latter, then we cannot write a stacked value 3634 # now since that may cause runtime errors due to different shapes in the 3635 # array. At the moment we are not able to handle this gracefully and 3636 # distinguish between the two cases. That would require some heuristic 3637 # traversal of the graph to figure out whether all the writes are 3638 # unstacked or not. 3639 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3640 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3641 else: 3642 if not index_stacked: 3643 raise ValueError("Need indices for %s to be not loop invariant" % handle) 3644 # Note that even when index_stacked is true, actual values in index may 3645 # still not be unique. However that will cause runtime error when executing 3646 # the scatter operation below. 3647 if not value_stacked: 3648 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3649 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) 3650 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3651 3652 3653def _transpose_first_two_dims(value): 3654 # TODO(agarwal): optimize if one of the dims == 1. 3655 value_shape = array_ops.shape(value) 3656 v0 = value_shape[0] 3657 v1 = value_shape[1] 3658 value = array_ops.reshape(value, [v0, v1, -1]) 3659 value = array_ops.transpose(value, [1, 0, 2]) 3660 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) 3661 return array_ops.reshape(value, new_shape) 3662 3663 3664@RegisterPFor("TensorArrayGatherV3") 3665def _convert_tensor_array_gather_v3(pfor_input): 3666 handle = pfor_input.unstacked_input(0) 3667 indices, indices_stacked, _ = pfor_input.input(1) 3668 indices = array_ops.reshape(indices, [-1]) 3669 flow, flow_stacked, _ = pfor_input.input(2) 3670 if flow_stacked: 3671 flow = _unstack_flow(flow) 3672 dtype = pfor_input.get_attr("dtype") 3673 # TODO(agarwal): support element_shape attr? 3674 3675 n = pfor_input.pfor.loop_len_vector 3676 value = data_flow_ops.tensor_array_gather_v3( 3677 handle, indices, flow, dtype=dtype) 3678 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3679 if is_inside: 3680 # flow_stacked indicates if values in the TensorArray are stacked or not. 3681 if indices_stacked: 3682 if flow_stacked: 3683 raise ValueError( 3684 "It looks like TensorArrayGatherV3 was called on a TensorArray " 3685 "whose values are not loop-invariant, and the indices were also " 3686 "not loop invariant. This is currently unsupported.") 3687 else: 3688 value = _unflatten_first_dim(value, n) 3689 return wrap(value, True) 3690 else: 3691 if flow_stacked: 3692 # Since elements in this array are stacked and `value` was produced by 3693 # gather, its first two dims are "gathered elements" and "stack 3694 # dimension". Our semantics require these two to be flipped. 3695 value = _transpose_first_two_dims(value) 3696 return wrap(value, flow_stacked) 3697 else: 3698 # Values in the TensorArray should be unstacked (since different iterations 3699 # couldn't write to the same location). So whether output is stacked or not 3700 # depends on indices_stacked. 3701 if indices_stacked: 3702 value = _unflatten_first_dim(value, n) 3703 return wrap(value, indices_stacked) 3704 3705 3706@RegisterPFor("TensorArrayScatterV3") 3707def _convert_tensor_array_scatter_v3(pfor_input): 3708 handle = pfor_input.unstacked_input(0) 3709 indices, indices_stacked, _ = pfor_input.input(1) 3710 indices = array_ops.reshape(indices, [-1]) 3711 value, value_stacked, _ = pfor_input.input(2) 3712 flow, flow_stacked, _ = pfor_input.input(3) 3713 3714 if flow_stacked: 3715 flow = _unstack_flow(flow) 3716 3717 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3718 if is_inside: 3719 if indices_stacked: 3720 raise ValueError("Need indices for %s to be loop invariant" % handle) 3721 # Note that flow_stacked indicates if existing values in the array are 3722 # stacked or not. 3723 if not flow_stacked and not value_stacked: 3724 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3725 flow) 3726 return wrap(flow_out, False) 3727 if not value_stacked: 3728 # TODO(agarwal): tile in the second dimension directly instead of 3729 # transposing below. 3730 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3731 3732 value = _transpose_first_two_dims(value) 3733 # TODO(agarwal): Note that if a previous write was unstacked, flow will be 3734 # unstacked, and a stacked value may be written here which may cause 3735 # runtime error due to different elements having different shape. We do 3736 # not try to prevent that. 3737 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3738 flow) 3739 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3740 if not indices_stacked: 3741 raise ValueError("Need indices for %s to be not loop invariant" % handle) 3742 if not value_stacked: 3743 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3744 value = _flatten_first_two_dims(value) 3745 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow) 3746 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3747 3748 3749@RegisterPFor("TensorArrayGradV3") 3750def _convert_tensor_array_grad_v3(pfor_input): 3751 handle = pfor_input.unstacked_input(0) 3752 flow, flow_stacked, _ = pfor_input.input(1) 3753 if flow_stacked: 3754 flow = _unstack_flow(flow) 3755 source = pfor_input.get_attr("source") 3756 # TODO(agarwal): For now, we assume that gradients are stacked if the 3757 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong 3758 # will give runtime error due to incorrect shape being written to the 3759 # accumulator. It is difficult to know in advance if gradients written will be 3760 # stacked or not. Note that flow being stacked is not indicative of the 3761 # gradient being stacked or not. Revisit this later. 3762 shape_to_prepend = pfor_input.pfor.loop_len_vector 3763 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( 3764 handle=handle, 3765 flow_in=flow, 3766 shape_to_prepend=shape_to_prepend, 3767 source=source) 3768 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t 3769 return [wrap(grad_handle, False), wrap(flow_out, True)] 3770 3771 3772def _stack_tensor_list_shape(shape, first_dim): 3773 shape_value = tensor_util.constant_value(shape) 3774 # Note that negative values in the shape are used to signify unknown shapes 3775 # and are handled in a special way. 3776 if shape_value is not None: 3777 shape_value = np.asarray(shape_value) 3778 if -1 in shape_value: 3779 return constant_op.constant(-1) 3780 elif not shape_value.size: 3781 return first_dim 3782 else: 3783 shape = array_ops.reshape(shape, [-1]) 3784 return control_flow_ops.cond( 3785 math_ops.reduce_any(shape < 0), 3786 lambda: constant_op.constant(-1), 3787 lambda: array_ops.concat([first_dim, shape], axis=0)) 3788 3789 3790def _tile_variant_with_length(t, length): 3791 """stacks `t` `length` times.""" 3792 if _is_variant_with_internal_stacking(t): 3793 # The content of TensorLists is vectorized, not the variant itself. 3794 return t 3795 original_tensor = t 3796 t.set_shape([]) 3797 t = array_ops.reshape(t, [-1]) 3798 with ops.device("CPU:0"): 3799 result = array_ops.tile(t, length) 3800 # TODO(b/169968286): Should regular shape functions do handle data 3801 # propagation here? 3802 handle_data_util.copy_handle_data(original_tensor, result) 3803 return result 3804 3805 3806def _tile_variant(t, pfor_input): 3807 """stacks `t` according to its loop context.""" 3808 return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector) 3809 3810 3811def _untile_variant(t): 3812 if _is_variant_with_internal_stacking(t): 3813 # The content of TensorLists is vectorized, not the variant itself. 3814 if not t.shape.is_compatible_with([]): 3815 raise AssertionError( 3816 ("Unexpectedly saw a vectorized variant (e.g. TensorList) with " 3817 "non-scalar shape: {!r}").format(t)) 3818 return t 3819 return array_ops.gather(t, 0) 3820 3821 3822@RegisterPFor("OptionalFromValue") 3823def _convert_optional_from_value(pfor_input): 3824 pfor_input.stack_inputs() 3825 return wrap( 3826 gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]), 3827 True) 3828 3829 3830@RegisterPFor("OptionalGetValue") 3831def _convert_optional_get_value(pfor_input): 3832 handle = pfor_input.stacked_input(0) 3833 output_types = pfor_input.get_attr("output_types") 3834 original_output_shapes = pfor_input.get_attr("output_shapes") 3835 output_shapes = [] 3836 for shape in original_output_shapes: 3837 shape = tensor_shape.TensorShape(shape) 3838 loop_len_shape = tensor_shape.TensorShape( 3839 [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)]) 3840 shape = loop_len_shape.concatenate(shape) 3841 output_shapes.append(shape.as_proto()) 3842 results = gen_dataset_ops.optional_get_value(handle, output_types, 3843 output_shapes) 3844 return [wrap(t, True) for t in results] 3845 3846 3847@RegisterPFor("TensorListReserve") 3848def _convert_tensor_list_reserve(pfor_input): 3849 element_shape = pfor_input.unstacked_input(0) 3850 num_elements = pfor_input.unstacked_input(1) 3851 element_dtype = pfor_input.get_attr("element_dtype") 3852 3853 # Prepend a dimension to element_shape. 3854 element_shape = _stack_tensor_list_shape(element_shape, 3855 pfor_input.pfor.loop_len_vector) 3856 handle = list_ops.tensor_list_reserve( 3857 element_shape, num_elements, element_dtype=element_dtype) 3858 3859 return wrap(_tile_variant(handle, pfor_input), True) 3860 3861 3862@RegisterPFor("TensorListElementShape") 3863def _convert_tensor_list_element_shape(pfor_input): 3864 handle = _untile_variant(pfor_input.stacked_input(0)) 3865 shape_type = pfor_input.get_attr("shape_type") 3866 shape = list_ops.tensor_list_element_shape(handle, shape_type) 3867 shape = array_ops.reshape(shape, [-1]) 3868 shape = shape[1:] 3869 return wrap(shape, False) 3870 3871 3872@RegisterPFor("TensorListLength") 3873def _convert_tensor_list_length(pfor_input): 3874 handle = _untile_variant(pfor_input.stacked_input(0)) 3875 return wrap(list_ops.tensor_list_length(handle), False) 3876 3877 3878def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None): 3879 if element_shape is None: 3880 element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) 3881 length = list_ops.tensor_list_length(handle) 3882 new_handle = list_ops.tensor_list_reserve( 3883 _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype) 3884 3885 def _body_fn(i, h): 3886 elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) 3887 elem = _stack(elem, loop_len_vector).t 3888 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 3889 3890 return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn, 3891 [0, new_handle])[1] 3892 3893 3894@RegisterPFor("TensorListGetItem") 3895def _convert_tensor_list_get_item(pfor_input): 3896 handle, handle_stacked, _ = pfor_input.input(0) 3897 index, index_stacked, _ = pfor_input.input(1) 3898 element_shape = pfor_input.unstacked_input(2) 3899 element_dtype = pfor_input.get_attr("element_dtype") 3900 3901 if handle_stacked: 3902 handle = _untile_variant(handle) 3903 element_shape = _stack_tensor_list_shape(element_shape, 3904 pfor_input.pfor.loop_len_vector) 3905 if index_stacked: 3906 # We use a sequential loop since that may be more efficient than first 3907 # gathering and concatenating all the element corresponding to `index`, 3908 # and then doing a gather on it. 3909 def _map_fn(i): 3910 item_i = list_ops.tensor_list_get_item( 3911 handle, 3912 index[i], 3913 element_dtype=element_dtype) 3914 return array_ops.gather(item_i, i) 3915 3916 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 3917 return wrap(output, True) 3918 else: 3919 output = list_ops.tensor_list_get_item( 3920 handle, 3921 index, 3922 element_shape=element_shape, 3923 element_dtype=element_dtype) 3924 return wrap(output, True) 3925 else: 3926 assert index_stacked 3927 return wrap( 3928 list_ops.tensor_list_gather( 3929 handle, 3930 index, 3931 element_shape=element_shape, 3932 element_dtype=element_dtype), True) 3933 3934 3935@RegisterPFor("TensorListSetItem") 3936def _convert_tensor_array_set_item(pfor_input): 3937 handle, handle_stacked, _ = pfor_input.input(0) 3938 index, index_stacked, _ = pfor_input.input(1) 3939 item, item_stacked, _ = pfor_input.input(2) 3940 3941 if not handle_stacked: 3942 # Special case where we can statically guarantee that the indices are 3943 # disjoint. 3944 if index is pfor_input.pfor.all_indices: 3945 if not item_stacked: 3946 item = _stack(item, pfor_input.pfor.loop_len_vector).t 3947 return wrap( 3948 list_ops.tensor_list_scatter(item, index, input_handle=handle), False) 3949 else: 3950 handle = _stack_tensor_list(handle, item.dtype, 3951 pfor_input.pfor.loop_len_vector) 3952 else: 3953 handle = _untile_variant(handle) 3954 3955 if index_stacked: 3956 # TODO(agarwal): handle this. 3957 raise ValueError("Vectorizing writes to a TensorList with loop " 3958 "variant indices is currently unsupported.") 3959 3960 else: 3961 if not item_stacked: 3962 item = _stack(item, pfor_input.pfor.loop_len_vector).t 3963 handle = list_ops.tensor_list_set_item(handle, index, item) 3964 return wrap(_tile_variant(handle, pfor_input), True) 3965 3966 3967@RegisterPFor("TensorListPushBack") 3968def _convert_tensor_list_push_back(pfor_input): 3969 handle, handle_stacked, _ = pfor_input.input(0) 3970 tensor, tensor_stacked, _ = pfor_input.input(1) 3971 if handle_stacked: 3972 handle = _untile_variant(handle) 3973 else: 3974 handle = _stack_tensor_list(handle, tensor.dtype, 3975 pfor_input.pfor.loop_len_vector) 3976 if not tensor_stacked: 3977 tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t 3978 handle = list_ops.tensor_list_push_back(handle, tensor) 3979 return wrap(_tile_variant(handle, pfor_input), True) 3980 3981 3982@RegisterPFor("TensorListPopBack") 3983def _convert_tensor_array_push_back(pfor_input): 3984 handle = pfor_input.stacked_input(0) 3985 element_shape = pfor_input.unstacked_input(1) 3986 handle = _untile_variant(handle) 3987 3988 if element_shape.shape.ndims == 0: 3989 # Default / unspecified 3990 vectorized_shape = -1 3991 else: 3992 # PopBack has an element shape set when it's the gradient of PushBack, only 3993 # used when the list is uninitialized. 3994 vectorized_shape = array_ops.concat( 3995 [pfor_input.pfor.loop_len_vector, element_shape], axis=0) 3996 3997 output_handle, tensor = gen_list_ops.tensor_list_pop_back( 3998 input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"), 3999 element_shape=vectorized_shape) 4000 return wrap(output_handle, True), wrap(tensor, True) 4001 4002 4003@RegisterPFor("TensorListConcatV2") 4004def _convert_tensor_list_concat_v2(pfor_input): 4005 input_handle = pfor_input.stacked_input(0) 4006 element_shape = pfor_input.unstacked_input(1) 4007 leading_dims = pfor_input.unstacked_input(2) 4008 element_dtype = pfor_input.get_attr("element_dtype") 4009 4010 handle = _untile_variant(input_handle) 4011 length = list_ops.tensor_list_length(handle) 4012 # Note that element_shape attribute can have incomplete shapes. This doesn't 4013 # seem to work well when creating another list and then doing a concat on it. 4014 # Hence we try to find the dynamic shape here. 4015 element_shape = control_flow_ops.cond( 4016 length > 0, lambda: array_ops.shape( 4017 list_ops.tensor_list_get_item(handle, 0, element_dtype, None)), 4018 lambda: constant_op.constant([0, 0], dtype=dtypes.int32)) 4019 # The code below creates a copy of the list with each elements' first two 4020 # dimensions transposed. 4021 new_element_shape = array_ops.concat( 4022 [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0) 4023 4024 # Create a new TensorList with elements transposed. 4025 def _transpose_elem(i, h): 4026 elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None) 4027 elem = _transpose_first_two_dims(elem) 4028 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 4029 4030 new_handle = list_ops.tensor_list_reserve(new_element_shape, length, 4031 element_dtype) 4032 new_handle = control_flow_ops.while_loop(lambda i, _: i < length, 4033 _transpose_elem, [0, new_handle])[1] 4034 output, lengths = gen_list_ops.tensor_list_concat_v2( 4035 input_handle=new_handle, 4036 element_dtype=element_dtype, 4037 element_shape=new_element_shape, 4038 leading_dims=leading_dims) 4039 output = _transpose_first_two_dims(output) 4040 return wrap(output, True), wrap(lengths, False) 4041 4042 4043@RegisterPFor("TensorListStack") 4044def _convert_tensor_list_stack(pfor_input): 4045 handle = pfor_input.stacked_input(0) 4046 input_shape = pfor_input.unstacked_input(1) 4047 element_dtype = pfor_input.get_attr("element_dtype") 4048 num_elements = pfor_input.get_attr("num_elements") 4049 4050 handle = _untile_variant(handle) 4051 input_shape = _stack_tensor_list_shape(input_shape, 4052 pfor_input.pfor.loop_len_vector) 4053 output = list_ops.tensor_list_stack( 4054 handle, 4055 element_dtype, 4056 element_shape=input_shape, 4057 num_elements=num_elements) 4058 output = _transpose_first_two_dims(output) 4059 return wrap(output, True) 4060 4061 4062@RegisterPFor("TensorListGather") 4063def _convert_tensor_list_gather(pfor_input): 4064 handle, handle_stacked, _ = pfor_input.input(0) 4065 index, index_stacked, _ = pfor_input.input(1) 4066 element_shape = pfor_input.unstacked_input(2) 4067 element_dtype = pfor_input.get_attr("element_dtype") 4068 4069 if handle_stacked: 4070 handle = _untile_variant(handle) 4071 element_shape = _stack_tensor_list_shape(element_shape, 4072 pfor_input.pfor.loop_len_vector) 4073 if index_stacked: 4074 # We use a sequential loop since that may be more efficient than first 4075 # gathering and concatenating all the element corresponding to `index`, 4076 # and then doing a gather on it. 4077 def _map_fn(i): 4078 item_i = list_ops.tensor_list_gather( 4079 handle, 4080 index[i], 4081 element_dtype=element_dtype) 4082 axis = array_ops.rank(index) - 1 4083 return array_ops.gather(item_i, i, axis=axis) 4084 4085 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 4086 return wrap(output, True) 4087 else: 4088 output = list_ops.tensor_list_gather( 4089 handle, 4090 index, 4091 element_shape=element_shape, 4092 element_dtype=element_dtype) 4093 return wrap(output, True) 4094 else: 4095 assert index_stacked 4096 index_shape = array_ops.shape(index) 4097 index = array_ops.reshape(index, [-1]) 4098 values = list_ops.tensor_list_gather( 4099 handle, index, element_shape=element_shape, element_dtype=element_dtype) 4100 final_shape = array_ops.concat( 4101 [index_shape, array_ops.shape(values)[1:]], axis=0) 4102 return wrap(array_ops.reshape(values, final_shape), True) 4103 4104 4105@RegisterPFor("TensorListScatterIntoExistingList") 4106def _convert_tensor_list_scatter(pfor_input): 4107 pfor_input.stack_inputs([1]) 4108 handle, handle_stacked, _ = pfor_input.input(0) 4109 item = pfor_input.stacked_input(1) 4110 indices, indices_stacked, _ = pfor_input.input(2) 4111 if handle_stacked: 4112 handle = _untile_variant(handle) 4113 else: 4114 handle = _stack_tensor_list(handle, item.dtype, 4115 pfor_input.pfor.loop_len_vector) 4116 4117 item = _transpose_first_two_dims(item) 4118 if indices_stacked: 4119 # Pretend the list is a dense tensor: 4120 # list_as_dense: Tensor[list_len, loop_len, ...] 4121 # And indices are a tensor with shape (before transpose): 4122 # indices: Tensor[loop_len, num_scatters] 4123 # The item to scatter has shape (before transpose): 4124 # item: Tensor[loop_len, num_scatters, ...] 4125 # 4126 # We want list_as_dense[indices[i, j], i] = item[i, j] 4127 # 4128 # Since we're not just indexing along the first axis of `list_as_dense`, we 4129 # need to first extract the relevant list entries based on `indices`, 4130 # scatter into them according to the loop index, and re-scatter the chunks 4131 # we updated back into the list. 4132 indices = _transpose_first_two_dims(indices) 4133 indices_flat = array_ops.reshape(indices, [-1]) 4134 # In many cases `indices` will be unique across pfor iterations, but this is 4135 # not guaranteed. If there are duplicates, we need to map multiple updates 4136 # to a single chunk extracted from the list. The last update should win. 4137 unique_indices = array_ops.unique(indices_flat) 4138 gathered_items = list_ops.tensor_list_gather( 4139 handle, unique_indices.y, element_dtype=item.dtype, 4140 element_shape=array_ops.shape(item)[1:]) 4141 loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0]) 4142 scatters_per_op = array_ops.shape(indices)[0] 4143 4144 unique_indices_loop_idx = array_ops.reshape(array_ops.tile( 4145 loop_idx[None, :], [scatters_per_op, 1]), [-1]) 4146 scatter_indices = array_ops.stack( 4147 [unique_indices.idx, unique_indices_loop_idx], 4148 axis=1) 4149 # This op does *not* guarantee last-update-wins on GPU, so semantics may not 4150 # be exactly preserved for duplicate updates there. 4151 scattered = array_ops.tensor_scatter_nd_update( 4152 tensor=gathered_items, 4153 indices=scatter_indices, 4154 updates=_flatten_first_two_dims(item)) 4155 handle = list_ops.tensor_list_scatter( 4156 scattered, unique_indices.y, input_handle=handle) 4157 else: 4158 handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) 4159 return wrap(_tile_variant(handle, pfor_input), True) 4160 4161 4162@RegisterPFor("TensorListFromTensor") 4163def _convert_tensor_list_from_tensor(pfor_input): 4164 tensor = pfor_input.stacked_input(0) 4165 element_shape = pfor_input.unstacked_input(1) 4166 tensor = _transpose_first_two_dims(tensor) 4167 element_shape = _stack_tensor_list_shape(element_shape, 4168 pfor_input.pfor.loop_len_vector) 4169 handle = list_ops.tensor_list_from_tensor(tensor, element_shape) 4170 return wrap(_tile_variant(handle, pfor_input), True) 4171 4172 4173# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar 4174# to TensorArrays, we convert them by changing the dimension of the elements 4175# inside the stack. 4176# 4177# We consider two cases: 4178# 4179# 1. StackV2 is constructed and used entirely inside the pfor loop. 4180# We keep a single Stack and perform the push/pop operations of all the 4181# iterations in lock-step. We also assume that all the iterations perform these 4182# operations. In case of dynamic control flow, if only some of the iterations 4183# try to perform a push/pop, then the conversion may not work correctly and may 4184# cause undefined behavior. 4185# TODO(agarwal): test StackV2 with dynamic control flow. 4186# 4187# 2. StackV2 is constructed outside the pfor loop. 4188# Performing stack push/pop in a parallel fashion is ill-defined. However given 4189# that reading stacks created externally is a common operation when computing 4190# jacobians, we provide some special semantics here as follows. 4191# - disallow push operations to the stack 4192# - pop operations are performed in lock step by all iterations, similar to the 4193# case when the stack is created inside. A single value is popped during the 4194# lock-step operation and broadcast to all the iterations. Values in the stack 4195# are assumed to be loop-invariant. 4196# 4197# Some other implementation details: 4198# We use an ugly logic to find whether values in Stack data structure are 4199# loop invariant or not. When converting push/pop operations, we keep track of 4200# whether the last conversion used a stacked value or not (see _stack_cache 4201# below). As a result if an unstacked value is written first, subsequent stacked 4202# writes are disallowed when they could have been allowed in theory. 4203 4204# Map from cache key based on StackV2 handle to a bool indicating whether values 4205# are stacked or not. 4206# TODO(agarwal): move _stack_cache inside pfor? 4207_stack_cache = {} 4208 4209 4210def _stack_cache_key(pfor_input): 4211 """Create cache key corresponding to a stack handle.""" 4212 op_type = pfor_input.op_type 4213 assert op_type in ["StackPushV2", "StackPopV2"], op_type 4214 orig_handle = pfor_input.op.inputs[0] 4215 while orig_handle.op.type in ["Identity", "Enter"]: 4216 orig_handle = orig_handle.op.inputs[0] 4217 assert orig_handle.op.type == "StackV2", orig_handle.op 4218 return ops.get_default_graph(), pfor_input.pfor, orig_handle 4219 4220 4221def _stack_handle_inside_pfor(handle, pfor_input): 4222 while handle.op.type in ["Identity", "Enter"]: 4223 handle = handle.op.inputs[0] 4224 assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" % 4225 handle.op) 4226 return pfor_input.pfor.op_is_inside_loop(handle.op) 4227 4228 4229@RegisterPFor("StackPushV2") 4230def _convert_stack_push_v2(pfor_input): 4231 handle = pfor_input.unstacked_input(0) 4232 elem, elem_stacked, _ = pfor_input.input(1) 4233 swap_memory = pfor_input.get_attr("swap_memory") 4234 4235 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): 4236 raise ValueError("StackPushV2 not allowed on stacks created outside pfor") 4237 stack_cache_key = _stack_cache_key(pfor_input) 4238 stacked = _stack_cache.get(stack_cache_key, None) 4239 if stacked is None: 4240 stacked = elem_stacked 4241 _stack_cache[stack_cache_key] = stacked 4242 else: 4243 # If we previously made it unstacked then we can't revert to being stacked. 4244 if not stacked and elem_stacked: 4245 raise ValueError( 4246 "It looks like the stack was previously determined to be loop" 4247 " invariant, but we are now trying to push a loop dependent value" 4248 " to it. This is currently unsupported.") 4249 if stacked and not elem_stacked: 4250 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t 4251 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) 4252 return wrap(out, stacked) 4253 4254 4255# Note that inputs to this convertor will be unstacked. However it should get 4256# called since it is a stateful op. 4257@RegisterPFor("StackPopV2") 4258def _convert_stack_pop_v2(pfor_input): 4259 handle = pfor_input.unstacked_input(0) 4260 stack_cache_key = _stack_cache_key(pfor_input) 4261 stacked = _stack_cache.get(stack_cache_key, None) 4262 # If a StackPushV2 has not been converted yet, we default to unstacked since 4263 # the push could be outside of pfor, or the convertor may not be called if the 4264 # inputs are unconverted. 4265 if stacked is None: 4266 stacked = False 4267 _stack_cache[stack_cache_key] = False 4268 elem_type = pfor_input.get_attr("elem_type") 4269 out = data_flow_ops.stack_pop_v2(handle, elem_type) 4270 return wrap(out, stacked) 4271 4272 4273# parsing_ops 4274 4275 4276@RegisterPFor("DecodeCSV") 4277def _convert_decode_csv(pfor_input): 4278 lines = pfor_input.stacked_input(0) 4279 record_defaults = [ 4280 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 4281 ] 4282 field_delim = pfor_input.get_attr("field_delim") 4283 use_quote_delim = pfor_input.get_attr("use_quote_delim") 4284 select_cols = pfor_input.get_attr("select_cols") 4285 if not select_cols: 4286 select_cols = None 4287 return [ 4288 wrap(t, True) for t in parsing_ops.decode_csv( 4289 lines, 4290 record_defaults, 4291 field_delim=field_delim, 4292 use_quote_delim=use_quote_delim, 4293 select_cols=select_cols) 4294 ] 4295 4296 4297@RegisterPFor("ParseSingleExample") 4298def _convert_parse_single_example(pfor_input): 4299 serialized = pfor_input.stacked_input(0) 4300 dense_defaults = [ 4301 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 4302 ] 4303 sparse_keys = pfor_input.get_attr("sparse_keys") 4304 dense_keys = pfor_input.get_attr("dense_keys") 4305 sparse_types = pfor_input.get_attr("sparse_types") 4306 dense_shapes = pfor_input.get_attr("dense_shapes") 4307 output = gen_parsing_ops.parse_example( 4308 serialized=serialized, 4309 names=[], 4310 dense_defaults=dense_defaults, 4311 sparse_keys=sparse_keys, 4312 dense_keys=dense_keys, 4313 sparse_types=sparse_types, 4314 dense_shapes=dense_shapes) 4315 return [wrap(t, True, True) for t in nest.flatten(output)] 4316 4317 4318@RegisterPFor("ParseExampleV2") 4319def _convert_parse_example_v2(pfor_input): 4320 serialized = pfor_input.stacked_input(0) 4321 sparse_keys = pfor_input.unstacked_input(2) 4322 dense_keys = pfor_input.unstacked_input(3) 4323 ragged_keys = pfor_input.unstacked_input(4) 4324 dense_defaults = [ 4325 pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs) 4326 ] 4327 num_sparse = pfor_input.get_attr("num_sparse") 4328 sparse_types = pfor_input.get_attr("sparse_types") 4329 ragged_value_types = pfor_input.get_attr("ragged_value_types") 4330 ragged_split_types = pfor_input.get_attr("ragged_split_types") 4331 dense_shapes = pfor_input.get_attr("dense_shapes") 4332 if serialized.shape.ndims not in (None, 1): 4333 raise ValueError("ParseExampleV2 can only be converted if `serialized` " 4334 "is scalar.") 4335 output = gen_parsing_ops.parse_example_v2( 4336 serialized=serialized, 4337 names=[], 4338 sparse_keys=sparse_keys, 4339 dense_keys=dense_keys, 4340 ragged_keys=ragged_keys, 4341 dense_defaults=dense_defaults, 4342 num_sparse=num_sparse, 4343 sparse_types=sparse_types, 4344 ragged_value_types=ragged_value_types, 4345 ragged_split_types=ragged_split_types, 4346 dense_shapes=dense_shapes) 4347 return [wrap(t, True, True) for t in nest.flatten(output)] 4348 4349 4350# functional_ops 4351 4352 4353def _convert_function_call(func, converter, inputs): 4354 assert isinstance(func.graph, func_graph.FuncGraph), func 4355 assert isinstance(converter, PFor) 4356 4357 # TODO(agarwal): consider caching this function definition. 4358 @def_function.function 4359 def f(*args): 4360 assert all(isinstance(arg, WrappedTensor) for arg in args), args 4361 assert len(args) == len(func.graph.inputs), (args, func.graph.inputs) 4362 # Map inputs to function arguments. 4363 for inp, arg in zip(func.graph.inputs, args): 4364 converter._add_conversion(inp, arg) 4365 # Convert output tensors. 4366 return tuple( 4367 [converter._convert_helper(x).t for x in func._func_graph_outputs]) 4368 4369 call_outputs = f(*inputs) 4370 assert len(call_outputs) == len(func._func_graph_outputs) 4371 outputs = [] 4372 for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs): 4373 func_output = converter._convert_helper(output_tensor) 4374 outputs.append( 4375 wrap(call_output, func_output.is_stacked, 4376 func_output.is_sparse_stacked)) 4377 return outputs 4378 4379 4380@RegisterPFor("StatefulPartitionedCall") 4381@RegisterPFor("PartitionedCall") 4382def _convert_partitioned_call(pfor_input): 4383 func_name = pfor_input.get_attr("f").name 4384 func = pfor_input.op.graph._get_function(compat.as_bytes(func_name)) 4385 assert isinstance(func.graph, func_graph.FuncGraph), ( 4386 "Could not find FuncGraph object for %s. Got func %s" % (func_name, func)) 4387 pfor = pfor_input.pfor 4388 converter = PFor( 4389 loop_var=pfor.loop_var, 4390 loop_len=pfor.loop_len_vector[0], 4391 pfor_ops=func.graph.get_operations(), 4392 fallback_to_while_loop=pfor.fallback_to_while_loop, 4393 all_indices=pfor.all_indices, 4394 all_indices_partitioned=pfor.all_indices_partitioned, 4395 pfor_config=pfor.pfor_config) 4396 return _convert_function_call(func, converter, pfor_input.inputs) 4397 4398 4399def _partition_inputs_for_indices(inputs, indices): 4400 new_inputs = [] 4401 for inp in inputs: 4402 if inp.is_stacked: 4403 new_inputs.append(wrap(array_ops.gather(inp.t, indices), True)) 4404 else: 4405 new_inputs.append(inp) 4406 return new_inputs 4407 4408 4409def _outputs_for_branch(func_name, indices, pfor_input, inputs): 4410 if indices is None: 4411 indices = pfor_input.pfor.all_indices 4412 partitioned = pfor_input.pfor.all_indices_partitioned 4413 else: 4414 partitioned = True 4415 func = pfor_input.op.graph._get_function(func_name) 4416 converter = PFor( 4417 loop_var=pfor_input.pfor.loop_var, 4418 loop_len=array_ops.size(indices), 4419 pfor_ops=func.graph.get_operations(), 4420 fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop, 4421 all_indices=indices, 4422 all_indices_partitioned=partitioned, 4423 pfor_config=pfor_input.pfor.pfor_config) 4424 outputs = _convert_function_call(func, converter, inputs) 4425 stacked_outputs = [] 4426 for out in outputs: 4427 if not out.is_stacked: 4428 stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t) 4429 else: 4430 stacked_outputs.append(out.t) 4431 return stacked_outputs 4432 4433 4434# TODO(agarwal): Currently the converted code aggressively tiles loop variant 4435# outputs from the then/else branches. Instead, it could do so only if at least 4436# one of the branch outputs is loop variant. 4437@RegisterPFor("StatelessIf") 4438@RegisterPFor("If") 4439def _convert_if(pfor_input): 4440 cond, cond_stacked, _ = pfor_input.input(0) 4441 inputs = pfor_input.inputs[1:] 4442 then_branch = pfor_input.get_attr("then_branch") 4443 else_branch = pfor_input.get_attr("else_branch") 4444 4445 if cond_stacked: 4446 cond_int = math_ops.cast(cond, dtypes.int32) 4447 # Compute loop indices for the different branches 4448 false_indices, true_indices = data_flow_ops.dynamic_partition( 4449 pfor_input.pfor.all_indices, cond_int, 2) 4450 # Compute indices for cond being True or False. 4451 if pfor_input.pfor.all_indices_partitioned: 4452 else_indices, then_indices = data_flow_ops.dynamic_partition( 4453 math_ops.range(pfor_input.pfor.loop_len_vector[0]), 4454 cond_int, 2) 4455 else: 4456 else_indices, then_indices = false_indices, true_indices 4457 # Partition inputs 4458 then_inputs = _partition_inputs_for_indices(inputs, then_indices) 4459 else_inputs = _partition_inputs_for_indices(inputs, else_indices) 4460 4461 # Convert "then" branch. 4462 then_outputs = _outputs_for_branch(then_branch.name, true_indices, 4463 pfor_input, then_inputs) 4464 4465 # Convert "else" branch. 4466 else_outputs = _outputs_for_branch(else_branch.name, false_indices, 4467 pfor_input, else_inputs) 4468 4469 assert len(then_outputs) == len(else_outputs) 4470 # Note that if the "then" and "else" branches are updating the same state, 4471 # and possibly reading them as well, it could lead to undefined behavior 4472 # since the ordering of those operations is not well defined. 4473 # One possibility is to order all the "then" branches to execute before all 4474 # the "else" branches so that the side-effects in the former are visible to 4475 # the latter. For now, we leave that as undefined behavior. 4476 outputs = [] 4477 # Merge outputs 4478 for then_output, else_output in zip(then_outputs, else_outputs): 4479 out = data_flow_ops.dynamic_stitch([then_indices, else_indices], 4480 [then_output, else_output]) 4481 outputs.append(wrap(out, True)) 4482 return outputs 4483 else: 4484 outputs = control_flow_ops.cond( 4485 cond, 4486 lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs), 4487 lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs)) 4488 return [wrap(t, True) for t in outputs] 4489 4490 4491class WhileV2(object): 4492 """Object for vectorizing V2 while_loop op.""" 4493 4494 def __init__(self, pfor_input): 4495 self._pfor_input = pfor_input 4496 self._pfor = pfor_input.pfor 4497 cond_func_name = pfor_input.get_attr("cond").name 4498 self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes( 4499 cond_func_name)) 4500 body_func_name = pfor_input.get_attr("body").name 4501 self._body_func = pfor_input.op.graph._get_function(compat.as_bytes( 4502 body_func_name)) 4503 if self._cond_func is None or self._body_func is None: 4504 raise ValueError("Error extracting cond and body functions for op %s." % ( 4505 self._pfor_input.op)) 4506 # Indices of inputs that are passed unchanged through the while loop body. 4507 # Typically these are tensors captured from outside the body context. 4508 self._body_pass_through_indices = set() 4509 for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs, 4510 self._body_func.graph.outputs)): 4511 if id(inp) == id(out): 4512 self._body_pass_through_indices.add(i) 4513 self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations") 4514 4515 def _output_shapes(self): 4516 # Calculate output shape for vectorized loop. This will be used as 4517 # shape_invariant. Merges shape inference outputs with the `output_shapes` 4518 # attribute of the op. 4519 output_shapes = [out.shape for out in self._pfor_input.op.outputs] 4520 shapes = self._pfor_input.get_attr("output_shapes") 4521 if not shapes: 4522 shapes = [tensor_shape.TensorShape(None) for _ in output_shapes] 4523 else: 4524 shapes = [tensor_shape.TensorShape(shape) for shape in shapes] 4525 for i, shape in enumerate(shapes): 4526 shape = shape.merge_with(output_shapes[i]) 4527 pfor_input = self._pfor_input.input(i) 4528 if pfor_input.is_stacked: 4529 if _is_variant_with_internal_stacking(pfor_input.t): 4530 shape = tensor_shape.TensorShape([]).concatenate(shape) 4531 else: 4532 shape = tensor_shape.TensorShape([None]).concatenate(shape) 4533 output_shapes[i] = shape 4534 assert len(output_shapes) == self._pfor_input.num_inputs 4535 return output_shapes 4536 4537 def _init_values(self): 4538 """Create arguments passed to converted while_loop.""" 4539 loop_len = self._pfor.loop_len_vector[0] 4540 inputs = [] 4541 # TensorArrays for outputs of converted while loop 4542 output_tas = [] 4543 4544 with ops.name_scope("while_init"): 4545 for inp in self._pfor_input.inputs: 4546 inputs.append(inp.t) 4547 variant_type_id = _variant_type_id(inp.t) 4548 if variant_type_id in _INTERNAL_STACKING_TYPE_IDS: 4549 if variant_type_id != full_type_pb2.TFT_ARRAY: 4550 raise NotImplementedError( 4551 ("While loop conversion is only supported for TensorLists. Got " 4552 "another variant {}, probably an optional. Please file a bug.") 4553 .format(inp.t)) 4554 # For TensorLists, the input format is: 4555 # 4556 # List[user_list_len, Tensor[loop_len, ...]] 4557 # 4558 # rather than the usual 4559 # 4560 # Tensor[loop_len, ...] 4561 # 4562 # The body of the loop will take and return lists in this "internal 4563 # vectorization" format, so we want to keep it that way as much as 4564 # possible. We'll accumulate finished iterations (only relevant for 4565 # pfor-loop-variant while_loop conditions) in an accumulator with 4566 # type: 4567 # 4568 # List[user_list_len, List[loop_len, Tensor[...]]] 4569 # 4570 # This means that each while_loop iteration, we'll iterate over the 4571 # length of the TensorList, dividing done/remaining pfor loop indices 4572 # and scattering the done indices into the inner nested list of the 4573 # accumulator. 4574 element_shape = list_ops.tensor_list_element_shape( 4575 inp.t, dtypes.int32) 4576 if inp.is_stacked: 4577 # Shapes may be tf.constant(-1) for fully dynamic, in which case 4578 # slicing is an error. 4579 element_shape = control_flow_ops.cond( 4580 math_ops.equal(array_ops.rank(element_shape), 0), 4581 lambda: element_shape, 4582 lambda: element_shape[1:]) 4583 dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype 4584 4585 def _init_loop_body(index, output_ta): 4586 output_ta = output_ta.write( 4587 index, 4588 list_ops.tensor_list_reserve(element_shape, loop_len, dtype)) 4589 return index + 1, output_ta 4590 4591 length = list_ops.tensor_list_length(inp.t) 4592 output_ta = tensor_array_ops.TensorArray( 4593 inp.t.dtype, # Variant; this is a nested TensorList 4594 size=length, 4595 dynamic_size=True, 4596 infer_shape=False) 4597 _, output_ta = control_flow_ops.while_loop( 4598 lambda index, _: index < length, 4599 _init_loop_body, 4600 [0, output_ta]) 4601 else: 4602 output_ta = tensor_array_ops.TensorArray( 4603 inp.t.dtype, 4604 size=loop_len, 4605 dynamic_size=False, 4606 infer_shape=True) 4607 output_tas.append(output_ta) 4608 # See documentation for __call__ for the structure of init_values. 4609 indices = ( 4610 math_ops.range(self._pfor.loop_len_vector[0]) 4611 if self._pfor.all_indices_partitioned else self._pfor.all_indices) 4612 return [True, indices] + inputs + output_tas 4613 4614 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 4615 """Handles case when condition is pfor loop invariant.""" 4616 # Note that all iterations end together. So we don't need to partition the 4617 # inputs. 4618 not_all_done = array_ops.reshape(conditions, []) 4619 return not_all_done, indices, inputs, output_tas 4620 4621 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 4622 output_tas): 4623 """Handles case when condition is pfor loop dependent.""" 4624 # Compute if all iterations are done. 4625 not_all_done = math_ops.reduce_any(conditions) 4626 conditions_int = math_ops.cast(conditions, dtypes.int32) 4627 # Partition the indices. 4628 done_indices, new_indices = data_flow_ops.dynamic_partition( 4629 indices, conditions_int, 2) 4630 4631 new_inputs = [] 4632 new_output_tas = [] 4633 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 4634 pass_through = i in self._body_pass_through_indices 4635 if not pass_through and _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: 4636 shape_and_type = _parse_variant_shapes_and_types(inp)[0] 4637 element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32) 4638 user_list_len = list_ops.tensor_list_length(inp) 4639 4640 def _split_vectorized_ta_element(index, new_inp, new_out_ta): 4641 elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype, 4642 element_shape) 4643 if stacked: 4644 done_elem, new_elem = data_flow_ops.dynamic_partition( 4645 elem, conditions_int, 2) 4646 new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem) 4647 else: 4648 done_elem = _stack(elem, [array_ops.size(done_indices)]).t 4649 done_accum = new_out_ta.read(index) 4650 done_accum = list_ops.tensor_list_scatter( 4651 tensor=done_elem, indices=done_indices, input_handle=done_accum) 4652 new_out_ta = new_out_ta.write(index, done_accum) 4653 return index + 1, new_inp, new_out_ta 4654 4655 length = list_ops.tensor_list_length(inp) 4656 new_inp = list_ops.tensor_list_reserve( 4657 tensor_shape.TensorShape([None]) 4658 + tensor_shape.TensorShape(shape_and_type.shape)[1:], 4659 user_list_len, shape_and_type.dtype) 4660 _, new_inp, out_ta = control_flow_ops.while_loop( 4661 lambda index, unused_new_inp, unused_new_out_ta: index < length, 4662 _split_vectorized_ta_element, 4663 [0, new_inp, output_tas[i]]) 4664 else: 4665 # Partition the inputs. 4666 if stacked: 4667 done_inp, new_inp = data_flow_ops.dynamic_partition( 4668 inp, conditions_int, 2) 4669 else: 4670 if not pass_through: 4671 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 4672 new_inp = inp 4673 4674 out_ta = output_tas[i] 4675 if not pass_through: 4676 # Note that done_indices can be empty. done_inp should also be empty 4677 # in that case. 4678 out_ta = out_ta.scatter(done_indices, done_inp) 4679 new_inputs.append(new_inp) 4680 new_output_tas.append(out_ta) 4681 4682 assert len(new_output_tas) == len(output_tas) 4683 assert len(new_inputs) == len(inputs) 4684 return not_all_done, new_indices, new_inputs, new_output_tas 4685 4686 def _process_body(self, inputs_stacked, new_indices, cond_stacked, 4687 new_inputs, not_all_done): 4688 """Convert the body function.""" 4689 # This is used to store the indices of inputs to the while op that need to 4690 # be stacked. This stacking may be needed in cases where the input to the 4691 # while_loop is loop_invariant but the corresponding output is not. 4692 mismatching_stacked_indices = [] 4693 4694 def true_fn(): 4695 """Converts the body function for all but last iteration.""" 4696 wrapped_inputs = [wrap(inp, stacked) for inp, stacked in 4697 zip(new_inputs, inputs_stacked)] 4698 # Note the iterative process below to figure out loop invariance. 4699 # Here we iterate on vectorization process till a fixed point. The issue 4700 # is that the while body can take pfor loop invariant inputs but return 4701 # loop variant outputs. For any loop variant output, the corresponding 4702 # input has to be then made loop variant (since subsequent while 4703 # iterations will need to see loop variant values). 4704 # However once we make a new input loop variant, we might make other 4705 # outputs loop variant. Hence we need to iterate till we get fixed point. 4706 while True: 4707 if self._pfor.all_indices_partitioned: 4708 indices = array_ops.gather(self._pfor.all_indices, new_indices) 4709 else: 4710 indices = new_indices 4711 body_pfor = PFor( 4712 loop_var=self._pfor.loop_var, 4713 loop_len=array_ops.size(new_indices), 4714 pfor_ops=self._body_func.graph.get_operations(), 4715 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 4716 all_indices=indices, 4717 all_indices_partitioned=(self._pfor.all_indices_partitioned or 4718 cond_stacked), 4719 pfor_config=self._pfor.pfor_config) 4720 stacking_mismatch = False 4721 outputs = _convert_function_call(self._body_func, 4722 body_pfor, 4723 wrapped_inputs) 4724 for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)): 4725 if out.is_stacked != inp.is_stacked: 4726 stacking_mismatch = True 4727 mismatching_stacked_indices.append(i) 4728 stacked = _stack(inp.t, [array_ops.size(new_indices)]) 4729 if inp.t.dtype == dtypes.variant: 4730 stacked = wrap( 4731 _tile_variant_with_length(stacked.t, 4732 [array_ops.size(new_indices)])) 4733 wrapped_inputs[i] = stacked 4734 if not stacking_mismatch: 4735 if mismatching_stacked_indices: 4736 # We needed to stack some inputs. This code will be abandoned and 4737 # should not get executed. Hence we simply return `new_inputs` to 4738 # make sure the graph construction code completes. 4739 with ops.control_dependencies([ 4740 control_flow_ops.Assert( 4741 False, ["pfor ERROR: this branch should never execute"])]): 4742 return [array_ops.identity(x) for x in new_inputs] 4743 else: 4744 return [out.t for out in outputs] 4745 4746 # If all are done, we simply return `new_inputs`. Else we need to run the 4747 # body function. 4748 return control_flow_ops.cond( 4749 not_all_done, 4750 true_fn, 4751 lambda: list(new_inputs)), mismatching_stacked_indices 4752 4753 def __call__(self): 4754 """Converter for the V2 while_loop. 4755 4756 The conversion of a while_loop is another while_loop. 4757 4758 The arguments to this converted while_loop are as follows: 4759 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 4760 are done. 4761 indices: int32 1-D Tensor storing the id of the pfor iterations that are not 4762 done. 4763 args: Remaining arguments. These can be divided into 2 categories: 4764 - The first set of arguments correspond one-to-one to the inputs to the 4765 unvectorized while_loop. 4766 - The second set are TensorArrays, corresponding one-to-one to each output 4767 of the unvectorized while_loop. Each TensorArray has `PFor.loop_len` 4768 elements, i.e. the number of pfor iterations. At the end, the i'th 4769 element of each TensorArray will contain the output computed by the i'th 4770 iteration of pfor. Note that elements can be written into these tensors 4771 arrays in any order, depending on when the corresponding pfor iteration 4772 is done. 4773 In each iteration, the while_loop body recomputes the condition for all 4774 active pfor iterations to see which of them are now done. It then partitions 4775 all the inputs and passes them along to the converted body. Values for all 4776 the iterations that are done are written to TensorArrays indexed by the pfor 4777 iteration number. When all iterations are done, the TensorArrays are stacked 4778 to get the final value. 4779 4780 Returns: 4781 List of converted outputs. 4782 """ 4783 output_shapes = self._output_shapes() 4784 # Note that we use these lists as a hack since we need the `body` to compute 4785 # these values during construction of the while_loop graph. 4786 cond_is_stacked = [None] 4787 indices_to_stack = [] 4788 4789 def cond(not_all_done, *_): 4790 return not_all_done 4791 4792 def body(not_all_done, indices, *args): 4793 # See documentation for __call__ for the structure of *args. 4794 num_inputs = self._pfor_input.num_inputs 4795 inputs = args[:num_inputs] 4796 output_tas = args[num_inputs:] 4797 inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs] 4798 assert len(inputs) >= len(output_tas) 4799 assert len(inputs) == len(inputs_stacked) 4800 # Convert condition 4801 with ops.name_scope("while_cond"): 4802 # Note that we set all_indices_partitioned to True here. At this point 4803 # we don't know if indices will be partitioned. Hence we use the 4804 # conservative value. 4805 cond_pfor = PFor( 4806 loop_var=self._pfor.loop_var, 4807 loop_len=array_ops.size(indices), 4808 pfor_ops=self._cond_func.graph.get_operations(), 4809 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 4810 all_indices=indices, 4811 all_indices_partitioned=True, 4812 pfor_config=self._pfor.pfor_config) 4813 4814 wrapped_inputs = [wrap(inp, stacked) for inp, stacked 4815 in zip(inputs, inputs_stacked)] 4816 conditions, cond_stacked, _ = _convert_function_call( 4817 self._cond_func, 4818 cond_pfor, 4819 wrapped_inputs)[0] 4820 cond_is_stacked[0] = cond_stacked 4821 4822 # Recompute the new condition, write outputs of done iterations, and 4823 # partition the inputs if needed. 4824 if not cond_stacked: 4825 (not_all_done, new_indices, new_inputs, 4826 new_output_tas) = self._process_cond_unstacked(conditions, indices, 4827 inputs, output_tas) 4828 else: 4829 (not_all_done, new_indices, new_inputs, 4830 new_output_tas) = self._process_cond_stacked(conditions, indices, 4831 inputs, inputs_stacked, 4832 output_tas) 4833 # Convert body 4834 with ops.name_scope("while_body"): 4835 # Compute the outputs from the body. 4836 new_outputs, mismatching_stacked_indices = self._process_body( 4837 inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done) 4838 4839 indices_to_stack[:] = mismatching_stacked_indices 4840 for i, new_output in enumerate(new_outputs): 4841 new_output.set_shape(output_shapes[i]) 4842 new_args = ([not_all_done, new_indices] + new_outputs + 4843 list(new_output_tas)) 4844 return tuple(new_args) 4845 4846 # Note that we run the code below in a function since we might abandon the 4847 # generated code in cases where the conversion dictates that some inputs be 4848 # further stacked. Hence we run the graph construction using 4849 # `get_concrete_function` and avoid calling the constructed function if not 4850 # needed. 4851 @def_function.function 4852 def while_fn(): 4853 # Create init_values that will be passed to the while_loop. 4854 init_values = self._init_values() 4855 ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in 4856 self._pfor_input.outputs] 4857 shape_invariants = ( 4858 [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])] 4859 + output_shapes + ta_shape_invariants) 4860 4861 while_outputs = control_flow_ops.while_loop( 4862 cond, body, init_values, 4863 shape_invariants=shape_invariants, 4864 parallel_iterations=self._parallel_iterations) 4865 if indices_to_stack: 4866 # This function will be abandoned. 4867 return while_outputs 4868 else: 4869 num_inputs = self._pfor_input.num_inputs 4870 new_inputs = while_outputs[2:num_inputs+2] 4871 output_tas = while_outputs[num_inputs+2:] 4872 assert cond_is_stacked[0] is not None 4873 outputs = [] 4874 for i, inp in enumerate(new_inputs): 4875 if cond_is_stacked[0]: 4876 if i in self._body_pass_through_indices: 4877 outputs.append(init_values[i + 2]) 4878 else: 4879 ta = output_tas[i] 4880 if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: 4881 shape_and_type = _parse_variant_shapes_and_types(inp)[0] 4882 length = list_ops.tensor_list_length(inp) 4883 4884 # We have been accumulating values in a: 4885 # 4886 # List[user_list_len, List[loop_len, Tensor[...]]] 4887 # 4888 # We want to return an output in the same format as the input: 4889 # 4890 # List[user_list_len, Tensor[loop_len, ...]] 4891 # 4892 # So we need to loop over the list and stack its contents. 4893 def _stack_loop_body(index, output_list): 4894 current_value = ta.read(index) 4895 output_list = list_ops.tensor_list_set_item( 4896 output_list, index, 4897 list_ops.tensor_list_stack( 4898 current_value, shape_and_type.dtype)) 4899 return index + 1, output_list 4900 4901 output_list = list_ops.tensor_list_reserve( 4902 tensor_shape.TensorShape(shape_and_type.shape), length, 4903 shape_and_type.dtype) 4904 _, output_list = control_flow_ops.while_loop( 4905 lambda index, _: index < length, 4906 _stack_loop_body, 4907 [0, output_list]) 4908 outputs.append(output_list) 4909 else: 4910 outputs.append(ta.stack()) 4911 else: 4912 outputs.append(inp) 4913 return outputs 4914 4915 _ = while_fn.get_concrete_function() 4916 if indices_to_stack: 4917 # Need to abandon the current conversion, stack some inputs and restart. 4918 self._pfor_input.stack_inputs( 4919 stack_indices=indices_to_stack, tile_variants=True) 4920 # Note that this call will recurse at most one time. The first call will 4921 # do the required stacking, based on the iterative procedure in 4922 # _process_body, and the next invocation to __call__ should not need to do 4923 # any more stacking. 4924 # We invoke `self()` here as a way to discard any corrupted state. 4925 return self() 4926 else: 4927 outputs = while_fn() 4928 wrapped_outputs = [] 4929 for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)): 4930 if i not in self._body_pass_through_indices and cond_is_stacked[0]: 4931 wrapped_outputs.append(wrap(out, True)) 4932 else: 4933 wrapped_outputs.append(wrap(out, inp.is_stacked)) 4934 return wrapped_outputs 4935 4936 4937@RegisterPFor("StatelessWhile") 4938@RegisterPFor("While") 4939def _convert_while(pfor_input): 4940 converter = WhileV2(pfor_input) 4941 return converter() 4942 4943 4944# spectral_ops 4945 4946 4947@RegisterPForWithArgs("FFT", gen_spectral_ops.fft) 4948@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d) 4949@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d) 4950@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft) 4951@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d) 4952@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d) 4953def _convert_fft(pfor_input, _, op_func): 4954 return wrap(op_func(pfor_input.stacked_input(0)), True) 4955 4956 4957@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex") 4958@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex") 4959@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex") 4960@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal") 4961@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal") 4962@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal") 4963def _convert_rfft(pfor_input, _, op_func, attr_name): 4964 inp = pfor_input.stacked_input(0) 4965 fft_length = pfor_input.unstacked_input(1) 4966 attr = pfor_input.get_attr(attr_name) 4967 return wrap(op_func(inp, fft_length, attr), True) 4968