1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Compiled parallel-for loop.""" 16# pylint: disable=missing-docstring,g-direct-tensorflow-import 17 18import collections 19import string 20import sys 21import traceback 22 23import numpy as np 24from functools import partial 25 26from tensorflow.compiler.tf2xla.python import xla 27from tensorflow.core.framework import full_type_pb2 28from tensorflow.python.eager import context 29from tensorflow.python.eager import def_function 30from tensorflow.python.eager import execute 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import func_graph 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import smart_cond 36from tensorflow.python.framework import sparse_tensor 37from tensorflow.python.framework import tensor_shape 38from tensorflow.python.framework import tensor_spec 39from tensorflow.python.framework import tensor_util 40from tensorflow.python.ops import array_ops 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.ops import data_flow_ops 43from tensorflow.python.ops import gen_array_ops 44from tensorflow.python.ops import gen_dataset_ops 45from tensorflow.python.ops import gen_image_ops 46from tensorflow.python.ops import gen_linalg_ops 47from tensorflow.python.ops import gen_list_ops 48from tensorflow.python.ops import gen_math_ops 49from tensorflow.python.ops import gen_nn_ops 50from tensorflow.python.ops import gen_parsing_ops 51from tensorflow.python.ops import gen_random_ops 52from tensorflow.python.ops import gen_sparse_ops 53from tensorflow.python.ops import gen_spectral_ops 54from tensorflow.python.ops import handle_data_util 55from tensorflow.python.ops import linalg_ops 56from tensorflow.python.ops import list_ops 57from tensorflow.python.ops import manip_ops 58from tensorflow.python.ops import map_fn 59from tensorflow.python.ops import math_ops 60from tensorflow.python.ops import nn_ops 61from tensorflow.python.ops import parsing_ops 62from tensorflow.python.ops import resource_variable_ops 63from tensorflow.python.ops import sparse_ops 64from tensorflow.python.ops import special_math_ops 65from tensorflow.python.ops import tensor_array_ops 66from tensorflow.python.platform import flags 67from tensorflow.python.platform import tf_logging as logging 68from tensorflow.python.util import compat 69from tensorflow.python.util import nest 70from tensorflow.python.util import object_identity 71 72 73# TODO(agarwal): remove flag. 74flags.DEFINE_bool( 75 "op_conversion_fallback_to_while_loop", True, 76 "DEPRECATED: Flag is ignored.") 77 78 79def _variant_handle_data(t): 80 """Fetches handle data for a variant tensor `t`, or None if unavailable.""" 81 handle_data = resource_variable_ops.get_eager_safe_handle_data(t) 82 if not handle_data.is_set: 83 return None 84 return handle_data.shape_and_type 85 86 87def _variant_type_id(t): 88 """Returns the full_type_pb2 type of `t`, or None if it is not available.""" 89 if t.dtype != dtypes.variant: 90 return None 91 shapes_and_types = _variant_handle_data(t) 92 if shapes_and_types is None or not shapes_and_types: 93 # TODO(b/169968286): Identify all variant tensors (e.g. maps) and we can 94 # make this an error instead of assuming TensorLists have handle data. 95 return None # Presumed not a TensorList/Optional 96 return shapes_and_types[0].type.type_id 97 98 99_INTERNAL_STACKING_TYPE_IDS = ( 100 full_type_pb2.TFT_ARRAY, 101 full_type_pb2.TFT_OPTIONAL) 102 103 104def _is_variant_with_internal_stacking(t): 105 """Identifies variant tensors which pfor always maintains as scalars. 106 107 For these, the pfor tensor is recorded as "stacked" if the content of the 108 variant tensor (e.g. the elements of a TensorList) are all stacked. 109 110 Args: 111 t: A tensor to identify. 112 Returns: 113 True if `t` is a TensorList/Optional, False not, None if unknown. 114 """ 115 type_id = _variant_type_id(t) 116 return type_id in _INTERNAL_STACKING_TYPE_IDS 117 118 119def _parse_variant_shapes_and_types(t): 120 """Extracts shape and dtype information from a variant tensor `t`.""" 121 shapes_and_types = _variant_handle_data(t) 122 if shapes_and_types is None or not shapes_and_types: 123 raise ValueError("Required handle data not set for {!r}".format(t)) 124 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: 125 return shapes_and_types 126 else: 127 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_UNSET: 128 return shapes_and_types 129 else: 130 raise ValueError( 131 "Attempted to stack a variant-dtype tensor with no type set ({!r})" 132 .format(t)) 133 134 135def _stack(t, length): 136 """stacks `t` `length` times.""" 137 # Note that this stacking may currently be triggered, for example, when a 138 # loop invariant tensor with dtype variant is input to a while_loop which then 139 # produces a loop dependent output. Simply stacking the variants may not be 140 # suitable since operations on stacked handles may expect a vectorized version 141 # of the variant. 142 if t.dtype == dtypes.variant: 143 shapes_and_types = _parse_variant_shapes_and_types(t) 144 if shapes_and_types[0].type.type_id == full_type_pb2.TFT_ARRAY: 145 if len(shapes_and_types) != 1: 146 raise ValueError( 147 f"Expected handle data of length 1, got {shapes_and_types!r} of " 148 f"length {len(shapes_and_types)}.") 149 return wrap( 150 _stack_tensor_list(t, shapes_and_types[0].dtype, length), 151 True) 152 else: 153 raise ValueError( 154 "Attempted to stack an unhandled variant-dtype tensor of " 155 f"type {shapes_and_types[0].type!r} ({t!r}).") 156 ones = array_ops.ones_like(array_ops.shape(t)) 157 ones = array_ops.reshape(ones, [-1]) 158 length = array_ops.reshape(length, [-1]) 159 multiples = array_ops.concat([length, ones], 0) 160 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) 161 return wrap(t, True) 162 163 164# The following stateful ops can be safely called once, and with the same 165# signature as the unconverted version, if their inputs are loop invariant. 166# TODO(agarwal): implement a strategy for converting Variable reads/writes. The 167# plan is to map each read/write in the loop_fn to a corresponding merged 168# read/write in the converted graph. Writes need to be mergeable (e.g. 169# AssignAdd) to be used in `pfor`. Given a certain read/write order in the 170# loop_fn, doing a one-to-one conversion will simulate executing such 171# instructions in lock-step across all iterations. 172passthrough_stateful_ops = set([ 173 "VariableV2", 174 "VarHandleOp", 175 "VariableShape", 176 "ReadVariableOp", 177 "StackV2", 178 "TensorArrayWriteV3", 179 "TensorArrayReadV3", 180 "TensorArraySizeV3", 181]) 182 183 184# Ops which we will treat like stateful for the purpose of vectorization. 185# Typically this is used to force pfor converters to run for these ops. 186force_stateful_ops = set([ 187 # We vectorize this since we need to change the element shape set on the 188 # list. 189 "TensorListReserve", 190]) 191 192 193def _is_stateful_pfor_op(op): 194 if isinstance(op, WhileOp): 195 return op.is_stateful 196 if op.type == "Const": 197 # Const didn't have an op_def. 198 return False 199 if op.type in passthrough_stateful_ops: 200 return False 201 if op.type in force_stateful_ops: 202 return True 203 assert hasattr(op, "op_def") and op.op_def is not None, op 204 return op.op_def.is_stateful 205 206 207# pylint: disable=protected-access 208class WhileOp: 209 """Object for storing state for converting the outputs of a while_loop.""" 210 211 def __init__(self, exit_node, pfor_ops, fallback_to_while_loop, pfor_config): 212 """Initializer. 213 214 Args: 215 exit_node: A tensor output from the while_loop. 216 pfor_ops: list of ops inside the current pfor loop. 217 fallback_to_while_loop: If True, fallback to while loop when conversion of 218 an op is not supported 219 pfor_config: PForConfig object used while constructing loop body. 220 """ 221 self._fallback_to_while_loop = fallback_to_while_loop 222 self._pfor_config = pfor_config 223 self._pfor_ops = set(pfor_ops) 224 self._pfor_op_ids = set(x._id for x in pfor_ops) 225 assert isinstance(exit_node, ops.Tensor) 226 self._while_context = exit_node.op._get_control_flow_context() 227 assert isinstance(self._while_context, control_flow_ops.WhileContext) 228 self._context_name = self._while_context.name 229 self._condition = self._while_context.pivot.op.inputs[0] 230 # Parts of an external while_loop could be created inside a pfor loop. 231 # However for the purpose here, we declare such loops to be external. Also 232 # note that we check if the condition was created inside or outside to 233 # determine if the while_loop was first created inside or outside. 234 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. 235 self._is_inside_loop = self.op_is_inside_loop(self._condition.op) 236 if self._is_inside_loop: 237 for e in self._while_context.loop_exits: 238 assert self.op_is_inside_loop(e.op) 239 240 # Note the code below tries to reverse engineer an existing while_loop graph 241 # by assuming the following pattern of nodes. 242 # 243 # NextIteration <---- Body <--- Enter 244 # | ^ 245 # V ___| Y 246 # Enter -> Merge -> Switch___ 247 # ^ | N 248 # | V 249 # LoopCond Exit 250 251 # Node that elements in the list below correspond one-to-one with each 252 # other. i.e. these lists are the same size, and the i_th entry corresponds 253 # to different Operations/Tensors of a single cycle as illustrated above. 254 # List of Switch ops (ops.Operation) that feed into an Exit Node. 255 self._exit_switches = [] 256 # List of inputs (ops.Tensor) to NextIteration. 257 self._body_outputs = [] 258 # List of list of control inputs of the NextIteration nodes. 259 self._next_iter_control_inputs = [] 260 # List of Merge ops (ops.Operation). 261 self._enter_merges = [] 262 # List of output (ops.Tensor) of Exit nodes. 263 self._outputs = [] 264 265 # List of Enter Tensors. 266 # There are two types of Enter nodes: 267 # - The Enter nodes that are used in the `loop_vars` argument to 268 # `while_loop` (see 269 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect 270 # these Enter nodes immediately below by tracing backwards from the Exit 271 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the 272 # diagram above. This allows us to have a 1:1 correspondence between the 273 # self._outputs and the first elements in self._enters. 274 # - The Enter nodes that are used only by the body. They don't appear in the 275 # `loop_vars` and are not returned from the `while_loop`. In Python code, 276 # they are usually captured by the body lambda. We collect them below by 277 # iterating over all the ops in the graph. They are appended to the end of 278 # self._enters or self._direct_enters, and don't correspond to any outputs 279 # in self._outputs. Note that we keep the resource/variant Enter nodes in 280 # self._direct_enters and the constructed while_loop's body uses them 281 # directly as opposed to passing them as loop variables. This is done 282 # because the while_body cannot partition the resource/variant Tensors, so 283 # it has to leave them unchanged. 284 self._enters = [] 285 self._direct_enters = [] 286 287 for e in self._while_context.loop_exits: 288 self._outputs.append(e.op.outputs[0]) 289 switch = e.op.inputs[0].op 290 assert switch.type == "Switch", switch 291 self._exit_switches.append(switch) 292 merge = switch.inputs[0].op 293 assert merge.type == "Merge", merge 294 self._enter_merges.append(merge) 295 enter = merge.inputs[0].op 296 assert enter.type == "Enter", enter 297 self._enters.append(enter.outputs[0]) 298 next_iter = merge.inputs[1].op 299 assert next_iter.type == "NextIteration", next_iter 300 self._body_outputs.append(next_iter.inputs[0]) 301 self._next_iter_control_inputs.append(next_iter.control_inputs) 302 303 # Collect all the Enter nodes that are not part of `loop_vars`, the second 304 # category described above. 305 # Also track whether the loop body has any stateful ops. 306 self._is_stateful = False 307 for op in ops.get_default_graph().get_operations(): 308 # TODO(agarwal): make sure this works with nested case. 309 control_flow_context = op._get_control_flow_context() 310 if control_flow_context is None: 311 continue 312 if control_flow_context.name == self._context_name: 313 self._is_stateful |= _is_stateful_pfor_op(op) 314 if op.type == "Enter": 315 output = op.outputs[0] 316 if output not in self._enters: 317 if output.dtype in (dtypes.resource, dtypes.variant): 318 if output not in self._direct_enters: 319 self._direct_enters.append(output) 320 else: 321 self._enters.append(output) 322 323 def __str__(self): 324 """String representation.""" 325 return "while_loop(%s)" % self.name 326 327 @property 328 def inputs(self): 329 """Input to all the Enter nodes.""" 330 return [x.op.inputs[0] for x in self._enters + self._direct_enters] 331 332 @property 333 def control_inputs(self): 334 """Control input to all the Enter nodes.""" 335 control_inputs = [] 336 for x in self._enters + self._direct_enters: 337 control_inputs.extend(x.op.control_inputs) 338 return control_inputs 339 340 @property 341 def outputs(self): 342 """Outputs of all the Exit nodes.""" 343 return self._outputs 344 345 @property 346 def name(self): 347 """Context name for the while loop.""" 348 return self._context_name 349 350 @property 351 def is_inside_loop(self): 352 """Returns true if the while_loop was created inside the pfor.""" 353 return self._is_inside_loop 354 355 def op_is_inside_loop(self, op): 356 """True if op was created inside the pfor loop body.""" 357 assert isinstance(op, ops.Operation) 358 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 359 # since it appears there tensorflow API could return different python 360 # objects representing the same Operation node. 361 return op._id in self._pfor_op_ids 362 363 @property 364 def is_stateful(self): 365 return self._is_stateful 366 367 @property 368 def pfor_converter(self): 369 """Return a converter for the while loop.""" 370 return self 371 372 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, 373 inputs_stacked): 374 """Create a PFor object for converting parts of the while_loop. 375 376 Args: 377 parent_pfor: PFor object being used for converting the while_loop. 378 indices: int32 Tensor of ids for the iterations that are still active 379 (i.e. did not exit the while_loop). 380 cond_stacked: True if the while_loop condition is stacked. 381 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note 382 that these Tensors are a subset of the loop variables for the generated 383 while_loop. 384 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, 385 indicating if the value is stacked or not. 386 387 Returns: 388 A PFor instance. The instance is initialized by adding conversion mappings 389 of nodes that will be external to the conversion that the returned 390 instance will be used for. e.g. Enter nodes as well as Merge and Switch 391 outputs are mapped to converted values. 392 """ 393 num_outputs = len(self._outputs) 394 assert len(inputs) == len(self._enters) 395 assert len(inputs_stacked) == len(self._enters) 396 loop_var = parent_pfor.loop_var 397 loop_len = array_ops.size(indices) 398 pfor = PFor( 399 loop_var, 400 loop_len, 401 pfor_ops=self._pfor_ops, 402 all_indices=indices, 403 all_indices_partitioned=cond_stacked, 404 fallback_to_while_loop=self._fallback_to_while_loop, 405 pfor_config=self._pfor_config) 406 # Map all inputs of Enter nodes in self._direct_enters to their converted 407 # values. 408 for enter in self._direct_enters: 409 enter_input = enter.op.inputs[0] 410 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( 411 enter_input) 412 # Since these are resources / variants, they should be unstacked. 413 assert not stacked and not is_sparse_stacked, (enter, converted_enter) 414 pfor._add_conversion(enter, wrap(converted_enter, False)) 415 416 # Map all Enter nodes to the inputs. 417 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): 418 pfor._add_conversion(enter, wrap(inp, stacked)) 419 # Map outputs of Switch and Merge. 420 for i in range(num_outputs): 421 wrapped_inp = wrap(inputs[i], inputs_stacked[i]) 422 merge = self._enter_merges[i] 423 pfor._add_conversion(merge.outputs[0], wrapped_inp) 424 # Note that second output of Merge is typically not used, except possibly 425 # as a control dependency. To avoid trying to output the correct value, we 426 # employ a hack here. We output a dummy invalid value with an incorrect 427 # dtype. This will allow control dependency to work but if using it as an 428 # input, it should typically lead to errors during graph construction due 429 # to dtype mismatch. 430 # TODO(agarwal): Check in the original graph to see if there are any 431 # consumers of this Tensor that use it as an input. 432 pfor._add_conversion(merge.outputs[1], 433 wrap(constant_op.constant(-1.0), False)) 434 switch = self._exit_switches[i] 435 # Don't need to worry about switch.output[0] which will feed to Exit node. 436 pfor._add_conversion(switch.outputs[1], wrapped_inp) 437 return pfor 438 439 def _convert_enter(self, parent_pfor, enter): 440 """Converts an Enter node.""" 441 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) 442 control_inputs = [] 443 for x in enter.op.control_inputs: 444 converted = parent_pfor._convert_helper(x) 445 if not isinstance(converted, ops.Operation): 446 converted = converted.t 447 control_inputs.append(converted) 448 if control_inputs: 449 with ops.control_dependencies(control_inputs): 450 inp = array_ops.identity(inp) 451 return inp, stacked 452 453 def _maybe_stacked(self, cache, inp): 454 """Heuristic to figure out if the converting inp leads to a stacked value. 455 456 457 Args: 458 cache: map from Tensor to boolean indicating stacked/unstacked. 459 inp: input Tensor. 460 461 Returns: 462 True if `inp` could get stacked. If the function returns False, the 463 converted value should be guaranteed to be unstacked. If returning True, 464 it may or may not be stacked. 465 """ 466 if inp in cache: 467 return cache[inp] 468 if not self.op_is_inside_loop(inp.op): 469 return False 470 op = inp.op 471 output = False 472 if op.type in [ 473 "Shape", 474 "Rank", 475 "ShapeN", 476 "ZerosLike", 477 "TensorArrayV3", 478 "TensorArraySizeV3", 479 ]: 480 output = False 481 elif _is_stateful_pfor_op(op): 482 # This may be fairly aggressive. 483 output = True 484 elif op.type == "Exit": 485 # This may be fairly aggressive. 486 output = True 487 else: 488 for t in op.inputs: 489 if self._maybe_stacked(cache, t): 490 output = True 491 break 492 cache[inp] = output 493 return output 494 495 def _create_init_values(self, pfor_input): 496 """Create arguments passed to converted while_loop.""" 497 with ops.name_scope("while_init"): 498 loop_len_vector = pfor_input.pfor.loop_len_vector 499 loop_len = loop_len_vector[0] 500 num_outputs = len(self._outputs) 501 502 inputs = [] 503 maybe_stacked_cache = {} 504 # Convert all the Enters. Need to do this before checking for stacking 505 # below. 506 for i, enter in enumerate(self._enters): 507 inp, stacked = self._convert_enter(pfor_input.pfor, enter) 508 inputs.append(inp) 509 maybe_stacked_cache[enter] = stacked 510 # Since this enter node is part of the `loop_vars`, it corresponds to an 511 # output and its preceding switch. We mark this switch's output the same 512 # stackness, to act at the base case for the logic below. Below, we will 513 # be going through the body figuring out which inputs might need to be 514 # stacked and which inputs can safely remain unstacked. 515 if i < num_outputs: 516 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked 517 518 # Shape invariants for init_values corresponding to self._enters. 519 input_shape_invariants = [] 520 # TensorArrays for outputs of converted while loop 521 output_tas = [] 522 # Shape invariants for output TensorArrays. 523 ta_shape_invariants = [] 524 # List of booleans indicating stackness of inputs, i.e. tensors 525 # corresponding to self._enters. 526 inputs_stacked = [] 527 for i, inp in enumerate(inputs): 528 enter = self._enters[i] 529 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) 530 # Note that even when an input is unstacked, the body could make it 531 # stacked. we use a heuristic below to figure out if body may be making 532 # it stacked. 533 if i < num_outputs: 534 body_output = self._body_outputs[i] 535 if enter.op in self._pfor_ops: 536 body_output_stacked = self._maybe_stacked(maybe_stacked_cache, 537 body_output) 538 else: 539 # If constructed outside of pfor loop, then the output would not be 540 # stacked. 541 body_output_stacked = False 542 if body_output_stacked and not inp_stacked: 543 inp = _stack(inp, loop_len_vector).t 544 inputs[i] = inp 545 inp_stacked = True 546 # TODO(agarwal): other attributes for the TensorArray ? 547 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) 548 ta_shape_invariants.append(tensor_shape.TensorShape(None)) 549 550 inputs_stacked.append(inp_stacked) 551 input_shape_invariants.append(tensor_shape.TensorShape(None)) 552 553 # See documentation for __call__ for the structure of init_values. 554 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas 555 # TODO(agarwal): try stricter shape invariants 556 shape_invariants = ( 557 [tensor_shape.TensorShape(None), 558 tensor_shape.TensorShape(None)] + input_shape_invariants + 559 ta_shape_invariants) 560 561 return init_values, inputs_stacked, shape_invariants 562 563 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 564 """Handles case when condition is unstacked. 565 566 Note that all iterations end together. So we don't need to partition the 567 inputs. When all iterations are done, we write the inputs to the 568 TensorArrays. Note that we only write to index 0 of output_tas. Since all 569 iterations end together, they can all be output together. 570 """ 571 not_all_done = array_ops.reshape(conditions, []) 572 new_output_tas = [] 573 # pylint: disable=cell-var-from-loop 574 for i, out_ta in enumerate(output_tas): 575 inp = inputs[i] 576 new_output_tas.append( 577 control_flow_ops.cond(not_all_done, lambda: out_ta, 578 lambda: out_ta.write(0, inp))) 579 # pylint: enable=cell-var-from-loop 580 return not_all_done, indices, inputs, new_output_tas 581 582 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 583 output_tas): 584 num_outputs = len(self._outputs) 585 # Compute if all iterations are done. 586 not_all_done = math_ops.reduce_any(conditions) 587 conditions_int = math_ops.cast(conditions, dtypes.int32) 588 # Partition the indices. 589 done_indices, new_indices = data_flow_ops.dynamic_partition( 590 indices, conditions_int, 2) 591 592 new_inputs = [] 593 new_output_tas = [] 594 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 595 # Partition the inputs. 596 if stacked: 597 done_inp, new_inp = data_flow_ops.dynamic_partition( 598 inp, conditions_int, 2) 599 else: 600 # TODO(agarwal): avoid this stacking. See TODO earlier in 601 # _process_cond_unstacked. 602 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 603 new_inp = inp 604 new_inputs.append(new_inp) 605 # For iterations that are done, write them to TensorArrays. 606 if i < num_outputs: 607 out_ta = output_tas[i] 608 # Note that done_indices can be empty. done_inp should also be empty in 609 # that case. 610 new_output_tas.append(out_ta.scatter(done_indices, done_inp)) 611 return not_all_done, new_indices, new_inputs, new_output_tas 612 613 def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked, 614 new_inputs, not_all_done): 615 """Convert the body function.""" 616 617 def true_fn(control_inputs, body_pfor, body_output, stacked): 618 """Converts the body function for all but last iteration. 619 620 This essentially converts body_output. Additionally, it needs to handle 621 any control dependencies on the NextIteration node. So it creates another 622 Identity node with the converted dependencies. 623 """ 624 converted_control_inp = [] 625 for x in control_inputs: 626 for t in x.outputs: 627 converted_control_inp.append(body_pfor._convert_helper(t).t) 628 if stacked: 629 # Note convert always does the stacking. 630 output = body_pfor.convert(body_output) 631 else: 632 output, convert_stacked, _ = body_pfor._convert_helper(body_output) 633 assert convert_stacked == stacked, body_output 634 with ops.control_dependencies(converted_control_inp): 635 return array_ops.identity(output) 636 637 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked, 638 new_inputs, inputs_stacked) 639 new_outputs = [] 640 641 for i, (body_output, 642 stacked) in enumerate(zip(self._body_outputs, inputs_stacked)): 643 control_inp = self._next_iter_control_inputs[i] 644 out_dtype = body_output.dtype 645 # Note that we want to run the body only if not all pfor iterations are 646 # done. If all are done, we return empty tensors since these values will 647 # not be used. Notice that the value returned by the loop is based on 648 # TensorArrays and not directly on these returned values. 649 # pylint: disable=cell-var-from-loop 650 new_output = control_flow_ops.cond( 651 not_all_done, 652 lambda: true_fn(control_inp, body_pfor, body_output, stacked), 653 lambda: constant_op.constant([], dtype=out_dtype)) 654 # pylint: enable=cell-var-from-loop 655 new_outputs.append(new_output) 656 return new_outputs 657 658 def __call__(self, pfor_input): 659 """Converter for the while_loop. 660 661 The conversion of a while_loop is another while_loop. 662 663 The arguments to this converted while_loop are as follows: 664 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 665 are done. 666 indices: int32 1-D Tensor storing the id of the iterations that are not 667 done. 668 args: Remaining arguments. These can be divided into 3 categories: 669 - First set of arguments are the tensors that correspond to the initial 670 elements of self._enters. The elements that appear in original while 671 loop's `loop_vars`. 672 - The second set of arguments are the tensors that correspond to the 673 remaining elements of self._enters. These are the tensors that directly 674 enter the original while loop body. 675 - Finally, the last set of arguments are TensorArrays. These TensorArrays 676 correspond to the outputs of the original while_loop, i.e. to the 677 elements in self._outputs. Each TensorArray has `PFor.loop_len` 678 elements, i.e. the number of pfor iterations. At the end, the i'th 679 element of each TensorArray will contain the output computed by the 680 i'th iteration of pfor. Note that elements can be written into these 681 tensors arrays in any order, depending on when the corresponding pfor 682 iteration is done. 683 If the original while_loop had `k` tensors in its `loop_vars` and its body 684 directly captured `m` tensors, the `args` will contain `2 * k + m` values. 685 686 In each iteration, the while_loop body recomputes the condition for all 687 active pfor iterations to see which of them are now done. It then partitions 688 all the inputs and passes them along to the converted body. Values for all 689 the iterations that are done are written to TensorArrays indexed by the pfor 690 iteration number. When all iterations are done, the TensorArrays are stacked 691 to get the final value. 692 693 Args: 694 pfor_input: A PForInput object corresponding to the output of any Exit 695 node from this while loop. 696 697 Returns: 698 List of converted outputs. 699 """ 700 # Create init_values that will be passed to the while_loop. 701 init_values, inputs_stacked, shape_invariants = self._create_init_values( 702 pfor_input) 703 # Note that we use a list as a hack since we need the nested function body 704 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal 705 # variables. 706 cond_is_stacked = [None] 707 708 def cond(not_all_done, *_): 709 return not_all_done 710 711 def body(not_all_done, indices, *args): 712 # See documentation for __call__ for the structure of *args. 713 num_enters = len(self._enters) 714 inputs = args[:num_enters] 715 output_tas = args[num_enters:] 716 # TODO(agarwal): see which outputs have consumers and only populate the 717 # TensorArrays corresponding to those. Or do those paths get trimmed out 718 # from inside the while_loop body? 719 assert len(inputs) >= len(output_tas) 720 assert len(inputs) == len(inputs_stacked) 721 722 # Convert condition 723 with ops.name_scope("while_cond"): 724 # Note that we set cond_stacked to True here. At this point we don't 725 # know if it could be loop invariant, hence the conservative value is 726 # to assume stacked. 727 cond_pfor = self._init_pfor( 728 pfor_input.pfor, 729 indices, 730 cond_stacked=True, 731 inputs=inputs, 732 inputs_stacked=inputs_stacked) 733 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) 734 cond_is_stacked[0] = cond_stacked 735 736 # Recompute the new condition, write outputs of done iterations, and 737 # partition the inputs if needed. 738 if not cond_stacked: 739 (not_all_done, new_indices, new_inputs, 740 new_output_tas) = self._process_cond_unstacked(conditions, indices, 741 inputs, output_tas) 742 else: 743 (not_all_done, new_indices, new_inputs, 744 new_output_tas) = self._process_cond_stacked(conditions, indices, 745 inputs, inputs_stacked, 746 output_tas) 747 748 # Convert body 749 with ops.name_scope("while_body"): 750 # Compute the outputs from the body. 751 new_outputs = self._process_body(pfor_input, inputs_stacked, 752 new_indices, cond_stacked, new_inputs, 753 not_all_done) 754 755 # Note that the first num_outputs new values of inputs are computed using 756 # the body. Rest of them were direct Enters into the condition/body and 757 # the partitioning done earlier is sufficient to give the new value. 758 num_outputs = len(self._outputs) 759 new_args = ([not_all_done, new_indices] + new_outputs + 760 list(new_inputs[num_outputs:]) + new_output_tas) 761 return tuple(new_args) 762 763 while_outputs = control_flow_ops.while_loop( 764 cond, body, init_values, shape_invariants=shape_invariants) 765 output_tas = while_outputs[-len(self._outputs):] 766 outputs = [] 767 assert cond_is_stacked[0] is not None 768 for inp_stacked, ta in zip(inputs_stacked, output_tas): 769 if cond_is_stacked[0]: 770 outputs.append(wrap(ta.stack(), True)) 771 else: 772 # Note that if while_loop condition is unstacked, all iterations exit at 773 # the same time and we wrote those outputs in index 0 of the tensor 774 # array. 775 outputs.append(wrap(ta.read(0), inp_stacked)) 776 return outputs 777 778 779class ConversionNotImplementedError(Exception): 780 pass 781 782 783class _PforInput: 784 """Input object passed to registered pfor converters.""" 785 786 __slots__ = ["pfor", "_op", "_inputs"] 787 788 def __init__(self, pfor, op, inputs): 789 """Creates a _PforInput object. 790 791 Args: 792 pfor: PFor converter object. 793 op: the Operation object that is being converted. 794 inputs: list of WrappedTensor objects representing converted values of the 795 inputs of `op`. 796 """ 797 self.pfor = pfor 798 self._op = op 799 self._inputs = inputs 800 801 def stack_inputs(self, stack_indices=None, tile_variants=False): 802 """Stacks unstacked inputs at `stack_indices`. 803 804 Args: 805 stack_indices: indices of inputs at which stacking is done. If None, 806 stacking is done at all indices. 807 tile_variants: If True, affected indices which have a variant dtype will 808 be tiled after this operation to match the expected shape of a 809 vectorized tensor. Variants generally need to be un-tiled when they are 810 inputs to operations and tiled when returned. 811 """ 812 if stack_indices is None: 813 stack_indices = range(len(self._inputs)) 814 length = self.pfor.loop_len_vector 815 for i in stack_indices: 816 inp = self._inputs[i] 817 is_variant = inp.t.dtype == dtypes.variant 818 if not inp.is_stacked: 819 self._inputs[i] = _stack(inp.t, length) 820 if tile_variants and is_variant: 821 self._inputs[i] = wrap( 822 _tile_variant_with_length(self._inputs[i].t, length), True) 823 elif not tile_variants and is_variant: 824 self._inputs[i] = wrap(_untile_variant(self._inputs[i].t), True) 825 826 def expanddim_inputs_for_broadcast(self): 827 """Reshapes stacked inputs to prepare them for broadcast. 828 829 Since stacked inputs have an extra leading dimension, automatic broadcasting 830 rules could incorrectly try to expand dimensions before that leading 831 dimension. To avoid that, we reshape these stacked inputs to the maximum 832 rank they will need to be broadcasted to. 833 """ 834 if not self._inputs: 835 return 836 837 # Find max rank 838 def _get_rank(x): 839 rank = array_ops.rank(x.t) 840 if not x.is_stacked: 841 rank += 1 842 return rank 843 844 ranks = [_get_rank(x) for x in self._inputs] 845 max_rank = ranks[0] 846 for rank in ranks[1:]: 847 max_rank = math_ops.maximum(rank, max_rank) 848 849 for i, inp in enumerate(self._inputs): 850 if inp.is_stacked: 851 shape = array_ops.shape(inp.t) 852 rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) 853 ones = array_ops.tile([1], rank_diff) 854 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) 855 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) 856 857 @property 858 def inputs(self): 859 return self._inputs 860 861 @property 862 def num_inputs(self): 863 return len(self._inputs) 864 865 def input(self, index): 866 assert len(self._inputs) > index, (index, self._inputs) 867 return self._inputs[index] 868 869 def stacked_input(self, index): 870 t, is_stacked, _ = self.input(index) 871 if not is_stacked: 872 op_type = self.op_type 873 op_def = getattr(self._op, "op_def", None) 874 if op_def is None: 875 input_name = "at index %d" % index 876 else: 877 input_name = "\"%s\"" % op_def.input_arg[index].name 878 raise ConversionNotImplementedError( 879 f"Input {input_name} of op '{op_type}' expected to be not loop " 880 "invariant.") 881 return t 882 883 def unstacked_input(self, index): 884 t, is_stacked, _ = self.input(index) 885 if is_stacked: 886 op_type = self.op_type 887 op_def = getattr(self._op, "op_def", None) 888 if op_def is None: 889 input_name = "at index %d" % index 890 else: 891 input_name = "\"%s\"" % op_def.input_arg[index].name 892 raise ConversionNotImplementedError( 893 f"Input {input_name} of op '{op_type}' expected to be loop " 894 "invariant.") 895 return t 896 897 @property 898 def op(self): 899 return self._op 900 901 @property 902 def op_type(self): 903 return self._op.type 904 905 def get_attr(self, attr): 906 return self._op.get_attr(attr) 907 908 @property 909 def outputs(self): 910 return self._op.outputs 911 912 def output(self, index): 913 assert index < len(self._op.outputs) 914 return self._op.outputs[index] 915 916 917_pfor_converter_registry = {} 918 919 920class RegisterPFor: 921 """Utility to register converters for pfor. 922 923 Usage: 924 @RegisterPFor(foo_op_type) 925 def _foo_converter(pfor_input): 926 ... 927 928 The above will register conversion function `_foo_converter` for handling 929 conversion of `foo_op_type`. These converters are called during vectorization 930 of a `pfor` loop body. For each operation node in this loop body, 931 the vectorization process will call the converter corresponding to the 932 operation type of the node. 933 934 During conversion, the registered function will be called with a single 935 argument `pfor_input`, of type `PForInput`, which will contain state needed 936 for the conversion. When the converter is called for a node, all its inputs 937 should already have been converted and these converted values are stored in 938 `pfor_input.inputs`. This registered function should output a list of 939 WrappedTensor objects with the same length as the number of outputs of the 940 node being converted. If the node had zero outputs, then it should return an 941 ops.Operation object. These new sets of nodes should implement the 942 functionality of running that operation for the number of iterations specified 943 by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each 944 iteration are picked from `pfor_inputs.inputs()`. 945 946 One tricky aspect of the conversion process is keeping track of, and 947 leveraging loop invariance of computation. Each converted input is a 948 WrappedTensor which indicates whether the input was loop invariant or not. If 949 the converted value is loop invariant, its rank should match the rank of the 950 corresponding tensor in the loop body, else its rank is larger by 1. The 951 converter should look at the loop invariance of the inputs and generate new 952 nodes based on that. Note that the converter will not be called if all inputs 953 are loop invariant and the operation is not stateful. The converter should 954 determine if its own output is loop invariant and `wrap` its output 955 accordingly. 956 957 Example: 958 959 Here, the converter is trying to convert a Reshape node in the loop body. This 960 node will have two inputs: the tensor to reshape, and the new shape. The 961 example here only handles the case where the shape is loop invariant. 962 963 @RegisterPFor("Reshape") 964 def _convert_reshape(pfor_input): 965 # We assume that input is not loop invariant. Call to `stacked_input` 966 # asserts that and returns the converted value. This value will have a rank 967 # larger by 1 compared to the rank of the input in the loop body. 968 t = pfor_input.stacked_input(0) 969 970 # We assume that shape input is loop invariant. Call to `unstacked_input` 971 # asserts that and returns the converted value. 972 shape = pfor_input.unstacked_input(1) 973 974 # We compute `new_shape` by prepending the number of iterations to the 975 # original shape. 976 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], 977 axis=0) 978 979 # The vectorized output involves reshaping the converted input `t` using 980 # `new_shape`. 981 new_output = array_ops.reshape(t, new_shape) 982 983 # The converted output is marked as not loop invariant using the call to 984 # wrap. 985 return wrap(new_output, True) 986 """ 987 988 def __init__(self, op_type): 989 """Creates an object to register a converter for op with type `op_type`.""" 990 self.op_type = op_type 991 992 def __call__(self, converter): 993 name = self.op_type 994 assert name not in _pfor_converter_registry, "Re-registering %s " % name 995 _pfor_converter_registry[name] = converter 996 return converter 997 998 999class RegisterPForWithArgs(RegisterPFor): 1000 """Utility to register converters for pfor. 1001 1002 Usage: 1003 @RegisteRPFor(foo_op_type, foo=value, ....) 1004 def _foo_converter(pfor_input, foo=None, ....): 1005 ... 1006 1007 See RegisterPFor for details on the conversion function. 1008 `RegisterPForWithArgs` allows binding extra arguments to the 1009 conversion function at registration time. 1010 """ 1011 1012 def __init__(self, op_type, *args, **kw_args): 1013 super(RegisterPForWithArgs, self).__init__(op_type) 1014 self._args = args 1015 self._kw_args = kw_args 1016 1017 def __call__(self, converter): 1018 1019 def _f(pfor_input): 1020 return converter(pfor_input, self.op_type, *self._args, **self._kw_args) 1021 1022 super(RegisterPForWithArgs, self).__call__(_f) 1023 return converter 1024 1025 1026# TODO(agarwal): call raw_ops instead of calling these low level routines. 1027def _create_op(op_type, inputs, op_dtypes, attrs=None): 1028 """Utility to create an op.""" 1029 op = ops.get_default_graph().create_op( 1030 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) 1031 flat_attrs = [] 1032 # The tape expects an alternating flat list of names and attribute values. 1033 for a in attrs: 1034 flat_attrs.append(str(a)) 1035 flat_attrs.append(op.get_attr(str(a))) 1036 execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:]) 1037 return op 1038 1039 1040WrappedTensor = collections.namedtuple("WrappedTensor", 1041 ["t", "is_stacked", "is_sparse_stacked"]) 1042"""Wrapper around the result of a Tensor conversion. 1043 1044The additional fields are useful for keeping track of the conversion state as 1045data flows through the ops in the loop body. For every op whose output is a 1046Tensor, its converter should return either a WrappedTensor or a list of 1047WrappedTensors. 1048 1049Args: 1050 t: The converted tensor 1051 is_stacked: True if the tensor is stacked, i.e. represents the results of all 1052 the iterations of the loop, where each row i of the tensor corresponds to 1053 that op's output on iteration i of the loop. False if the tensor is not 1054 stacked, i.e. represents the result of the op on of a single iteration of 1055 the loop, where the result does not vary between iterations. 1056 is_sparse_stacked: True if the tensor corresponds to a component tensor 1057 (indices, values, or dense_shape) of a sparse tensor, and has been logically 1058 stacked via a sparse conversion. 1059""" 1060 1061 1062def wrap(tensor, is_stacked=True, is_sparse_stacked=False): 1063 """Helper to create a WrappedTensor object.""" 1064 assert isinstance(is_stacked, bool) 1065 assert isinstance(is_sparse_stacked, bool) 1066 assert isinstance(tensor, ops.Tensor) 1067 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " 1068 "stacked via a sparse " 1069 "conversion, it must also be " 1070 "stacked.") 1071 return WrappedTensor(tensor, is_stacked, is_sparse_stacked) 1072 1073 1074def _wrap_and_tile_variants(tensor, length): 1075 if tensor.dtype == dtypes.variant: 1076 tensor = _tile_variant_with_length(tensor, length) 1077 return wrap(tensor) 1078 1079 1080def _fallback_converter(pfor_input, root_cause="", warn=True): 1081 if warn: 1082 logging.warning("Using a while_loop for converting %s cause %s", 1083 pfor_input.op_type, root_cause) 1084 output_dtypes = [x.dtype for x in pfor_input.outputs] 1085 iters = pfor_input.pfor.loop_len_vector[0] 1086 1087 def while_body(i, *ta_list): 1088 """Body of while loop.""" 1089 inputs = [ 1090 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs 1091 ] 1092 op_outputs = _create_op( 1093 pfor_input.op_type, 1094 inputs, 1095 output_dtypes, 1096 attrs=pfor_input.op.node_def.attr).outputs 1097 1098 outputs = [] 1099 # TODO(agarwal): Add tf.debugging asserts to check that the shapes across 1100 # the different iterations are the same. 1101 for out, ta in zip(op_outputs, ta_list): 1102 assert isinstance(out, ops.Tensor) 1103 outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) 1104 return tuple([i + 1] + outputs) 1105 1106 ta_list = control_flow_ops.while_loop( 1107 lambda i, *ta: i < iters, while_body, [0] + 1108 [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes 1109 ])[1:] 1110 return tuple([wrap(ta.concat(), True) for ta in ta_list]) 1111 1112 1113class PForConfig: 1114 """A configuration object used to communicate with loop body function.""" 1115 1116 def __init__(self): 1117 # This may be set to the number of iterations. 1118 self._maybe_iters = None 1119 # Map from reduction node, created by `reduce`, to the bundle of reduction 1120 # function and arguments. 1121 self._reduce_map = {} 1122 1123 def _has_reductions(self): 1124 """True if some reductions where performed by loop body.""" 1125 return len(self._reduce_map) 1126 1127 def _set_iters(self, iters): 1128 """Set number of pfor iterations.""" 1129 if isinstance(iters, ops.Tensor): 1130 iters = tensor_util.constant_value(iters) 1131 self._maybe_iters = iters 1132 1133 def reduce(self, fn, *args): 1134 """Performs reduction `fn` on `args` vectorized across pfor iterations. 1135 1136 Note that `fn` is traced once inside the loop function context. Hence any 1137 captures or side-effects will happen in that context. Call to the traced 1138 version of `fn` happens during the construction of the vectorized code. 1139 1140 Note that this currently may not work inside a control flow construct. 1141 Args: 1142 fn: a reduction function. It will be called with arguments that have the 1143 same structure as *args but with individual values whose rank may be 1144 higher by 1 since they represent loop invariant vectorized versions of 1145 the corresponding Tensors in *args. 1146 *args: unvectorized Tensors. 1147 1148 Returns: 1149 The result of running `fn` on the vectorized versions of `*args`. These 1150 outputs will be available as loop invariant values to all the iterations. 1151 """ 1152 assert not context.executing_eagerly() 1153 # Creates a concrete function that will be used for reduction. 1154 tensor_specs = [] 1155 for arg in args: 1156 if not isinstance(arg, ops.Tensor): 1157 raise ValueError(f"Got a non-Tensor argument {arg} in reduce.") 1158 batched_shape = tensor_shape.TensorShape([self._maybe_iters 1159 ]).concatenate(arg.shape) 1160 tensor_specs.append( 1161 tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) 1162 concrete_function = def_function.function(fn).get_concrete_function( 1163 *tensor_specs) 1164 1165 # Creates PlaceholderWithDefault and IdentityN nodes corresponding the 1166 # reduction. 1167 pl_outputs = [] 1168 with ops.control_dependencies(args): 1169 for output in concrete_function.outputs: 1170 if not isinstance(output, ops.Tensor): 1171 raise ValueError(f"Got a non-Tensor output {output} while running " 1172 "reduce.") 1173 # Note that we use placeholder_with_default just to make XLA happy since 1174 # it does not like placeholder ops. 1175 if output.shape.is_fully_defined(): 1176 dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype) 1177 pl_outputs.append( 1178 array_ops.placeholder_with_default(dummy, shape=output.shape)) 1179 else: 1180 # TODO(agarwal): support case when under XLA and output.shape is not 1181 # fully defined. 1182 pl_outputs.append( 1183 array_ops.placeholder(output.dtype, shape=output.shape)) 1184 1185 reduction_op = array_ops.identity_n(pl_outputs)[0].op 1186 self._reduce_map[reduction_op] = (concrete_function, args) 1187 if len(reduction_op.outputs) == 1: 1188 return reduction_op.outputs[0] 1189 else: 1190 return tuple(reduction_op.outputs) 1191 1192 # TODO(agarwal): handle reductions inside control flow constructs. 1193 def reduce_concat(self, x): 1194 """Performs a concat reduction on `x` across pfor iterations. 1195 1196 Note that this currently may not work inside a control flow construct. 1197 Args: 1198 x: an unvectorized Tensor. 1199 1200 Returns: 1201 A Tensor that has rank one higher than `x`. The value is the vectorized 1202 version of `x`, i.e. stacking the value of `x` across different pfor 1203 iterations. 1204 """ 1205 return self.reduce(lambda y: y, x) 1206 1207 def reduce_mean(self, x): 1208 """Performs a mean reduction on `x` across pfor iterations. 1209 1210 Note that this currently may not work inside a control flow construct. 1211 Args: 1212 x: an unvectorized Tensor. 1213 1214 Returns: 1215 A Tensor that has same rank as `x`. The value is the mean of the values 1216 of `x` across the pfor iterations. 1217 """ 1218 return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x) 1219 1220 def reduce_sum(self, x): 1221 """Performs a sum reduction on `x` across pfor iterations. 1222 1223 Note that this currently may not work inside a control flow construct. 1224 Args: 1225 x: an unvectorized Tensor. 1226 1227 Returns: 1228 A Tensor that has same rank as `x`. The value is the sum of the values 1229 of `x` across the pfor iterations. 1230 """ 1231 return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x) 1232 1233 def _lookup_reduction(self, t): 1234 """Lookups Tensor `t` in the reduction maps.""" 1235 assert isinstance(t, ops.Tensor), t 1236 return self._reduce_map.get(t.op) 1237 1238 1239class PFor: 1240 """Implementation of rewrite of parallel-for loops. 1241 1242 This class takes a DAG or a set of DAGs representing the body of a 1243 parallel-for loop, and adds new operations to the graph that implements 1244 functionality equivalent to running that loop body for a specified number of 1245 iterations. This new set of nodes may or may not use a tensorflow loop 1246 construct. 1247 1248 The process of conversion does not delete or change any existing operations. 1249 It only adds operations that efficiently implement the equivalent 1250 functionality. We refer to the added ops as "converted ops". 1251 1252 The conversion process uses a simple greedy heuristic. It walks the loop body 1253 and tries to express the functionality of running each node in a loop with a 1254 new set of nodes. When converting an op several cases are possible: 1255 - The op is not inside the loop body. Hence it can be used as is. 1256 - The op does not depend on the iteration number and is stateless. In this 1257 case, it can be used as is. 1258 - The op is not stateful, and depends on iteration number only through control 1259 dependencies. In this case, we can create a single op with same inputs and 1260 attributes, but with "converted" control dependencies. 1261 - The op is not stateful, and all its inputs are loop invariant. In this 1262 case, similar to above, we can create a single op with same inputs and 1263 attributes, but with "converted" control dependencies. 1264 - The op is stateful or at least one of the inputs is not loop invariant. In 1265 this case, we run the registered converter for that op to create a set of 1266 converted ops. All nodes in the set will have converted control dependencies 1267 corresponding to control dependencies of the original op. If the op returned 1268 multiple outputs, "converted outputs" could be produced by different ops in 1269 this set. 1270 """ 1271 1272 def __init__(self, 1273 loop_var, 1274 loop_len, 1275 pfor_ops, 1276 fallback_to_while_loop, 1277 all_indices=None, 1278 all_indices_partitioned=False, 1279 pfor_config=None, 1280 warn=False): 1281 """Creates an object to rewrite a parallel-for loop. 1282 1283 Args: 1284 loop_var: ops.Tensor output of a Placeholder operation. The value should 1285 be an int32 scalar representing the loop iteration number. 1286 loop_len: A scalar or scalar Tensor representing the number of iterations 1287 the loop is run for. 1288 pfor_ops: List of all ops inside the loop body. 1289 fallback_to_while_loop: If True, on failure to vectorize an op, a while 1290 loop is used to sequentially execute that op. 1291 all_indices: If not None, an int32 vector with size `loop_len` 1292 representing the iteration ids that are still active. These values 1293 should be unique and sorted. However they may not be contiguous. This is 1294 typically the case when inside a control flow construct which has 1295 partitioned the indices of the iterations that are being converted. 1296 all_indices_partitioned: If True, this object is being constructed from a 1297 control flow construct where not all the pfor iterations are guaranteed 1298 to be active. 1299 pfor_config: PForConfig object used while constructing the loop body. 1300 warn: Whether or not to warn on while loop conversions. 1301 """ 1302 assert isinstance(loop_var, ops.Tensor) 1303 assert loop_var.op.type == "PlaceholderWithDefault" 1304 self._loop_var = loop_var 1305 loop_len_value = tensor_util.constant_value(loop_len) 1306 if loop_len_value is not None: 1307 loop_len = loop_len_value 1308 self._loop_len_vector = array_ops.reshape(loop_len, [1]) 1309 self._all_indices_partitioned = all_indices_partitioned 1310 if all_indices_partitioned: 1311 assert all_indices is not None 1312 self.all_indices = ( 1313 math_ops.range(loop_len) if all_indices is None else all_indices) 1314 1315 self._conversion_map = object_identity.ObjectIdentityDictionary() 1316 self._conversion_map[loop_var] = wrap(self.all_indices, True) 1317 self._pfor_ops = set(pfor_ops) 1318 self._pfor_op_ids = set(x._id for x in pfor_ops) 1319 self._fallback_to_while_loop = fallback_to_while_loop 1320 self._warn = warn 1321 self._pfor_config = pfor_config 1322 1323 def op_is_inside_loop(self, op): 1324 """True if op was created inside the pfor loop body.""" 1325 assert isinstance(op, ops.Operation) 1326 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 1327 # since it appears there tensorflow API could return different python 1328 # objects representing the same Operation node. 1329 return op._id in self._pfor_op_ids 1330 1331 def _convert_sparse(self, y): 1332 """Returns the converted value corresponding to SparseTensor y. 1333 1334 For SparseTensors, instead of stacking the component tensors separately, 1335 resulting in component tensors with shapes (N, m, rank), (N, m), and (N, 1336 rank) respectively for indices, values, and dense_shape (where N is the loop 1337 length and m is the number of sparse tensor values per loop iter), we want 1338 to logically stack the SparseTensors, to create a SparseTensor whose 1339 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) 1340 respectively. 1341 1342 Here, we try to get the conversion of each component tensor. 1343 If the tensors are stacked via a sparse conversion, return the resulting 1344 SparseTensor composed of the converted components. Otherwise, the component 1345 tensors are either unstacked or stacked naively. In the latter case, we 1346 unstack the component tensors to reform loop_len SparseTensor elements, 1347 then correctly batch them. 1348 1349 The unstacked tensors must have the same rank. Each dimension of each 1350 SparseTensor will expand to be the largest among all SparseTensor elements 1351 for that dimension. For example, if there are N SparseTensors of rank 3 1352 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), 1353 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). 1354 1355 Args: 1356 y: A tf.sparse.SparseTensor. 1357 1358 Returns: 1359 A tf.sparse.SparseTensor that is the converted value corresponding to y. 1360 """ 1361 outputs = [ 1362 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) 1363 ] 1364 assert all(isinstance(o, WrappedTensor) for o in outputs) 1365 1366 if all(w.is_sparse_stacked for w in outputs): 1367 return sparse_tensor.SparseTensor(*[w.t for w in outputs]) 1368 1369 assert not any(w.is_sparse_stacked for w in outputs), ( 1370 "Error converting SparseTensor. All components should be logically " 1371 "stacked, or none.") 1372 1373 # If component tensors were not sparsely stacked, they are either unstacked 1374 # or stacked without knowledge that they are components of sparse tensors. 1375 # In this case, we have to restack them. 1376 return self._restack_sparse_tensor_logically( 1377 *[self._unwrap_or_tile(w) for w in outputs]) 1378 1379 def _restack_sparse_tensor_logically(self, indices, values, shape): 1380 sparse_tensor_rank = indices.get_shape().dims[-1].value 1381 if sparse_tensor_rank is not None: 1382 sparse_tensor_rank += 1 1383 1384 def fn(args): 1385 res = gen_sparse_ops.serialize_sparse( 1386 args[0], args[1], args[2], out_type=dtypes.variant) 1387 return res 1388 1389 # Applies a map function to the component tensors to serialize each 1390 # sparse tensor element and batch them all, then deserializes the batch. 1391 # TODO(rachelim): Try to do this without map_fn -- add the right offsets 1392 # to shape and indices tensors instead. 1393 result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant) 1394 return sparse_ops.deserialize_sparse( 1395 result, dtype=values.dtype, rank=sparse_tensor_rank) 1396 1397 def _unwrap_or_tile(self, wrapped_tensor): 1398 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" 1399 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked 1400 if is_stacked: 1401 return output 1402 else: 1403 return _stack(output, self._loop_len_vector).t 1404 1405 def convert(self, y): 1406 """Returns the converted value corresponding to y. 1407 1408 Args: 1409 y: A ops.Tensor or a ops.Operation object. If latter, y should not have 1410 any outputs. 1411 1412 Returns: 1413 If y does not need to be converted, it returns y as is. Else it returns 1414 the "converted value" corresponding to y. 1415 """ 1416 if y is None: 1417 return None 1418 if isinstance(y, sparse_tensor.SparseTensor): 1419 return self._convert_sparse(y) 1420 assert isinstance(y, (ops.Tensor, ops.Operation)), y 1421 output = self._convert_helper(y) 1422 if isinstance(output, WrappedTensor): 1423 assert isinstance(y, ops.Tensor) 1424 return self._unwrap_or_tile(output) 1425 else: 1426 assert isinstance(y, ops.Operation) 1427 assert not y.outputs 1428 assert isinstance(output, ops.Operation) 1429 return output 1430 1431 def _was_converted(self, t): 1432 """True if t is not a conversion of itself.""" 1433 converted_t = self._conversion_map[t] 1434 return converted_t.t is not t 1435 1436 def _add_conversion(self, old_output, new_output): 1437 assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output 1438 assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output 1439 self._conversion_map[old_output] = new_output 1440 1441 def _convert_reduction(self, y): 1442 # Handle reductions. 1443 if self._pfor_config is None or isinstance(y, ops.Operation): 1444 return None 1445 reduction = self._pfor_config._lookup_reduction(y) 1446 if reduction is None: 1447 return None 1448 (reduction_fn, reduction_args) = reduction 1449 batched_args = [] 1450 for reduction_arg in reduction_args: 1451 assert isinstance(reduction_arg, ops.Tensor), reduction_arg 1452 # Tensor being reduced should already be converted due to a control 1453 # dependency on the created placeholder. 1454 # Note that in cases where reduction_arg is in an outer context, one 1455 # needs to locate the corresponding Enter node and use that to lookup 1456 # the conversion. 1457 # TODO(agarwal): handle reductions inside control flow constructs. 1458 assert reduction_arg in self._conversion_map, ( 1459 "Unable to handle reduction of %s, possibly as it was used " 1460 "inside a control flow construct. Note that reductions across " 1461 "pfor iterations are currently not supported inside control flow " 1462 "constructs." % reduction_arg) 1463 batched_arg = self._conversion_map[reduction_arg] 1464 batched_args.append(self._unwrap_or_tile(batched_arg)) 1465 outputs = reduction_fn(*batched_args) 1466 return [wrap(output, False) for output in nest.flatten(outputs)] 1467 1468 def _convert_helper(self, op_or_tensor): 1469 stack = collections.deque([op_or_tensor]) 1470 while stack: 1471 y = stack[0] 1472 if y in self._conversion_map: 1473 assert isinstance(self._conversion_map[y], 1474 (WrappedTensor, ops.Operation)) 1475 stack.popleft() 1476 continue 1477 if isinstance(y, ops.Operation): 1478 assert not y.outputs, ( 1479 "We only support converting Operation objects with no outputs. " 1480 "Got %s", y) 1481 y_op = y 1482 else: 1483 assert isinstance(y, ops.Tensor), y 1484 y_op = y.op 1485 1486 is_while_loop = y_op.type == "Exit" 1487 if is_while_loop: 1488 while_op = WhileOp( 1489 y, pfor_ops=self._pfor_ops, 1490 fallback_to_while_loop=self.fallback_to_while_loop, 1491 pfor_config=self._pfor_config) 1492 is_inside_loop = while_op.is_inside_loop 1493 # If all nodes in the while_loop graph were created inside the pfor, we 1494 # treat the whole loop subgraph as a single op (y_op) and try to convert 1495 # it. For while_loops that are created completely or partially outside, 1496 # we treat them as external and should be able to simply return the Exit 1497 # node output as is without needing any conversion. Note that for 1498 # while_loops that are partially constructed inside, we assume they will 1499 # be loop invariant. If that is not the case, it will create runtime 1500 # errors since the converted graph would depend on the self._loop_var 1501 # placeholder. 1502 if is_inside_loop: 1503 y_op = while_op 1504 else: 1505 is_inside_loop = self.op_is_inside_loop(y_op) 1506 1507 # If this op was not created inside the loop body, we will return as is. 1508 # 1. Convert inputs and control inputs. 1509 1510 def _add_to_stack(x): 1511 if x not in self._conversion_map: 1512 stack.appendleft(x) 1513 return True 1514 else: 1515 return False 1516 1517 if is_inside_loop: 1518 added_to_stack = False 1519 for inp in y_op.inputs: 1520 added_to_stack |= _add_to_stack(inp) 1521 for cinp in y_op.control_inputs: 1522 if cinp.outputs: 1523 for t in cinp.outputs: 1524 added_to_stack |= _add_to_stack(t) 1525 else: 1526 added_to_stack |= _add_to_stack(cinp) 1527 if added_to_stack: 1528 continue 1529 1530 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] 1531 some_input_converted = any(self._was_converted(x) for x in y_op.inputs) 1532 some_input_stacked = any(x.is_stacked for x in converted_inputs) 1533 1534 converted_control_ops = set() 1535 some_control_input_converted = False 1536 for cinp in y_op.control_inputs: 1537 if cinp.outputs: 1538 for t in cinp.outputs: 1539 converted_t = self._conversion_map[t] 1540 if self._was_converted(t): 1541 some_control_input_converted = True 1542 converted_control_ops.add(converted_t.t.op) 1543 else: 1544 converted_cinp = self._conversion_map[cinp] 1545 assert isinstance(converted_cinp, ops.Operation) 1546 if converted_cinp != cinp: 1547 some_control_input_converted = True 1548 converted_control_ops.add(converted_cinp) 1549 converted_control_ops = list(converted_control_ops) 1550 is_stateful = _is_stateful_pfor_op(y_op) 1551 else: 1552 converted_inputs = [] 1553 converted_control_ops = [] 1554 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, 1555 converted_inputs, converted_control_ops) 1556 1557 # 2. Convert y_op 1558 # If converting a while_loop, we let the while_loop convertor deal with 1559 # putting the control dependencies appropriately. 1560 control_dependencies = [] if is_while_loop else converted_control_ops 1561 with ops.control_dependencies(control_dependencies), ops.name_scope( 1562 y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op): 1563 # Op is a placeholder for a reduction. 1564 reduce_output = self._convert_reduction(y) 1565 if reduce_output is not None: 1566 new_outputs = reduce_output 1567 # None of the inputs and control inputs were converted. 1568 elif ((not is_inside_loop or 1569 (not is_stateful and not some_input_converted and 1570 not some_control_input_converted)) and 1571 y.graph == ops.get_default_graph()): 1572 if y is y_op: 1573 assert not isinstance(y_op, WhileOp) 1574 new_outputs = y_op 1575 else: 1576 new_outputs = [wrap(x, False) for x in y_op.outputs] 1577 elif not (is_stateful or is_while_loop or some_input_stacked): 1578 # All inputs are unstacked or unconverted but some control inputs are 1579 # converted. 1580 # TODO(rachelim): Handle the case where some inputs are sparsely 1581 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) 1582 new_op = _create_op(y_op.type, [x.t for x in converted_inputs], 1583 [x.dtype for x in y_op.outputs], 1584 y_op.node_def.attr) 1585 if y is y_op: 1586 new_outputs = new_op 1587 else: 1588 new_outputs = [] 1589 for old_output, new_output in zip(y_op.outputs, new_op.outputs): 1590 handle_data_util.copy_handle_data(old_output, new_output) 1591 new_outputs.append(wrap(new_output, False)) 1592 else: 1593 # Either some inputs are not loop invariant or op is stateful. 1594 if hasattr(y_op, "pfor_converter"): 1595 converter = y_op.pfor_converter 1596 else: 1597 converter = _pfor_converter_registry.get(y_op.type, None) 1598 if converter is None: 1599 root_cause = (f"there is no registered converter for this op.") 1600 has_variant_outputs = any(x.dtype == dtypes.variant for x in 1601 y_op.outputs) 1602 has_vectorized_variant_inputs = any( 1603 _is_variant_with_internal_stacking(x) for x in 1604 y_op.inputs) 1605 if (self._fallback_to_while_loop and not has_variant_outputs 1606 and not has_vectorized_variant_inputs): 1607 converter = partial( 1608 _fallback_converter, root_cause=root_cause, warn=self._warn) 1609 else: 1610 message = (f"No pfor vectorization defined for {y_op.type}\n" 1611 f"{y_op}\n inputs: {converted_inputs}.") 1612 if not self._fallback_to_while_loop: 1613 message += ("Consider enabling the fallback_to_while_loop " 1614 "option to pfor, which may run slower.") 1615 raise ValueError(message) 1616 # TODO(rachelim): Handle the case where some inputs are sparsely 1617 # stacked. We should only call the converter if it supports handling 1618 # those inputs. 1619 pfor_inputs = _PforInput(self, y_op, converted_inputs) 1620 try: 1621 try: 1622 new_outputs = converter(pfor_inputs) 1623 except ConversionNotImplementedError as e: 1624 has_vectorized_variant_inputs = any( 1625 _is_variant_with_internal_stacking(x) for x in 1626 y_op.inputs) 1627 if (self._fallback_to_while_loop 1628 and not has_vectorized_variant_inputs): 1629 new_outputs = _fallback_converter( 1630 pfor_inputs, root_cause=str(e)) 1631 else: 1632 raise ValueError(str(e)).with_traceback(sys.exc_info()[2]) 1633 except Exception as e: # pylint: disable=broad-except 1634 logging.error( 1635 f"Got error while pfor was converting op {y_op} with inputs " 1636 f"{y_op.inputs[:]}\n, converted inputs {pfor_inputs.inputs}\n" 1637 f"Here are the pfor conversion stack traces: {e}") 1638 original_op = y_op 1639 while isinstance(original_op, ops.Operation): 1640 logging.error( 1641 "%s\ncreated at:\n %s", original_op, 1642 " ".join(traceback.format_list(original_op.traceback))) 1643 original_op = original_op._original_op 1644 raise 1645 1646 if isinstance(new_outputs, WrappedTensor): 1647 new_outputs = [new_outputs] 1648 assert isinstance(new_outputs, 1649 (list, tuple, ops.Operation)), new_outputs 1650 logging.vlog(2, f"converted {y_op} {new_outputs}") 1651 1652 # Insert into self._conversion_map 1653 if y is y_op: 1654 assert isinstance(new_outputs, ops.Operation) 1655 self._add_conversion(y_op, new_outputs) 1656 else: 1657 assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs, 1658 new_outputs) 1659 for old_output, new_output in zip(y_op.outputs, new_outputs): 1660 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) 1661 assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op) 1662 # Set shape for converted output. 1663 output_shape = old_output.shape 1664 if not new_output.is_sparse_stacked: 1665 if new_output.is_stacked: 1666 loop_len = tensor_util.constant_value(self.loop_len_vector) 1667 if loop_len is None: 1668 batch_dim = tensor_shape.TensorShape([None]) 1669 else: 1670 batch_dim = tensor_shape.TensorShape(loop_len) 1671 output_shape = batch_dim.concatenate(output_shape) 1672 if _is_variant_with_internal_stacking(new_output.t): 1673 new_output.t.set_shape([]) 1674 else: 1675 new_output.t.set_shape(output_shape) 1676 self._add_conversion(old_output, new_output) 1677 stack.popleft() 1678 1679 return self._conversion_map[op_or_tensor] 1680 1681 @property 1682 def loop_len_vector(self): 1683 """Returns a single element vector whose value is number of iterations.""" 1684 return self._loop_len_vector 1685 1686 @property 1687 def loop_var(self): 1688 """Returns placeholder loop variable.""" 1689 return self._loop_var 1690 1691 @property 1692 def pfor_ops(self): 1693 return self._pfor_ops 1694 1695 @property 1696 def pfor_config(self): 1697 return self._pfor_config 1698 1699 @property 1700 def all_indices_partitioned(self): 1701 """all_indices_partitioned property. 1702 1703 Returns: 1704 True if we are inside a control flow construct and not all pfor iterations 1705 may be active. 1706 """ 1707 return self._all_indices_partitioned 1708 1709 @property 1710 def fallback_to_while_loop(self): 1711 return self._fallback_to_while_loop 1712 1713 1714# The code below defines converters for different operations. Please see comment 1715# for RegisterPFor to see how converters should be defined. 1716 1717 1718# image_ops 1719 1720 1721@RegisterPFor("AdjustContrastv2") 1722def _convert_adjust_contrastv2(pfor_input): 1723 images = pfor_input.stacked_input(0) 1724 contrast_factor = pfor_input.unstacked_input(1) 1725 return wrap(gen_image_ops.adjust_contrastv2(images, contrast_factor), True) 1726 1727 1728@RegisterPFor("AdjustHue") 1729def _convert_adjust_hue(pfor_input): 1730 images = pfor_input.stacked_input(0) 1731 delta = pfor_input.unstacked_input(1) 1732 return wrap(gen_image_ops.adjust_hue(images, delta), True) 1733 1734 1735@RegisterPFor("AdjustSaturation") 1736def _convert_adjust_saturation(pfor_input): 1737 images = pfor_input.stacked_input(0) 1738 scale = pfor_input.unstacked_input(1) 1739 return wrap(gen_image_ops.adjust_saturation(images, scale), True) 1740 1741 1742# nn_ops 1743 1744 1745def _flatten_first_two_dims(x): 1746 """Merges first two dimensions.""" 1747 old_shape = array_ops.shape(x) 1748 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) 1749 return array_ops.reshape(x, new_shape) 1750 1751 1752def _unflatten_first_dim(x, first_dim): 1753 """Splits first dimension into [first_dim, -1].""" 1754 old_shape = array_ops.shape(x) 1755 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) 1756 return array_ops.reshape(x, new_shape) 1757 1758 1759def _inputs_with_flattening(pfor_input, input_indices): 1760 """Stacks and flattens first dim of inputs at indices `input_indices`.""" 1761 if input_indices is None: 1762 input_indices = [] 1763 pfor_input.stack_inputs(stack_indices=input_indices) 1764 inputs = [] 1765 for i in range(pfor_input.num_inputs): 1766 if i in input_indices: 1767 inp = pfor_input.stacked_input(i) 1768 inp = _flatten_first_two_dims(inp) 1769 else: 1770 inp = pfor_input.unstacked_input(i) 1771 inputs.append(inp) 1772 return inputs 1773 1774 1775@RegisterPForWithArgs("Conv2D", dims=[0]) 1776@RegisterPForWithArgs("DepthToSpace", dims=[0]) 1777@RegisterPForWithArgs("AvgPool", dims=[0]) 1778@RegisterPForWithArgs("AvgPool3D", dims=[0]) 1779@RegisterPForWithArgs("MaxPool", dims=[0]) 1780@RegisterPForWithArgs("MaxPoolV2", dims=[0]) 1781@RegisterPForWithArgs("MaxPool3D", dims=[0]) 1782@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) 1783@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) 1784@RegisterPForWithArgs("MaxPoolGradV2", dims=[0, 1, 2]) 1785@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) 1786@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) 1787@RegisterPForWithArgs("MaxPoolGradGradV2", dims=[0, 1, 2]) 1788@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1789@RegisterPForWithArgs("SparseSoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1790@RegisterPForWithArgs("SpaceToDepth", dims=[0]) 1791def _convert_flatten_batch(pfor_input, op_type, dims): 1792 del op_type 1793 inputs = _inputs_with_flattening(pfor_input, dims) 1794 outputs = _create_op( 1795 pfor_input.op_type, 1796 inputs, [x.dtype for x in pfor_input.outputs], 1797 attrs=pfor_input.op.node_def.attr).outputs 1798 n = pfor_input.pfor.loop_len_vector 1799 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1800 return [wrap(x, True) for x in outputs] 1801 1802 1803_channel_flatten_input_cache = {} 1804 1805 1806@RegisterPFor("BatchToSpaceND") 1807def _convert_batch_to_space_nd(pfor_input): 1808 inp = pfor_input.stacked_input(0) 1809 block_shape = pfor_input.unstacked_input(1) 1810 crops = pfor_input.unstacked_input(2) 1811 1812 inp_shape = array_ops.shape(inp) 1813 n = pfor_input.pfor.loop_len_vector 1814 1815 # Reshape and transpose to move the vectorization axis inside the axes that 1816 # will move to space. 1817 # Reshape to 4D and transpose 1818 block_size = math_ops.reduce_prod(block_shape) 1819 new_shape = [n[0], block_size, inp_shape[1] // block_size, -1] 1820 inp = array_ops.reshape(inp, new_shape) 1821 inp = array_ops.transpose(inp, [1, 0, 2, 3]) 1822 # Reshape back to merge the block, vectorization and batch dimension, and 1823 # restore the other dimensions. 1824 new_shape = array_ops.concat([n * inp_shape[1], inp_shape[2:]], axis=0) 1825 inp = array_ops.reshape(inp, new_shape) 1826 # Call batch_to_space and then split the new batch axis. 1827 output = gen_array_ops.batch_to_space_nd(inp, block_shape, crops) 1828 output = _unflatten_first_dim(output, n) 1829 return wrap(output, True) 1830 1831 1832@RegisterPFor("SpaceToBatchND") 1833def _convert_space_to_batch_nd(pfor_input): 1834 inp = pfor_input.stacked_input(0) 1835 block_shape = pfor_input.unstacked_input(1) 1836 paddings = pfor_input.unstacked_input(2) 1837 1838 n = pfor_input.pfor.loop_len_vector 1839 inp_shape = array_ops.shape(inp) 1840 inp = _flatten_first_two_dims(inp) 1841 output = gen_array_ops.space_to_batch_nd(inp, block_shape, paddings) 1842 output_shape = array_ops.shape(output) 1843 block_size = math_ops.reduce_prod(block_shape) 1844 new_shape = [block_size, n[0], -1] 1845 output = array_ops.reshape(output, new_shape) 1846 output = array_ops.transpose(output, [1, 0, 2]) 1847 new_shape = array_ops.concat( 1848 [n, block_size * inp_shape[1:2], output_shape[1:]], axis=0) 1849 output = array_ops.reshape(output, new_shape) 1850 return wrap(output, True) 1851 1852 1853def _channel_flatten_input(x, data_format): 1854 """Merge the stack dimension with the channel dimension. 1855 1856 If S is pfor's stacking dimension, then, 1857 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose 1858 should be cheap. 1859 - for SNHWC, we transpose to NHWSC. 1860 We then merge the S and C dimension. 1861 1862 Args: 1863 x: ops.Tensor to transform. 1864 data_format: "NCHW" or "NHWC". 1865 1866 Returns: 1867 A 3-element tuple with the transformed value, along with the shape for 1868 reshape and order for transpose required to transform back. 1869 """ 1870 1871 graph = ops.get_default_graph() 1872 cache_key = (graph, x.ref(), data_format) 1873 if cache_key not in _channel_flatten_input_cache: 1874 x_shape = array_ops.shape(x) 1875 if data_format == b"NCHW": 1876 order = [1, 0, 2, 3, 4] 1877 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) 1878 reverse_order = order 1879 else: 1880 order = [1, 2, 3, 0, 4] 1881 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) 1882 reverse_order = [3, 0, 1, 2, 4] 1883 # Move S dimension next to C dimension. 1884 x = array_ops.transpose(x, order) 1885 reverse_shape = array_ops.shape(x) 1886 # Reshape to merge the S and C dimension. 1887 x = array_ops.reshape(x, shape) 1888 outputs = x, reverse_order, reverse_shape 1889 _channel_flatten_input_cache[cache_key] = outputs 1890 else: 1891 outputs = _channel_flatten_input_cache[cache_key] 1892 return outputs 1893 1894 1895# Note that with training=True, running FusedBatchNormV3 on individual examples 1896# is very different from running FusedBatchNormV3 on a batch of those examples. 1897# This is because, for the latter case, the operation can be considered as first 1898# computing the mean and variance over all the examples and then using these 1899# to scale all those examples. This creates a data dependency between these 1900# different "iterations" since the inputs to the scaling step depends on the 1901# statistics coming from all these inputs. 1902# As with other kernels, the conversion here effectively runs the kernel 1903# independently for each iteration, and returns outputs by stacking outputs from 1904# each of those iterations. 1905@RegisterPFor("FusedBatchNormV3") 1906def _convert_fused_batch_norm(pfor_input): 1907 is_training = pfor_input.get_attr("is_training") 1908 # When BatchNorm is used with training=False, mean and variance are provided 1909 # externally and used as is by the op. Thus, we can merge the S and N 1910 # dimensions as we do for regular operations. 1911 # When BatchNorm is used with training=True, mean and variance are computed 1912 # for each channel across the batch dimension (first one). If we merge S and N 1913 # dimensions, mean and variances will be computed over a larger set. So, we 1914 # merge the S and C dimensions instead. 1915 if not is_training: 1916 # We return zeros for batch_mean and batch_variance output. Note that CPU 1917 # and GPU seem to have different behavior for those two outputs. CPU outputs 1918 # zero because these values are not used during inference. GPU outputs 1919 # something, probably real means and variances. 1920 inputs = _inputs_with_flattening(pfor_input, [0]) 1921 outputs = _create_op( 1922 pfor_input.op_type, 1923 inputs, [x.dtype for x in pfor_input.outputs], 1924 attrs=pfor_input.op.node_def.attr).outputs 1925 y = outputs[0] 1926 n = pfor_input.pfor.loop_len_vector 1927 y = _unflatten_first_dim(y, n) 1928 mean = pfor_input.unstacked_input(3) 1929 zeros = array_ops.zeros_like(mean) 1930 return [wrap(y, True)] + [wrap(zeros, False)] * 5 1931 1932 pfor_input.stack_inputs() 1933 data_format = pfor_input.get_attr("data_format") 1934 # We merge the first dimension with the "C" dimension, run FusedBatchNormV3, 1935 # and then transpose back. 1936 x = pfor_input.stacked_input(0) 1937 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) 1938 # Note that we stack all the other inputs as well so that they are the same 1939 # size as the new size of the channel dimension. 1940 inputs = [x] + [ 1941 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1942 for i in range(1, pfor_input.num_inputs) 1943 ] 1944 outputs = _create_op( 1945 pfor_input.op_type, 1946 inputs, [x.dtype for x in pfor_input.outputs], 1947 attrs=pfor_input.op.node_def.attr).outputs 1948 y = outputs[0] 1949 y = array_ops.reshape(y, reverse_shape) 1950 y = array_ops.transpose(y, reverse_order) 1951 n = pfor_input.pfor.loop_len_vector 1952 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1953 outputs = [y] + outputs 1954 return [wrap(x, True) for x in outputs] 1955 1956 1957@RegisterPFor("FusedBatchNormGradV3") 1958def _convert_fused_batch_norm_grad(pfor_input): 1959 pfor_input.stack_inputs() 1960 data_format = pfor_input.get_attr("data_format") 1961 y_backprop = pfor_input.stacked_input(0) 1962 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) 1963 x = pfor_input.stacked_input(1) 1964 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) 1965 inputs = [y_backprop, x] + [ 1966 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1967 for i in range(2, pfor_input.num_inputs) 1968 ] 1969 outputs = _create_op( 1970 pfor_input.op_type, 1971 inputs, [x.dtype for x in pfor_input.outputs], 1972 attrs=pfor_input.op.node_def.attr).outputs 1973 x_backprop = outputs[0] 1974 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) 1975 x_backprop = array_ops.transpose(x_backprop, x_reverse_order) 1976 n = pfor_input.pfor.loop_len_vector 1977 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1978 outputs = [x_backprop] + outputs 1979 return [wrap(output, True) for output in outputs] 1980 1981 1982@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) 1983@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) 1984@RegisterPForWithArgs("AvgPool3DGrad", flatten_dims=[1], shape_dim=0) 1985def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, 1986 shape_dim): 1987 del op_type 1988 inputs = _inputs_with_flattening(pfor_input, flatten_dims) 1989 n = pfor_input.pfor.loop_len_vector 1990 # Adjust the `input_sizes` input. 1991 ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1], 1992 dtype=n.dtype) 1993 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) 1994 outputs = _create_op( 1995 pfor_input.op_type, 1996 inputs, [x.dtype for x in pfor_input.outputs], 1997 attrs=pfor_input.op.node_def.attr).outputs 1998 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1999 return [wrap(x, True) for x in outputs] 2000 2001 2002@RegisterPFor("Conv2DBackpropFilter") 2003def _convert_conv2d_backprop_filter(pfor_input): 2004 pfor_input.stack_inputs(stack_indices=[2]) 2005 inputs, inputs_stacked, _ = pfor_input.input(0) 2006 filter_sizes = pfor_input.unstacked_input(1) 2007 grads = pfor_input.stacked_input(2) 2008 strides = pfor_input.get_attr("strides") 2009 padding = pfor_input.get_attr("padding") 2010 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") 2011 data_format = pfor_input.get_attr("data_format") 2012 dilations = pfor_input.get_attr("dilations") 2013 if inputs_stacked: 2014 # TODO(agarwal): Implement this efficiently. 2015 logging.warning("Conv2DBackpropFilter uses a while_loop. Fix that!") 2016 2017 def while_body(i, ta): 2018 inp_i = inputs[i, ...] 2019 grad_i = grads[i, ...] 2020 output = nn_ops.conv2d_backprop_filter( 2021 inp_i, 2022 filter_sizes, 2023 grad_i, 2024 strides=strides, 2025 padding=padding, 2026 use_cudnn_on_gpu=use_cudnn_on_gpu, 2027 data_format=data_format, 2028 dilations=dilations) 2029 return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) 2030 2031 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) 2032 _, ta = control_flow_ops.while_loop( 2033 lambda i, ta: i < n, while_body, 2034 (0, tensor_array_ops.TensorArray(inputs.dtype, n))) 2035 output = ta.concat() 2036 return wrap(output, True) 2037 else: 2038 # We merge the stack dimension with the channel dimension of the gradients 2039 # and pretend we had a larger filter (see change to filter_sizes below). 2040 # Once the filter backprop is computed, we reshape and transpose back 2041 # appropriately. 2042 grads, _, _ = _channel_flatten_input(grads, data_format) 2043 n = pfor_input.pfor.loop_len_vector 2044 old_filter_sizes = filter_sizes 2045 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) 2046 output = nn_ops.conv2d_backprop_filter( 2047 inputs, 2048 filter_sizes, 2049 grads, 2050 strides=strides, 2051 padding=padding, 2052 use_cudnn_on_gpu=use_cudnn_on_gpu, 2053 data_format=data_format, 2054 dilations=dilations) 2055 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) 2056 output = array_ops.reshape(output, new_filter_shape) 2057 output = array_ops.transpose(output, [3, 0, 1, 2, 4]) 2058 return wrap(output, True) 2059 2060 2061def _flatten_with_inner_dim(x, dim, x_rank): 2062 """Merges the first dim with the specified dim.""" 2063 shape = array_ops.shape(x) 2064 x = array_ops.transpose(x, 2065 list(range(1, dim)) + [0] + list(range(dim, x_rank))) 2066 2067 if dim < x_rank - 1: 2068 new_shape_pieces = [shape[1:dim], [-1], shape[dim + 1:]] 2069 else: 2070 new_shape_pieces = [shape[1:dim], [-1]] 2071 new_shape = array_ops.concat(new_shape_pieces, axis=0) 2072 return array_ops.reshape(x, new_shape) 2073 2074 2075def _unflatten_with_inner_dim(x, dim, x_rank, stack_size): 2076 """Undoes _flatten_with_inner_dim.""" 2077 shape = array_ops.shape(x) 2078 if dim < x_rank - 1: 2079 new_shape_pieces = [shape[:dim], [stack_size], [-1], shape[dim + 1:]] 2080 else: 2081 new_shape_pieces = [shape[:dim], [stack_size], [-1]] 2082 new_shape = array_ops.concat(new_shape_pieces, axis=0) 2083 x = array_ops.reshape(x, new_shape) 2084 dims_permutation = [dim] + list(range(dim)) + list(range(dim + 1, x_rank + 1)) 2085 return array_ops.transpose(x, dims_permutation) 2086 2087 2088@RegisterPFor("DepthwiseConv2dNative") 2089def _convert_depthwise_conv2d_native(pfor_input): 2090 # Kernel can be vectorized, so folding to batch dimension does not work. We 2091 # instead fold into the channel dimension because it is parallel. 2092 stack_size = pfor_input.pfor.loop_len_vector[0] 2093 data_format = pfor_input.get_attr("data_format") 2094 c_dim = 1 if data_format == b"NCHW" else 3 2095 t = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) 2096 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) 2097 conv = _create_op( 2098 "DepthwiseConv2dNative", [t, kernel], 2099 [x.dtype for x in pfor_input.outputs], 2100 attrs=pfor_input.op.node_def.attr).outputs[0] 2101 return wrap(_unflatten_with_inner_dim(conv, c_dim, 4, stack_size), True) 2102 2103 2104@RegisterPFor("DepthwiseConv2dNativeBackpropInput") 2105def _convert_depthwise_conv2d_native_backprop_input(pfor_input): 2106 stack_size = pfor_input.pfor.loop_len_vector[0] 2107 input_sizes = pfor_input.unstacked_input(0) 2108 data_format = pfor_input.get_attr("data_format") 2109 c_dim = 1 if data_format == b"NCHW" else 3 2110 input_sizes_mutipliers = [ 2111 constant_op.constant([1] * c_dim, dtype=dtypes.int32), [stack_size] 2112 ] 2113 if c_dim < 3: 2114 input_sizes_mutipliers += [ 2115 constant_op.constant([1] * (3 - c_dim), dtype=dtypes.int32) 2116 ] 2117 input_sizes *= array_ops.concat(input_sizes_mutipliers, axis=0) 2118 kernel = _flatten_with_inner_dim(pfor_input.stacked_input(1), 3, 5) 2119 out_backprop = _flatten_with_inner_dim( 2120 pfor_input.stacked_input(2), c_dim + 1, 5) 2121 result = _create_op( 2122 "DepthwiseConv2dNativeBackpropInput", [input_sizes, kernel, out_backprop], 2123 [x.dtype for x in pfor_input.outputs], 2124 attrs=pfor_input.op.node_def.attr).outputs[0] 2125 return wrap(_unflatten_with_inner_dim(result, c_dim, 4, stack_size), True) 2126 2127 2128@RegisterPFor("DepthwiseConv2dNativeBackpropFilter") 2129def _convert_depthwise_conv2d_native_backprop_filter(pfor_input): 2130 stack_size = pfor_input.pfor.loop_len_vector[0] 2131 data_format = pfor_input.get_attr("data_format") 2132 c_dim = 1 if data_format == b"NCHW" else 3 2133 inputs = _flatten_with_inner_dim(pfor_input.stacked_input(0), c_dim + 1, 5) 2134 filter_sizes = pfor_input.unstacked_input(1) 2135 filter_sizes_multipliers = [ 2136 constant_op.constant([1, 1], dtype=dtypes.int32), [stack_size], 2137 constant_op.constant([1], dtype=dtypes.int32) 2138 ] 2139 filter_sizes *= array_ops.concat(filter_sizes_multipliers, axis=0) 2140 out_backprop = _flatten_with_inner_dim( 2141 pfor_input.stacked_input(2), c_dim + 1, 5) 2142 result = _create_op( 2143 "DepthwiseConv2dNativeBackpropFilter", 2144 [inputs, filter_sizes, out_backprop], 2145 [x.dtype for x in pfor_input.outputs], 2146 attrs=pfor_input.op.node_def.attr).outputs[0] 2147 return wrap(_unflatten_with_inner_dim(result, 2, 4, stack_size), True) 2148 2149 2150@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax) 2151@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax) 2152def _convert_softmax(pfor_input, op_type, op_func): 2153 del op_type 2154 return wrap(op_func(pfor_input.stacked_input(0)), True) 2155 2156 2157# array_ops 2158 2159 2160@RegisterPForWithArgs("Identity", array_ops.identity) 2161@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) 2162@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag) 2163@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) 2164@RegisterPForWithArgs("_EagerConst", array_ops.identity) 2165def _convert_identity(pfor_input, op_type, op_func): 2166 del op_type 2167 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 2168 2169 2170@RegisterPFor("IdentityN") 2171def _convert_identity_n(pfor_input): 2172 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) 2173 return [ 2174 wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs) 2175 ] 2176 2177 2178@RegisterPFor("Reshape") 2179def _convert_reshape(pfor_input): 2180 t = pfor_input.stacked_input(0) 2181 shape = pfor_input.unstacked_input(1) 2182 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2183 return wrap(array_ops.reshape(t, new_shape), True) 2184 2185 2186@RegisterPFor("Fill") 2187def _convert_fill(pfor_input): 2188 dims = pfor_input.unstacked_input(0) 2189 value = pfor_input.stacked_input(1) 2190 # Expand the rank of `value` 2191 new_shape = array_ops.concat( 2192 [[-1], array_ops.ones([array_ops.size(dims)], dtype=dtypes.int32)], 2193 axis=0) 2194 value = array_ops.reshape(value, new_shape) 2195 # Compute the new output shape 2196 new_dims = array_ops.concat([pfor_input.pfor.loop_len_vector, dims], axis=0) 2197 # Broadcast 2198 return wrap(array_ops.broadcast_to(value, new_dims), True) 2199 2200 2201@RegisterPFor("BroadcastTo") 2202def _convert_broadcast_to(pfor_input): 2203 t = pfor_input.stacked_input(0) 2204 shape = pfor_input.unstacked_input(1) 2205 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2206 2207 # Expand dims of stacked t to broadcast against the new shape. 2208 # TODO(davmre): consider factoring out common code with 2209 # `expanddim_inputs_for_broadcast`, which has similar logic but with 2210 # implicit shapes (of input Tensors) rather than explicit shapes. 2211 rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t) 2212 ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1])) 2213 t_shape = array_ops.shape(t) 2214 t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0) 2215 2216 return wrap( 2217 array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape), 2218 True) 2219 2220 2221@RegisterPFor("ExpandDims") 2222def _convert_expanddims(pfor_input): 2223 t = pfor_input.stacked_input(0) 2224 dim = pfor_input.unstacked_input(1) 2225 dim += math_ops.cast(dim >= 0, dim.dtype) 2226 return wrap(array_ops.expand_dims(t, axis=dim), True) 2227 2228 2229@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound) 2230@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound) 2231def _convert_searchsorted(pfor_input, _, op_func): 2232 pfor_input.stack_inputs() 2233 sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0)) 2234 values = _flatten_first_two_dims(pfor_input.stacked_input(1)) 2235 out_type = pfor_input.get_attr("out_type") 2236 output = op_func(sorted_inputs, values, out_type) 2237 return wrap( 2238 _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True) 2239 2240 2241@RegisterPFor("MatrixBandPart") 2242def _convert_matrix_band_part(pfor_input): 2243 t = pfor_input.stacked_input(0) 2244 num_lower = pfor_input.unstacked_input(1) 2245 num_upper = pfor_input.unstacked_input(2) 2246 return wrap( 2247 array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper), 2248 True) 2249 2250 2251@RegisterPFor("MatrixSetDiag") 2252def _convert_matrix_set_diag(pfor_input): 2253 pfor_input.stack_inputs() 2254 t = pfor_input.stacked_input(0) 2255 diag = pfor_input.stacked_input(1) 2256 return wrap(array_ops.matrix_set_diag(t, diag), True) 2257 2258 2259# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. 2260# The input orders defined in the OpKernel and the actual python API are 2261# different (for compatibility with V1), so we cannot use _convert_identity. 2262# v2 is not compatible with v3 and is never exposed on the public API. 2263@RegisterPFor("MatrixDiagV2") 2264@RegisterPFor("MatrixDiagV3") 2265def _convert_matrix_diag_v2(pfor_input): 2266 params = { 2267 "diagonal": pfor_input.stacked_input(0), 2268 "k": pfor_input.unstacked_input(1), 2269 "num_rows": pfor_input.unstacked_input(2), 2270 "num_cols": pfor_input.unstacked_input(3), 2271 "padding_value": pfor_input.unstacked_input(4) 2272 } 2273 if pfor_input.op_type == "MatrixDiagV2": 2274 return wrap(array_ops.matrix_diag_v2(**params), True) 2275 params["align"] = pfor_input.get_attr("align") 2276 return wrap(array_ops.matrix_diag(**params), True) 2277 2278 2279@RegisterPFor("Diag") 2280def _convert_diag(pfor_input): 2281 diag = pfor_input.stacked_input(0) 2282 if diag.shape.ndims == 2: 2283 # We can use matrix_diag. 2284 return wrap(array_ops.matrix_diag(diag), True) 2285 else: 2286 # It is not clear if we can do better than a while loop here with existing 2287 # kernels. 2288 return _fallback_converter(pfor_input, warn=False) 2289 2290 2291# See notes for MatrixDiagV2 2292@RegisterPFor("MatrixDiagPartV2") 2293@RegisterPFor("MatrixDiagPartV3") 2294def _convert_matrix_diag_part_v2(pfor_input): 2295 params = { 2296 "input": pfor_input.stacked_input(0), 2297 "k": pfor_input.unstacked_input(1), 2298 "padding_value": pfor_input.unstacked_input(2) 2299 } 2300 if pfor_input.op_type == "MatrixDiagPartV2": 2301 return wrap(array_ops.matrix_diag_part_v2(**params), True) 2302 params["align"] = pfor_input.get_attr("align") 2303 return wrap(array_ops.matrix_diag_part(**params), True) 2304 2305 2306# See notes for MatrixDiagV2 2307@RegisterPFor("MatrixSetDiagV2") 2308@RegisterPFor("MatrixSetDiagV3") 2309def _convert_matrix_set_diag_v2(pfor_input): 2310 pfor_input.stack_inputs([0, 1]) 2311 params = { 2312 "input": pfor_input.stacked_input(0), 2313 "diagonal": pfor_input.stacked_input(1), 2314 "k": pfor_input.unstacked_input(2) 2315 } 2316 if pfor_input.op_type == "MatrixSetDiagV2": 2317 return wrap(array_ops.matrix_set_diag_v2(**params), True) 2318 params["align"] = pfor_input.get_attr("align") 2319 return wrap(array_ops.matrix_set_diag(**params), True) 2320 2321 2322@RegisterPFor("DiagPart") 2323def _convert_diag_part(pfor_input): 2324 inp = pfor_input.stacked_input(0) 2325 if inp.shape.ndims == 3: 2326 # We can use matrix_diag_part. 2327 return wrap(array_ops.matrix_diag_part(inp), True) 2328 else: 2329 # It is not clear if we can do better than a while loop here with existing 2330 # kernels. 2331 return _fallback_converter(pfor_input, warn=False) 2332 2333 2334@RegisterPFor("OneHot") 2335def _convert_one_hot(pfor_input): 2336 indices = pfor_input.stacked_input(0) 2337 depth = pfor_input.unstacked_input(1) 2338 on_value = pfor_input.unstacked_input(2) 2339 off_value = pfor_input.unstacked_input(3) 2340 axis = pfor_input.get_attr("axis") 2341 if axis >= 0: 2342 axis += 1 2343 return wrap( 2344 array_ops.one_hot(indices, depth, on_value, off_value, axis), True) 2345 2346 2347@RegisterPFor("Slice") 2348def _convert_slice(pfor_input): 2349 t = pfor_input.stacked_input(0) 2350 begin, begin_stacked, _ = pfor_input.input(1) 2351 size = pfor_input.unstacked_input(2) 2352 if not begin_stacked: 2353 begin = array_ops.concat([[0], begin], axis=0) 2354 size = array_ops.concat([[-1], size], axis=0) 2355 return wrap(array_ops.slice(t, begin, size), True) 2356 else: 2357 # Handle negative sizes. 2358 # 2359 # If the `begin` entry corresponding to a negative `size` is loop-variant, 2360 # the output would be ragged. This case is not supported. But `size` having 2361 # some negative values and some loop-variant `begin`s is OK (and it's hard 2362 # to tell the difference statically). 2363 original_unstacked_shape = _stack( 2364 array_ops.shape(t)[1:], pfor_input.pfor.loop_len_vector).t 2365 broadcast_size = _stack(size, pfor_input.pfor.loop_len_vector).t 2366 result_shape = array_ops.where( 2367 math_ops.less(broadcast_size, 0), 2368 original_unstacked_shape - begin + broadcast_size + 1, broadcast_size) 2369 result_shape = math_ops.cast(math_ops.reduce_max(result_shape, axis=0), 2370 dtypes.int64) 2371 2372 # Now we enumerate points in the sliced region for each pfor iteration and 2373 # gather them. 2374 cumsize = math_ops.cumprod(result_shape, exclusive=True, reverse=True) 2375 result_num_elements = math_ops.reduce_prod(result_shape) 2376 # Offsets are loop-variant. We first compute loop-invariant gather 2377 # coordinates, then broadcast-add the loop-variant `begin` offsets. 2378 result_base_coordinates = ( 2379 math_ops.range(result_num_elements, dtype=dtypes.int64)[:, None] 2380 // cumsize[None, :]) % result_shape[None, :] 2381 result_coordinates = ( 2382 begin[:, None, :] 2383 + math_ops.cast(result_base_coordinates, begin.dtype)[None, :, :]) 2384 result_flat = array_ops.gather_nd(params=t, indices=result_coordinates, 2385 batch_dims=1) 2386 result_stacked_shape = array_ops.concat( 2387 [math_ops.cast(pfor_input.pfor.loop_len_vector, result_shape.dtype), 2388 result_shape], 2389 axis=0) 2390 return wrap(array_ops.reshape(result_flat, result_stacked_shape), True) 2391 2392 2393@RegisterPFor("Tile") 2394def _convert_tile(pfor_input): 2395 t = pfor_input.stacked_input(0) 2396 multiples = pfor_input.unstacked_input(1) 2397 multiples = array_ops.concat([[1], multiples], 0) 2398 return wrap(array_ops.tile(t, multiples), True) 2399 2400 2401@RegisterPFor("Pack") 2402def _convert_pack(pfor_input): 2403 pfor_input.stack_inputs() 2404 axis = pfor_input.get_attr("axis") 2405 if axis >= 0: 2406 axis += 1 2407 return wrap( 2408 array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) 2409 2410 2411@RegisterPFor("Unpack") 2412def _convert_unpack(pfor_input): 2413 value = pfor_input.stacked_input(0) 2414 axis = pfor_input.get_attr("axis") 2415 if axis >= 0: 2416 axis += 1 2417 num = pfor_input.get_attr("num") 2418 return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] 2419 2420 2421@RegisterPFor("Pad") 2422def _convert_pad(pfor_input): 2423 t = pfor_input.stacked_input(0) 2424 paddings = pfor_input.unstacked_input(1) 2425 paddings = array_ops.concat([[[0, 0]], paddings], 0) 2426 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) 2427 2428 2429@RegisterPFor("PadV2") 2430def _convert_pad_v2(pfor_input): 2431 t = pfor_input.stacked_input(0) 2432 paddings = pfor_input.unstacked_input(1) 2433 paddings = array_ops.concat([[[0, 0]], paddings], 0) 2434 return wrap(array_ops.pad_v2(t, paddings, mode="CONSTANT"), True) 2435 2436 2437@RegisterPFor("Split") 2438def _convert_split(pfor_input): 2439 split_dim = pfor_input.unstacked_input(0) 2440 t = pfor_input.stacked_input(1) 2441 num_split = pfor_input.get_attr("num_split") 2442 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2443 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] 2444 2445 2446@RegisterPFor("SplitV") 2447def _convert_split_v(pfor_input): 2448 t = pfor_input.stacked_input(0) 2449 splits = pfor_input.unstacked_input(1) 2450 split_dim = pfor_input.unstacked_input(2) 2451 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2452 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] 2453 2454 2455@RegisterPFor("Squeeze") 2456def _convert_squeeze(pfor_input): 2457 t = pfor_input.stacked_input(0) 2458 squeeze_dims = pfor_input.get_attr("squeeze_dims") 2459 squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] 2460 return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) 2461 2462 2463@RegisterPFor("ReverseV2") 2464def _convert_reverse(pfor_input): 2465 value = pfor_input.stacked_input(0) 2466 axis = pfor_input.unstacked_input(1) 2467 new_axis = array_ops.where_v2(axis >= 0, axis + 1, axis) 2468 return wrap(gen_array_ops.reverse_v2(value, axis=new_axis), True) 2469 2470 2471@RegisterPForWithArgs("Transpose", gen_array_ops.transpose) 2472@RegisterPForWithArgs("ConjugateTranspose", gen_array_ops.conjugate_transpose) 2473def _convert_transpose(pfor_input, _, op_func): 2474 t = pfor_input.stacked_input(0) 2475 perm = pfor_input.unstacked_input(1) 2476 new_perm = array_ops.concat([[0], perm + 1], axis=0) 2477 return wrap(op_func(t, new_perm), True) 2478 2479 2480@RegisterPFor("ZerosLike") 2481def _convert_zeroslike(pfor_input): 2482 t = pfor_input.stacked_input(0) 2483 shape = array_ops.shape(t)[1:] 2484 return wrap(array_ops.zeros(shape, dtype=t.dtype), False) 2485 2486 2487@RegisterPFor("Gather") 2488@RegisterPFor("GatherV2") 2489def _convert_gather(pfor_input): 2490 param, param_stacked, _ = pfor_input.input(0) 2491 indices, indices_stacked, _ = pfor_input.input(1) 2492 batch_dims = pfor_input.get_attr("batch_dims") 2493 2494 op_type = pfor_input.op_type 2495 if op_type == "Gather": 2496 validate_indices = pfor_input.get_attr("validate_indices") 2497 axis = 0 2498 else: 2499 validate_indices = None 2500 # Assume we will never have a Tensor with rank > 2**32. 2501 axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32) 2502 axis_value = tensor_util.constant_value(axis) 2503 if axis_value is not None: 2504 axis = axis_value 2505 if indices_stacked and not param_stacked: 2506 if indices is pfor_input.pfor.all_indices and axis == 0: 2507 param_shape0 = tensor_shape.dimension_value(param.shape[0]) 2508 indices_shape0 = tensor_shape.dimension_value(indices.shape[0]) 2509 if param_shape0 is not None and indices_shape0 == param_shape0: 2510 # Note that with loops and conditionals, indices may not be contiguous. 2511 # However they will be sorted and unique. So if the shape matches, then 2512 # it must be picking up all the rows of param. 2513 return wrap(param, True) 2514 2515 if batch_dims != 0: 2516 # Convert `batch_dims` to its positive equivalent if necessary. 2517 batch_dims_pos = batch_dims 2518 if batch_dims < 0: 2519 batch_dims_pos += array_ops.rank(indices) 2520 # In order to maintain 2521 # indices.shape[:batch_dims] == params.shape[:batch_dims] 2522 # with stacked indices, we move the first dimension of `indices` to the 2523 # `batch_dims + 1`th position. The (non-batch) index dimensions will be 2524 # inserted into the shape of `output` at the `axis` dimension, which is 2525 # then transposed to the front (below). 2526 order = array_ops.concat([ 2527 math_ops.range(1, batch_dims_pos + 1), 2528 [0], 2529 math_ops.range(batch_dims_pos + 1, array_ops.rank(indices))], axis=0) 2530 indices = array_ops.transpose(indices, order) 2531 2532 output = array_ops.gather( 2533 param, indices, validate_indices=validate_indices, axis=axis, 2534 batch_dims=batch_dims) 2535 if axis != 0: 2536 axis = smart_cond.smart_cond(axis < 0, 2537 lambda: axis + array_ops.rank(param), 2538 lambda: ops.convert_to_tensor(axis)) 2539 order = array_ops.concat( 2540 [[axis], 2541 math_ops.range(axis), 2542 math_ops.range(axis + 1, array_ops.rank(output))], 2543 axis=0) 2544 output = smart_cond.smart_cond( 2545 math_ops.equal(axis, 0), lambda: output, 2546 lambda: array_ops.transpose(output, order)) 2547 return wrap(output, True) 2548 if param_stacked: 2549 pfor_input.stack_inputs(stack_indices=[1]) 2550 indices = pfor_input.stacked_input(1) 2551 if isinstance(axis, ops.Tensor): 2552 axis = array_ops.where(axis >= 0, axis + 1, axis) 2553 else: 2554 axis = axis + 1 if axis >= 0 else axis 2555 batch_dims = batch_dims + 1 if batch_dims >= 0 else batch_dims 2556 output = array_ops.gather(param, indices, axis=axis, batch_dims=batch_dims) 2557 return wrap(output, True) 2558 2559 2560@RegisterPFor("GatherNd") 2561def _convert_gather_nd(pfor_input): 2562 # TODO(jmenick): Add support for unstacked params. 2563 pfor_input.stack_inputs(stack_indices=[1]) 2564 params = pfor_input.stacked_input(0) 2565 indices = pfor_input.stacked_input(1) 2566 stacked_result = array_ops.gather_nd(params, indices, batch_dims=1) 2567 return wrap(stacked_result, True) 2568 2569 2570@RegisterPFor("ConcatV2") 2571def _convert_concatv2(pfor_input): 2572 n = pfor_input.num_inputs 2573 pfor_input.stack_inputs(stack_indices=range(n - 1)) 2574 axis = pfor_input.unstacked_input(n - 1) 2575 axis += math_ops.cast(axis >= 0, axis.dtype) 2576 return wrap( 2577 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), 2578 True) 2579 2580 2581@RegisterPFor("StridedSlice") 2582def _convert_strided_slice(pfor_input): 2583 inp = pfor_input.stacked_input(0) 2584 begin = pfor_input.unstacked_input(1) 2585 end = pfor_input.unstacked_input(2) 2586 strides = pfor_input.unstacked_input(3) 2587 begin_mask = pfor_input.get_attr("begin_mask") 2588 end_mask = pfor_input.get_attr("end_mask") 2589 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2590 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2591 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2592 2593 begin = array_ops.concat([[0], begin], axis=0) 2594 end = array_ops.concat([[0], end], axis=0) 2595 strides = array_ops.concat([[1], strides], axis=0) 2596 begin_mask = begin_mask << 1 | 1 2597 end_mask = end_mask << 1 | 1 2598 ellipsis_mask <<= 1 2599 new_axis_mask <<= 1 2600 shrink_axis_mask <<= 1 2601 return wrap( 2602 array_ops.strided_slice( 2603 inp, 2604 begin, 2605 end, 2606 strides, 2607 begin_mask=begin_mask, 2608 end_mask=end_mask, 2609 ellipsis_mask=ellipsis_mask, 2610 new_axis_mask=new_axis_mask, 2611 shrink_axis_mask=shrink_axis_mask), True) 2612 2613 2614@RegisterPFor("StridedSliceGrad") 2615def _convert_strided_slice_grad(pfor_input): 2616 shape = pfor_input.unstacked_input(0) 2617 begin = pfor_input.unstacked_input(1) 2618 end = pfor_input.unstacked_input(2) 2619 strides = pfor_input.unstacked_input(3) 2620 dy = pfor_input.stacked_input(4) 2621 begin_mask = pfor_input.get_attr("begin_mask") 2622 end_mask = pfor_input.get_attr("end_mask") 2623 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2624 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2625 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2626 2627 shape = array_ops.concat( 2628 [math_ops.cast(pfor_input.pfor.loop_len_vector, shape.dtype), shape], 2629 axis=0) 2630 begin = array_ops.concat([[0], begin], axis=0) 2631 end = array_ops.concat([[0], end], axis=0) 2632 strides = array_ops.concat([[1], strides], axis=0) 2633 begin_mask = begin_mask << 1 | 1 2634 end_mask = end_mask << 1 | 1 2635 ellipsis_mask <<= 1 2636 new_axis_mask <<= 1 2637 shrink_axis_mask <<= 1 2638 return wrap( 2639 array_ops.strided_slice_grad( 2640 shape, 2641 begin, 2642 end, 2643 strides, 2644 dy, 2645 begin_mask=begin_mask, 2646 end_mask=end_mask, 2647 ellipsis_mask=ellipsis_mask, 2648 new_axis_mask=new_axis_mask, 2649 shrink_axis_mask=shrink_axis_mask), True) 2650 2651 2652@RegisterPFor("CheckNumerics") 2653def _convert_check_numerics(pfor_input): 2654 t = pfor_input.stacked_input(0) 2655 message = pfor_input.get_attr("message") 2656 return wrap(gen_array_ops.check_numerics(t, message), True) 2657 2658 2659@RegisterPFor("EnsureShape") 2660def _convert_ensure_shape(pfor_input): 2661 t = pfor_input.stacked_input(0) 2662 shape = tensor_shape.TensorShape(pfor_input.get_attr("shape")) 2663 return wrap(gen_array_ops.ensure_shape(t, [None] + shape), True) 2664 2665 2666# manip_ops 2667 2668 2669@RegisterPFor("Roll") 2670def _convert_roll(pfor_input): 2671 t = pfor_input.stacked_input(0) 2672 shift, shift_stacked, _ = pfor_input.input(1) 2673 axis = pfor_input.unstacked_input(2) 2674 if not shift_stacked: 2675 return wrap(manip_ops.roll(t, shift, axis + 1), True) 2676 else: 2677 # `axis` and `shift` may both be vectors, with repeated axes summing the 2678 # corresponding `shift`s. We scatter shifts into a dense array of shape 2679 # [loop_len, num_unstacked_axes] indicating the offset for each axis. 2680 num_unstacked_axes = math_ops.cast(array_ops.rank(t), dtypes.int64) - 1 2681 axis = math_ops.cast(array_ops.reshape(axis, [-1]), dtypes.int64) 2682 loop_len = math_ops.cast(pfor_input.pfor.loop_len_vector[0], dtypes.int64) 2683 shift = math_ops.cast(array_ops.reshape(shift, [loop_len, -1]), 2684 dtypes.int64) 2685 axis_segment_ids = ( 2686 math_ops.range(loop_len, dtype=dtypes.int64)[:, None] 2687 * num_unstacked_axes + axis[None, :]) 2688 axis_offsets = array_ops.reshape( 2689 math_ops.unsorted_segment_sum( 2690 data=shift, segment_ids=axis_segment_ids, 2691 num_segments=loop_len * num_unstacked_axes), 2692 [loop_len, num_unstacked_axes]) 2693 2694 # Determine the coordinates in the input array of each result and gather 2695 # them. 2696 unstacked_shape = array_ops.shape(t, out_type=dtypes.int64)[1:] 2697 cumsize = math_ops.cumprod(unstacked_shape, exclusive=True, reverse=True) 2698 num_unstacked_elements = math_ops.reduce_prod(unstacked_shape) 2699 result_coordinates = ( 2700 (math_ops.range(num_unstacked_elements, 2701 dtype=dtypes.int64)[None, :, None] 2702 // cumsize[None, None, :] - axis_offsets[:, None, :]) 2703 % unstacked_shape[None, None, :]) 2704 result_flat = array_ops.gather_nd(params=t, indices=result_coordinates, 2705 batch_dims=1) 2706 return wrap(array_ops.reshape(result_flat, array_ops.shape(t)), 2707 True) 2708 2709# math_ops 2710 2711 2712@RegisterPFor("MatMul") 2713def _convert_matmul(pfor_input): 2714 # TODO(agarwal): Check if tiling is faster than two transposes. 2715 a, a_stacked, _ = pfor_input.input(0) 2716 b, b_stacked, _ = pfor_input.input(1) 2717 tr_a = pfor_input.get_attr("transpose_a") 2718 tr_b = pfor_input.get_attr("transpose_b") 2719 if a_stacked and b_stacked: 2720 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) 2721 return output 2722 elif a_stacked: 2723 if tr_a: 2724 a = array_ops.transpose(a, [0, 2, 1]) 2725 if a.shape.is_fully_defined(): 2726 x, y, z = a.shape 2727 else: 2728 x, y, z = [ 2729 array_ops.reshape(i, []) 2730 for i in array_ops.split(array_ops.shape(a), 3) 2731 ] 2732 a = array_ops.reshape(a, [x * y, z]) 2733 prod = math_ops.matmul(a, b, transpose_b=tr_b) 2734 return wrap(array_ops.reshape(prod, [x, y, -1]), True) 2735 else: 2736 assert b_stacked 2737 if tr_b: 2738 perm = [2, 0, 1] 2739 b = array_ops.transpose(b, perm) 2740 else: 2741 # As an optimization, if one of the first two dimensions is 1, then we can 2742 # reshape instead of transpose. 2743 # TODO(agarwal): This check can be done inside Transpose kernel. 2744 b_shape = array_ops.shape(b) 2745 min_dim = math_ops.minimum(b_shape[0], b_shape[1]) 2746 perm = array_ops.where( 2747 math_ops.equal(min_dim, 1), [0, 1, 2], [1, 0, 2]) 2748 new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) 2749 b = array_ops.transpose(b, perm) 2750 b = array_ops.reshape(b, new_shape) 2751 2752 if b.shape.is_fully_defined(): 2753 x, y, z = b.shape 2754 else: 2755 x, y, z = [ 2756 array_ops.reshape(i, []) 2757 for i in array_ops.split(array_ops.shape(b), 3) 2758 ] 2759 b = array_ops.reshape(b, [x, y * z]) 2760 prod = math_ops.matmul(a, b, transpose_a=tr_a) 2761 prod = array_ops.reshape(prod, [-1, y, z]) 2762 prod = array_ops.transpose(prod, [1, 0, 2]) 2763 return wrap(prod, True) 2764 2765 2766# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window 2767# is met. 2768@RegisterPFor("BatchMatMul") 2769def _convert_batch_mat_mul(pfor_input): 2770 # TODO(agarwal): There may be a more efficient way to do this instead of 2771 # stacking the inputs. 2772 pfor_input.stack_inputs() 2773 x = pfor_input.stacked_input(0) 2774 y = pfor_input.stacked_input(1) 2775 adj_x = pfor_input.get_attr("adj_x") 2776 adj_y = pfor_input.get_attr("adj_y") 2777 2778 x = _flatten_first_two_dims(x) 2779 y = _flatten_first_two_dims(y) 2780 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2781 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) 2782 return wrap(output, True) 2783 2784 2785@RegisterPFor("BatchMatMulV2") 2786def _convert_batch_mat_mul_v2(pfor_input): 2787 pfor_input.expanddim_inputs_for_broadcast() 2788 x = pfor_input.input(0)[0] 2789 y = pfor_input.input(1)[0] 2790 adj_x = pfor_input.get_attr("adj_x") 2791 adj_y = pfor_input.get_attr("adj_y") 2792 2793 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2794 return wrap(output, True) 2795 2796 2797@RegisterPForWithArgs("Sum", math_ops.reduce_sum) 2798@RegisterPForWithArgs("Prod", math_ops.reduce_prod) 2799@RegisterPForWithArgs("Max", math_ops.reduce_max) 2800@RegisterPForWithArgs("Min", math_ops.reduce_min) 2801@RegisterPForWithArgs("Mean", math_ops.reduce_mean) 2802@RegisterPForWithArgs("All", math_ops.reduce_all) 2803@RegisterPForWithArgs("Any", math_ops.reduce_any) 2804def _convert_reduction(pfor_input, _, op_func): 2805 t = pfor_input.stacked_input(0) 2806 indices = pfor_input.unstacked_input(1) 2807 # Shift positive indices by one to account for the extra dimension. 2808 indices += math_ops.cast(indices >= 0, indices.dtype) 2809 keep_dims = pfor_input.get_attr("keep_dims") 2810 return wrap(op_func(t, indices, keepdims=keep_dims), True) 2811 2812 2813@RegisterPForWithArgs("ArgMax", math_ops.argmax) 2814@RegisterPForWithArgs("ArgMin", math_ops.argmin) 2815def _convert_argmax_argmin(pfor_input, _, op_func): 2816 t = pfor_input.stacked_input(0) 2817 dimension = pfor_input.unstacked_input(1) 2818 dimension += math_ops.cast(dimension >= 0, dimension.dtype) 2819 output_type = pfor_input.get_attr("output_type") 2820 return wrap(op_func(t, axis=dimension, output_type=output_type), True) 2821 2822 2823@RegisterPFor("Bucketize") 2824def _convert_bucketize(pfor_input): 2825 t = pfor_input.stacked_input(0) 2826 boundaries = pfor_input.get_attr("boundaries") 2827 return wrap(math_ops.bucketize(t, boundaries), True) 2828 2829 2830@RegisterPFor("ClipByValue") 2831def _convert_clip_by_value(pfor_input): 2832 t = pfor_input.stacked_input(0) 2833 clip_value_min = pfor_input.unstacked_input(1) 2834 clip_value_max = pfor_input.unstacked_input(2) 2835 return wrap(gen_math_ops.clip_by_value(t, clip_value_min, clip_value_max), 2836 True) 2837 2838 2839@RegisterPForWithArgs("Cumsum", math_ops.cumsum) 2840@RegisterPForWithArgs("Cumprod", math_ops.cumprod) 2841def _convert_cumfoo(pfor_input, _, op_func): 2842 t = pfor_input.stacked_input(0) 2843 axis = pfor_input.unstacked_input(1) 2844 # Shift positive indices by one to account for the extra dimension. 2845 axis += math_ops.cast(axis >= 0, axis.dtype) 2846 exclusive = pfor_input.get_attr("exclusive") 2847 reverse = pfor_input.get_attr("reverse") 2848 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) 2849 2850 2851@RegisterPFor("BiasAdd") 2852def _convert_biasadd(pfor_input): 2853 t, t_stacked, _ = pfor_input.input(0) 2854 bias, bias_stacked, _ = pfor_input.input(1) 2855 data_format = pfor_input.get_attr("data_format").decode() 2856 if bias_stacked: 2857 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. 2858 pfor_input.expanddim_inputs_for_broadcast() 2859 t, _, _ = pfor_input.input(0) 2860 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) 2861 if compat.as_bytes(data_format) == b"NCHW": 2862 b_shape = array_ops.shape(bias) 2863 new_b_shape = array_ops.concat( 2864 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) 2865 bias = array_ops.reshape(bias, new_b_shape) 2866 return wrap(math_ops.add(t, bias), True) 2867 else: 2868 assert t_stacked, "At least one input to BiasAdd should be loop variant." 2869 if compat.as_bytes(data_format) == b"NCHW": 2870 shape = array_ops.shape(t) 2871 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) 2872 t = array_ops.reshape(t, flattened_shape) 2873 t = nn_ops.bias_add(t, bias, data_format="NCHW") 2874 t = array_ops.reshape(t, shape) 2875 return wrap(t, True) 2876 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) 2877 2878 2879@RegisterPForWithArgs("UnsortedSegmentSum", math_ops.unsorted_segment_sum) 2880@RegisterPForWithArgs("UnsortedSegmentMax", math_ops.unsorted_segment_max) 2881@RegisterPForWithArgs("UnsortedSegmentMin", math_ops.unsorted_segment_min) 2882@RegisterPForWithArgs("UnsortedSegmentProd", math_ops.unsorted_segment_prod) 2883def _convert_unsortedsegmentsum(pfor_input, _, op_func): 2884 pfor_input.stack_inputs([0, 1]) 2885 data = pfor_input.stacked_input(0) 2886 segment_ids = pfor_input.stacked_input(1) 2887 # TODO(agarwal): handle stacked? 2888 num_segments = pfor_input.unstacked_input(2) 2889 if segment_ids.dtype != num_segments.dtype: 2890 segment_ids = math_ops.cast(segment_ids, dtypes.int64) 2891 num_segments = math_ops.cast(num_segments, dtypes.int64) 2892 dtype = segment_ids.dtype 2893 segment_shape = array_ops.shape(segment_ids, out_type=dtype) 2894 n = segment_shape[0] 2895 ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:] 2896 segment_offset = num_segments * math_ops.range(n, dtype=dtype) 2897 segment_offset = array_ops.reshape(segment_offset, 2898 array_ops.concat([[n], ones], axis=0)) 2899 segment_ids += segment_offset 2900 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( 2901 n, dtypes.int64) 2902 output = op_func(data, segment_ids, num_segments) 2903 new_output_shape = array_ops.concat( 2904 [[n, -1], array_ops.shape(output)[1:]], axis=0) 2905 output = array_ops.reshape(output, new_output_shape) 2906 return wrap(output, True) 2907 2908 2909def _flatten_array_with_offset(ids, offset_delta, num_rows): 2910 """Flattens a rank 2 tensor, adding an offset to each row.""" 2911 # Note that if `ids` is rank 1, it is broadcast to rank 2. 2912 offset_delta = math_ops.cast(offset_delta, ids.dtype) 2913 n = math_ops.cast(num_rows, dtype=ids.dtype) 2914 offsets = math_ops.range( 2915 start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype) 2916 offsets = array_ops.expand_dims(offsets, -1) 2917 ids += offsets 2918 return array_ops.reshape(ids, [-1]) 2919 2920 2921@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2) 2922@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2) 2923@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2) 2924@RegisterPForWithArgs("SparseSegmentSumWithNumSegments", 2925 math_ops.sparse_segment_sum_v2) 2926@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments", 2927 math_ops.sparse_segment_mean_v2) 2928@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments", 2929 math_ops.sparse_segment_sqrt_n_v2) 2930def _convert_sparse_segment(pfor_input, _, op_func): 2931 _, segment_ids_stacked, _ = pfor_input.input(2) 2932 if segment_ids_stacked: 2933 pfor_input.stack_inputs([1]) 2934 data, data_stacked, _ = pfor_input.input(0) 2935 indices, _, _ = pfor_input.input(1) 2936 num_inputs = len(pfor_input.inputs) 2937 assert num_inputs in (3, 4) 2938 if num_inputs == 3: 2939 # `segment_ids` needs to be unstacked since otherwise output sizes could 2940 # differ across pfor iterations. 2941 segment_ids = pfor_input.unstacked_input(2) 2942 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2943 else: 2944 segment_ids, _, _ = pfor_input.input(2) 2945 num_segments = pfor_input.unstacked_input(3) 2946 2947 n = pfor_input.pfor.loop_len_vector[0] 2948 if data_stacked: 2949 indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n) 2950 data = _flatten_first_two_dims(data) 2951 else: 2952 indices = array_ops.reshape(indices, [-1]) 2953 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2954 2955 if num_inputs == 3: 2956 num_segments = None 2957 else: 2958 num_segments *= n 2959 output = op_func(data, indices, segment_ids, num_segments=num_segments) 2960 output = _unflatten_first_dim(output, [n]) 2961 return wrap(output, True) 2962 2963 2964@RegisterPForWithArgs("SparseSegmentSumGrad", math_ops.sparse_segment_sum_grad) 2965@RegisterPForWithArgs("SparseSegmentMeanGrad", 2966 math_ops.sparse_segment_mean_grad) 2967@RegisterPForWithArgs("SparseSegmentSqrtNGrad", 2968 math_ops.sparse_segment_sqrt_n_grad) 2969def _convert_sparse_segment_grad(pfor_input, _, op_func): 2970 grad = pfor_input.stacked_input(0) 2971 indices = pfor_input.unstacked_input(1) 2972 segment_ids = pfor_input.unstacked_input(2) 2973 dim0 = pfor_input.unstacked_input(3) 2974 2975 n = pfor_input.pfor.loop_len_vector[0] 2976 indices = _flatten_array_with_offset(indices, dim0, n) 2977 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2978 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2979 grad = _flatten_first_two_dims(grad) 2980 dim0 *= n 2981 output = op_func(grad, indices, segment_ids, dim0) 2982 output = _unflatten_first_dim(output, [n]) 2983 return wrap(output, True) 2984 2985 2986@RegisterPFor("Cast") 2987def _convert_cast(pfor_input): 2988 inp = pfor_input.stacked_input(0) 2989 dtype = pfor_input.get_attr("DstT") 2990 return wrap(math_ops.cast(inp, dtype), True) 2991 2992 2993@RegisterPFor("Abs") 2994@RegisterPFor("Acos") 2995@RegisterPFor("Acosh") 2996@RegisterPFor("Add") 2997@RegisterPFor("AddV2") 2998@RegisterPFor("Angle") 2999@RegisterPFor("Asin") 3000@RegisterPFor("Asinh") 3001@RegisterPFor("Atan") 3002@RegisterPFor("Atan2") 3003@RegisterPFor("Atanh") 3004@RegisterPFor("BesselI0") 3005@RegisterPFor("BesselI1") 3006@RegisterPFor("BesselI0e") 3007@RegisterPFor("BesselI1e") 3008@RegisterPFor("BesselK0") 3009@RegisterPFor("BesselK1") 3010@RegisterPFor("BesselK0e") 3011@RegisterPFor("BesselK1e") 3012@RegisterPFor("BesselJ0") 3013@RegisterPFor("BesselJ1") 3014@RegisterPFor("BesselY0") 3015@RegisterPFor("BesselY1") 3016@RegisterPFor("BitwiseAnd") 3017@RegisterPFor("BitwiseOr") 3018@RegisterPFor("BitwiseXor") 3019@RegisterPFor("Ceil") 3020@RegisterPFor("Complex") 3021@RegisterPFor("ComplexAbs") 3022@RegisterPFor("Conj") 3023@RegisterPFor("Cos") 3024@RegisterPFor("Cosh") 3025@RegisterPFor("Dawsn") 3026@RegisterPFor("Digamma") 3027@RegisterPFor("Div") 3028@RegisterPFor("DivNoNan") 3029@RegisterPFor("Elu") 3030@RegisterPFor("Erf") 3031@RegisterPFor("Erfc") 3032@RegisterPFor("Erfinv") 3033@RegisterPFor("Exp") 3034@RegisterPFor("Expint") 3035@RegisterPFor("Expm1") 3036@RegisterPFor("Floor") 3037@RegisterPFor("FloorDiv") 3038@RegisterPFor("FloorMod") 3039@RegisterPFor("FresnelCos") 3040@RegisterPFor("FresnelSin") 3041@RegisterPFor("Greater") 3042@RegisterPFor("GreaterEqual") 3043@RegisterPFor("Igamma") 3044@RegisterPFor("IgammaGradA") 3045@RegisterPFor("Igammac") 3046@RegisterPFor("Imag") 3047@RegisterPFor("Inv") 3048@RegisterPFor("Invert") 3049@RegisterPFor("IsFinite") 3050@RegisterPFor("IsInf") 3051@RegisterPFor("IsNan") 3052@RegisterPFor("LeftShift") 3053@RegisterPFor("Less") 3054@RegisterPFor("LessEqual") 3055@RegisterPFor("Lgamma") 3056@RegisterPFor("Log") 3057@RegisterPFor("Log1p") 3058@RegisterPFor("LogicalAnd") 3059@RegisterPFor("LogicalNot") 3060@RegisterPFor("LogicalOr") 3061@RegisterPFor("LogicalXor") 3062@RegisterPFor("Maximum") 3063@RegisterPFor("Minimum") 3064@RegisterPFor("Mod") 3065@RegisterPFor("Mul") 3066@RegisterPFor("MulNoNan") 3067@RegisterPFor("Ndtri") 3068@RegisterPFor("Neg") 3069@RegisterPFor("Polygamma") 3070@RegisterPFor("Pow") 3071@RegisterPFor("Real") 3072@RegisterPFor("RealDiv") 3073@RegisterPFor("Reciprocal") 3074@RegisterPFor("Relu") 3075@RegisterPFor("Relu6") 3076@RegisterPFor("RightShift") 3077@RegisterPFor("Rint") 3078@RegisterPFor("Round") 3079@RegisterPFor("Rsqrt") 3080@RegisterPFor("Selu") 3081@RegisterPFor("Sigmoid") 3082@RegisterPFor("Sign") 3083@RegisterPFor("Sin") 3084@RegisterPFor("Sinh") 3085@RegisterPFor("Softplus") 3086@RegisterPFor("Softsign") 3087@RegisterPFor("Spence") 3088@RegisterPFor("Sqrt") 3089@RegisterPFor("Square") 3090@RegisterPFor("SquaredDifference") 3091@RegisterPFor("Sub") 3092@RegisterPFor("Tan") 3093@RegisterPFor("Tanh") 3094@RegisterPFor("TruncateDiv") 3095@RegisterPFor("TruncateMod") 3096@RegisterPFor("Xdivy") 3097@RegisterPFor("Xlogy") 3098@RegisterPFor("Xlog1py") 3099@RegisterPFor("Zeta") 3100def _convert_cwise(pfor_input): 3101 if pfor_input.num_inputs > 1: 3102 pfor_input.expanddim_inputs_for_broadcast() 3103 3104 out = _create_op( 3105 pfor_input.op_type, [x.t for x in pfor_input.inputs], 3106 [x.dtype for x in pfor_input.outputs], 3107 attrs=pfor_input.op.node_def.attr).outputs 3108 assert len(out) == 1 3109 out = out[0] 3110 3111 op_output = wrap(out, True) 3112 return op_output 3113 3114 3115@RegisterPFor("XlaSharding") 3116def _convert_xla_sharding(pfor_input): 3117 t = pfor_input.stacked_input(0) 3118 sharding = pfor_input.get_attr("sharding") 3119 return wrap(xla.sharding(t, sharding=sharding), True) 3120 3121 3122@RegisterPFor("LeakyRelu") 3123def _convert_leaky_relu(pfor_input): 3124 t = pfor_input.stacked_input(0) 3125 alpha = pfor_input.get_attr("alpha") 3126 return wrap(gen_nn_ops.leaky_relu(t, alpha=alpha), True) 3127 3128 3129@RegisterPFor("Equal") 3130def _convert_equal(pfor_input): 3131 pfor_input.expanddim_inputs_for_broadcast() 3132 x = pfor_input.input(0)[0] 3133 y = pfor_input.input(1)[0] 3134 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 3135 return wrap(gen_math_ops.equal( 3136 x, y, incompatible_shape_error=incompatible_shape_error), True) 3137 3138 3139@RegisterPFor("NotEqual") 3140def _convert_not_equal(pfor_input): 3141 pfor_input.expanddim_inputs_for_broadcast() 3142 x = pfor_input.input(0)[0] 3143 y = pfor_input.input(1)[0] 3144 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 3145 return wrap(gen_math_ops.not_equal( 3146 x, y, incompatible_shape_error=incompatible_shape_error), True) 3147 3148 3149@RegisterPFor("ApproximateEqual") 3150def _convert_approximate_equal(pfor_input): 3151 pfor_input.expanddim_inputs_for_broadcast() 3152 x = pfor_input.input(0)[0] 3153 y = pfor_input.input(1)[0] 3154 tolerance = pfor_input.get_attr("tolerance") 3155 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) 3156 3157 3158@RegisterPFor("Shape") 3159def _convert_shape(pfor_input): 3160 out_type = pfor_input.get_attr("out_type") 3161 return wrap( 3162 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], 3163 False) 3164 3165 3166@RegisterPFor("ShapeN") 3167def _convert_shape_n(pfor_input): 3168 out_type = pfor_input.get_attr("out_type") 3169 shapes = [ 3170 array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape( 3171 x, out_type=out_type) for x, stacked, _ in pfor_input.inputs 3172 ] 3173 return [wrap(x, False) for x in shapes] 3174 3175 3176@RegisterPFor("Size") 3177def _convert_size(pfor_input): 3178 out_type = pfor_input.get_attr("out_type") 3179 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) 3180 return wrap( 3181 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, 3182 False) 3183 3184 3185@RegisterPFor("Rank") 3186def _convert_rank(pfor_input): 3187 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) 3188 3189 3190@RegisterPFor("AddN") 3191def _convert_addn(pfor_input): 3192 # AddN does not support broadcasting. 3193 pfor_input.stack_inputs(tile_variants=False) 3194 return _wrap_and_tile_variants( 3195 math_ops.add_n([x.t for x in pfor_input.inputs]), 3196 pfor_input.pfor.loop_len_vector) 3197 3198 3199@RegisterPFor("Cross") 3200def _convert_cross(pfor_input): 3201 pfor_input.stack_inputs() 3202 a = pfor_input.stacked_input(0) 3203 b = pfor_input.stacked_input(1) 3204 return wrap(math_ops.cross(a, b), True) 3205 3206 3207@RegisterPFor("BiasAddGrad") 3208def _convert_biasaddgrad(pfor_input): 3209 grad = pfor_input.stacked_input(0) 3210 fmt = pfor_input.get_attr("data_format") 3211 if fmt == b"NCHW": 3212 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) 3213 else: 3214 grad_shape = array_ops.shape(grad) 3215 last_dim_shape = grad_shape[-1] 3216 first_dim_shape = grad_shape[0] 3217 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) 3218 output = math_ops.reduce_sum(output, axis=[1], keepdims=False) 3219 return wrap(output, True) 3220 3221 3222# Some required ops are not exposed under the tf namespace. Hence relying on 3223# _create_op to create them. 3224@RegisterPForWithArgs("EluGrad") 3225@RegisterPForWithArgs("LeakyReluGrad") 3226@RegisterPForWithArgs("ReciprocalGrad") 3227@RegisterPForWithArgs("Relu6Grad") 3228@RegisterPForWithArgs("ReluGrad") 3229@RegisterPForWithArgs("RsqrtGrad") 3230@RegisterPForWithArgs("SeluGrad") 3231@RegisterPForWithArgs("SigmoidGrad") 3232@RegisterPForWithArgs("SoftplusGrad") 3233@RegisterPForWithArgs("SoftsignGrad") 3234@RegisterPForWithArgs("SqrtGrad") 3235@RegisterPForWithArgs("TanhGrad") 3236def _convert_grads(pfor_input, op_type, *args, **kw_args): 3237 del args 3238 del kw_args 3239 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we 3240 # have to use tiling here. 3241 pfor_input.stack_inputs() 3242 outputs = _create_op( 3243 op_type, [x.t for x in pfor_input.inputs], 3244 [x.dtype for x in pfor_input.outputs], 3245 attrs=pfor_input.op.node_def.attr).outputs 3246 return [wrap(x, True) for x in outputs] 3247 3248 3249@RegisterPFor("Select") 3250def _convert_select(pfor_input): 3251 pfor_input.stack_inputs() 3252 cond = pfor_input.stacked_input(0) 3253 t = pfor_input.stacked_input(1) 3254 e = pfor_input.stacked_input(2) 3255 cond_rank = array_ops.rank(cond) 3256 cond, t, e = smart_cond.smart_cond( 3257 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), 3258 lambda: [cond, t, e]) 3259 outputs = _create_op( 3260 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], 3261 attrs=pfor_input.op.node_def.attr).outputs 3262 n = pfor_input.pfor.loop_len_vector 3263 out = smart_cond.smart_cond(cond_rank > 1, 3264 lambda: _unflatten_first_dim(outputs[0], n), 3265 lambda: outputs[0]) 3266 return [wrap(out, True) for x in outputs] 3267 3268 3269@RegisterPFor("SelectV2") 3270def _convert_selectv2(pfor_input): 3271 pfor_input.expanddim_inputs_for_broadcast() 3272 cond = pfor_input.input(0)[0] 3273 t = pfor_input.input(1)[0] 3274 e = pfor_input.input(2)[0] 3275 out = array_ops.where_v2(cond, t, e) 3276 return wrap(out, True) 3277 3278 3279# random_ops 3280 3281 3282def _transpose_dim_to_front(x, dim): 3283 rank = array_ops.rank(x) 3284 return array_ops.transpose( 3285 x, 3286 perm=array_ops.concat( 3287 [[dim], math_ops.range(0, dim), 3288 math_ops.range(dim + 1, rank)], 3289 axis=0)) 3290 3291 3292@RegisterPForWithArgs("RandomUniform") 3293@RegisterPForWithArgs("RandomUniformInt") 3294@RegisterPForWithArgs("RandomStandardNormal") 3295@RegisterPForWithArgs("TruncatedNormal") 3296def _convert_random(pfor_input, op_type, *args, **kw_args): 3297 del args 3298 del kw_args 3299 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] 3300 # inputs[0] is "shape" 3301 inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]], 3302 axis=0) 3303 # TODO(b/222761732): Turn this warning back on when legacy RNGs are 3304 # deprecated. 3305 # logging.warning( 3306 # "Note that %s inside pfor op may not give same output as " 3307 # "inside a sequential loop.", op_type) 3308 outputs = _create_op( 3309 op_type, 3310 inputs, [x.dtype for x in pfor_input.outputs], 3311 attrs=pfor_input.op.node_def.attr).outputs 3312 return [wrap(x, True) for x in outputs] 3313 3314 3315@RegisterPFor("RandomGamma") 3316@RegisterPFor("RandomPoissonV2") 3317def _convert_random_with_param(pfor_input): 3318 shape = pfor_input.unstacked_input(0) 3319 # param is lam (Poisson rate) or alpha (Gamma shape). 3320 param, param_stacked, _ = pfor_input.input(1) 3321 # TODO(b/222761732): Turn this warning back on when legacy RNGs are 3322 # deprecated. 3323 # logging.warning( 3324 # "Note that %s inside pfor op may not give same output as " 3325 # "inside a sequential loop.", pfor_input.op_type) 3326 3327 if param_stacked: 3328 samples = _create_op( 3329 pfor_input.op_type, 3330 inputs=[shape, param], 3331 op_dtypes=[x.dtype for x in pfor_input.outputs], 3332 attrs=pfor_input.op.node_def.attr).outputs[0] 3333 loop_dim = array_ops.shape(shape)[0] 3334 stacked_samples = _transpose_dim_to_front(samples, loop_dim) 3335 else: 3336 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 3337 stacked_samples = _create_op( 3338 pfor_input.op_type, 3339 inputs=[shape, param], 3340 op_dtypes=[x.dtype for x in pfor_input.outputs], 3341 attrs=pfor_input.op.node_def.attr).outputs[0] 3342 3343 return wrap(stacked_samples, True) 3344 3345 3346@RegisterPFor("Multinomial") 3347def _convert_multinomial(pfor_input): 3348 logits, logits_stacked, _ = pfor_input.input(0) 3349 num_samples = pfor_input.unstacked_input(1) 3350 seed = pfor_input.get_attr("seed") 3351 seed2 = pfor_input.get_attr("seed2") 3352 output_dtype = pfor_input.get_attr("output_dtype") 3353 # TODO(b/222761732): Turn this warning back on when legacy RNGs are 3354 # deprecated. 3355 # logging.warning( 3356 # "Note that Multinomial inside pfor op may not give same output as " 3357 # "inside a sequential loop.") 3358 3359 n = pfor_input.pfor.loop_len_vector[0] 3360 if logits_stacked: 3361 flattened_logits = _flatten_first_two_dims(logits) 3362 samples = gen_random_ops.multinomial( 3363 flattened_logits, 3364 num_samples, 3365 seed=seed, 3366 seed2=seed2, 3367 output_dtype=output_dtype) 3368 stacked_samples = _unflatten_first_dim(samples, [n]) 3369 else: 3370 samples = gen_random_ops.multinomial( 3371 logits, 3372 num_samples * n, 3373 seed=seed, 3374 seed2=seed2, 3375 output_dtype=output_dtype) 3376 stacked_samples = array_ops.transpose( 3377 array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2]) 3378 3379 return wrap(stacked_samples, True) 3380 3381 3382@RegisterPFor("StatelessMultinomial") 3383@RegisterPFor("StatelessParameterizedTruncatedNormal") 3384@RegisterPFor("StatelessRandomBinomial") 3385@RegisterPFor("StatelessRandomGammaV2") 3386@RegisterPFor("StatelessRandomNormal") 3387@RegisterPFor("StatelessRandomPoisson") 3388@RegisterPFor("StatelessRandomUniform") 3389@RegisterPFor("StatelessRandomUniformInt") 3390@RegisterPFor("StatelessRandomUniformFullInt") 3391@RegisterPFor("StatelessTruncatedNormal") 3392def _convert_stateless_multinomial(pfor_input): 3393 # Unlike stateful random ops, for stateless ones we want better 3394 # reproducibility based on seed. Hence we don't want to use a similar strategy 3395 # as used for stateful ones where we generate a possibly different set of 3396 # random numbers under vectorization. 3397 # Unfortunately, the kernels currently are not necessarily setup to do this 3398 # efficiently and hence we fallback to a sequential loop for vectorization. 3399 return _fallback_converter(pfor_input, warn=False) 3400 3401 3402# linalg_ops 3403 3404 3405@RegisterPForWithArgs("XlaEinsum") 3406@RegisterPForWithArgs("Einsum") 3407def _convert_einsum(pfor_input, op_type): 3408 # Einsum may have either 1 or 2 inputs. 3409 inputs, input_stacked, _ = zip(*[ 3410 pfor_input.input(i) 3411 for i in range(pfor_input.num_inputs)]) 3412 3413 # Parse the einsum equation. 3414 equation = pfor_input.get_attr("equation").decode("utf-8") 3415 input_expr, output_expr = equation.split("->") 3416 input_exprs = input_expr.split(",") 3417 3418 # Pick a placeholder symbol to use for the new axis. 3419 chosen_symbol = None 3420 for s in string.ascii_letters: 3421 if s in equation: 3422 continue 3423 else: 3424 chosen_symbol = s 3425 break 3426 3427 if chosen_symbol is None: 3428 raise ValueError("Could not figure out what symbol to use for new axis.") 3429 3430 assert any(input_stacked) 3431 for i in range(len(inputs)): 3432 if input_stacked[i]: 3433 input_exprs[i] = "{}{}".format(chosen_symbol, input_exprs[i]) 3434 output_expr = "{}{}".format(chosen_symbol, output_expr) 3435 3436 new_equation = "{}->{}".format(",".join(input_exprs), output_expr) 3437 3438 if op_type == "XlaEinsum": 3439 if len(inputs) == 1: 3440 result = xla.einsum(equation=new_equation, a=inputs[0]) 3441 else: 3442 result = xla.einsum(equation=new_equation, a=inputs[0], b=inputs[1]) 3443 else: 3444 assert op_type == "Einsum" 3445 result = special_math_ops.einsum(new_equation, *inputs) 3446 3447 return wrap(result, True) 3448 3449 3450@RegisterPFor("Cholesky") 3451def _convert_cholesky(pfor_input): 3452 t = pfor_input.stacked_input(0) 3453 return wrap(linalg_ops.cholesky(t), True) 3454 3455 3456@RegisterPFor("LogMatrixDeterminant") 3457def _convert_log_matrix_determinant(pfor_input): 3458 t = pfor_input.stacked_input(0) 3459 return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] 3460 3461 3462@RegisterPFor("MatrixInverse") 3463def _convert_matrix_inverse(pfor_input): 3464 t = pfor_input.stacked_input(0) 3465 adjoint = pfor_input.get_attr("adjoint") 3466 return wrap(gen_linalg_ops.matrix_inverse(t, adjoint=adjoint), True) 3467 3468 3469@RegisterPFor("MatrixSolve") 3470def _convert_matrix_solve(pfor_input): 3471 pfor_input.stack_inputs() 3472 matrix = pfor_input.stacked_input(0) 3473 rhs = pfor_input.stacked_input(1) 3474 adjoint = pfor_input.get_attr("adjoint") 3475 output = gen_linalg_ops.matrix_solve( 3476 matrix, rhs, adjoint=adjoint) 3477 return wrap(output, True) 3478 3479 3480@RegisterPFor("MatrixTriangularSolve") 3481def _convert_matrix_triangular_solve(pfor_input): 3482 pfor_input.expanddim_inputs_for_broadcast() 3483 matrix = pfor_input.input(0)[0] 3484 rhs = pfor_input.input(1)[0] 3485 lower = pfor_input.get_attr("lower") 3486 adjoint = pfor_input.get_attr("adjoint") 3487 output = linalg_ops.matrix_triangular_solve( 3488 matrix, rhs, lower=lower, adjoint=adjoint) 3489 return wrap(output, True) 3490 3491 3492@RegisterPFor("SelfAdjointEigV2") 3493def _convert_self_adjoint_eig(pfor_input): 3494 t = pfor_input.stacked_input(0) 3495 compute_v = pfor_input.get_attr("compute_v") 3496 e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v) 3497 # If compute_v is False, v will have shape [0]. 3498 return wrap(e, True), wrap(v, compute_v) 3499 3500 3501# logging_ops 3502 3503 3504@RegisterPFor("Assert") 3505def _convert_assert(pfor_input): 3506 cond, cond_stacked, _ = pfor_input.input(0) 3507 if cond_stacked: 3508 cond = math_ops.reduce_all(cond) 3509 3510 data_list = [x.t for x in pfor_input.inputs][1:] 3511 return _create_op( 3512 "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr) 3513 3514 3515@RegisterPFor("Print") 3516def _convert_print(pfor_input): 3517 # Note that we don't stack all the inputs. Hence unstacked values are printed 3518 # once here vs multiple times in a while_loop. 3519 pfor_input.stack_inputs([0]) 3520 outputs = _create_op( 3521 "Print", [x.t for x in pfor_input.inputs], 3522 [x.dtype for x in pfor_input.outputs], 3523 attrs=pfor_input.op.node_def.attr).outputs 3524 return [wrap(x, True) for x in outputs] 3525 3526 3527@RegisterPFor("PrintV2") 3528def _convert_print_v2(pfor_input): 3529 # Print the full input Tensor(s), including the batch dimension if stacked. 3530 return _create_op( 3531 "PrintV2", [x.t for x in pfor_input.inputs], 3532 [x.dtype for x in pfor_input.outputs], 3533 attrs=pfor_input.op.node_def.attr) 3534 3535 3536@RegisterPFor("StringFormat") 3537def _convert_string_format(pfor_input): 3538 # Format using the full input Tensor(s), including the batch dimension if 3539 # stacked. 3540 op = _create_op( 3541 "StringFormat", [x.t for x in pfor_input.inputs], 3542 [x.dtype for x in pfor_input.outputs], 3543 attrs=pfor_input.op.node_def.attr) 3544 return [wrap(output, False) for output in op.outputs] 3545 3546 3547# data_flow_ops 3548 3549# TensorArray conversion is tricky since we don't support arrays of 3550# TensorArrays. For converting them, we consider two distinct cases: 3551# 3552# 1. The array is constructed outside the pfor call, and read/written inside the 3553# loop. 3554# This is an easier case since we don't need to make an array of TensorArrays. 3555# A correctness requirement is that these parallel iterations shouldn't attempt 3556# to write to the same location. Hence at conversion time we disallow indices to 3557# be loop-invariant as that would guarantee a collision. Even if the indices are 3558# not loop-invariant, they could conflict and that shall trigger runtime errors. 3559# 3560# 2. The array is constructed and used entirely inside each pfor iteration. 3561# For simplicity, here we require that the indices used for write/scatter are 3562# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in 3563# different pfor iterations. We consider two sub_cases: 3564# 3565# 2a Elements written to the array are "stacked" 3566# To simulate multiple TensorArrays, we may increase the dimension of each 3567# element of the array. i.e. the i_th row of the j_th entry of the converted 3568# TensorArray corresponds to the j_th entry of the TensorArray in the i_th 3569# pfor iteration. 3570# 3571# 2b Elements written to the array are "unstacked" 3572# In this case we don't increase the dimensions to avoid redundant tiling. Each 3573# iteration is trying to write the same value. So we convert that to a single 3574# write. 3575# 3576# Here are some tricks used to implement the above: 3577# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of 3578# trying to trace whether future writes are stacked or unstacked in order to set 3579# this attr, we set it to correspond to unknown shape. 3580# - We use the "flow" output of the different ops to track whether the array 3581# elements are stacked or unstacked. If a stacked write/scatter is done, we make 3582# the flow stacked as well. 3583# - We use some heuristic traversal of the graph to track whether the 3584# TensorArray handle was created inside or outside the pfor loop. 3585 3586 3587@RegisterPFor("TensorArrayV3") 3588def _convert_tensor_array_v3(pfor_input): 3589 size = pfor_input.unstacked_input(0) 3590 dtype = pfor_input.get_attr("dtype") 3591 dynamic_size = pfor_input.get_attr("dynamic_size") 3592 clear_after_read = pfor_input.get_attr("clear_after_read") 3593 identical_element_shapes = pfor_input.get_attr("identical_element_shapes") 3594 tensor_array_name = pfor_input.get_attr("tensor_array_name") 3595 handle, flow = data_flow_ops.tensor_array_v3( 3596 size, 3597 dtype=dtype, 3598 # We don't set element shape since we don't know if writes are stacked or 3599 # not yet. 3600 element_shape=None, 3601 dynamic_size=dynamic_size, 3602 clear_after_read=clear_after_read, 3603 identical_element_shapes=identical_element_shapes, 3604 tensor_array_name=tensor_array_name) 3605 # Note we keep flow unstacked for now since we don't know if writes will be 3606 # stacked or not. 3607 return wrap(handle, False), wrap(flow, False) 3608 3609 3610@RegisterPFor("TensorArraySizeV3") 3611def _convert_tensor_array_size_v3(pfor_input): 3612 handle = pfor_input.unstacked_input(0) 3613 flow, flow_stacked, _ = pfor_input.input(1) 3614 if flow_stacked: 3615 flow = _unstack_flow(flow) 3616 size = data_flow_ops.tensor_array_size_v3(handle, flow) 3617 return wrap(size, False) 3618 3619 3620def _handle_inside_pfor(pfor_input, handle): 3621 """Returns True if handle was created inside the pfor loop.""" 3622 # We use some heuristic to find the original TensorArray creation op. 3623 # The logic should handle the common cases (except cond based subgraphs). 3624 # In theory the user could perform different operations on the handle (like 3625 # Reshape, stack multiple handles, etc) which could break this logic. 3626 # TODO(agarwal): handle Switch/Merge. 3627 while handle.op.type in ("Enter", "Identity"): 3628 handle = handle.op.inputs[0] 3629 if handle.op.type not in [ 3630 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape" 3631 ]: 3632 raise ValueError(f"Unable to find source for handle {handle}.") 3633 else: 3634 return pfor_input.pfor.op_is_inside_loop(handle.op) 3635 3636 3637def _unstack_flow(value): 3638 # TODO(agarwal): consider looking if this is a Tile op then get its input. 3639 # This may avoid running the Tile operations. 3640 return array_ops.gather(value, 0) 3641 3642 3643@RegisterPFor("TensorArrayReadV3") 3644def _convert_tensor_array_read_v3(pfor_input): 3645 handle = pfor_input.unstacked_input(0) 3646 index, index_stacked, _ = pfor_input.input(1) 3647 dtype = pfor_input.get_attr("dtype") 3648 flow, flow_stacked, _ = pfor_input.input(2) 3649 if flow_stacked: 3650 flow = _unstack_flow(flow) 3651 3652 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3653 if is_inside_pfor: 3654 # Note that if we are inside a control flow construct inside the pfor, and 3655 # only some of the iterations are doing the read (i.e. 3656 # `all_indices_partitioned` is True), then the read operation should only 3657 # return values for the currently active pfor iterations (`all_indices` 3658 # below). Hence, whenever the returned value is stacked (i.e. `flow` is 3659 # stacked), we may need to do an extra gather after reading the values. Also 3660 # note that if `is_inside` is false, then values in the tensor array are 3661 # unstacked. So the check is only needed in this branch. 3662 all_indices = pfor_input.pfor.all_indices 3663 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned 3664 # Note: flow_stacked indicates if values in the TensorArray are stacked or 3665 # not. 3666 if index_stacked: 3667 if flow_stacked: 3668 raise ValueError( 3669 "It looks like TensorArrayReadV3 was called on a TensorArray whose" 3670 " values are not loop-invariant, and the read indices were also" 3671 " not loop invariant. This is currently unsupported.") 3672 value = data_flow_ops.tensor_array_gather_v3( 3673 handle, index, flow, dtype=dtype) 3674 return wrap(value, True) 3675 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3676 if flow_stacked and all_indices_partitioned: 3677 value = array_ops.gather(value, all_indices) 3678 return wrap(value, flow_stacked) 3679 # Values in the TensorArray should be unstacked (since different iterations 3680 # couldn't write to the same location). So whether output is stacked or not 3681 # depends on index_stacked. 3682 if index_stacked: 3683 value = data_flow_ops.tensor_array_gather_v3( 3684 handle, index, flow, dtype=dtype) 3685 else: 3686 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3687 return wrap(value, index_stacked) 3688 3689 3690@RegisterPFor("TensorArrayWriteV3") 3691def _convert_tensor_array_write_v3(pfor_input): 3692 handle = pfor_input.unstacked_input(0) 3693 index, index_stacked, _ = pfor_input.input(1) 3694 value, value_stacked, _ = pfor_input.input(2) 3695 flow, flow_stacked, _ = pfor_input.input(3) 3696 if value_stacked and pfor_input.pfor.all_indices_partitioned: 3697 # Looks like we are in a control flow in a pfor where not all iterations are 3698 # active now. We don't allow that since that could lead to different indices 3699 # having different shapes which will be hard to merge later. 3700 raise ValueError("Writing non loop invariant values to TensorArray from " 3701 "inside a while_loop/cond not supported.") 3702 if flow_stacked: 3703 flow = _unstack_flow(flow) 3704 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3705 if is_inside: 3706 if index_stacked: 3707 raise ValueError(f"Need indices for {handle} to be loop invariant.") 3708 if not flow_stacked and not value_stacked: 3709 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3710 return wrap(flow_out, False) 3711 else: 3712 if not value_stacked: 3713 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3714 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then 3715 # this may or may not be a safe situation. flow is unstacked both for a 3716 # freshly created TensorArray, as well as after unstacked values are 3717 # written to it. If it is the latter, then we cannot write a stacked value 3718 # now since that may cause runtime errors due to different shapes in the 3719 # array. At the moment we are not able to handle this gracefully and 3720 # distinguish between the two cases. That would require some heuristic 3721 # traversal of the graph to figure out whether all the writes are 3722 # unstacked or not. 3723 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3724 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3725 else: 3726 if not index_stacked: 3727 raise ValueError(f"Need indices for {handle} to be not loop invariant.") 3728 # Note that even when index_stacked is true, actual values in index may 3729 # still not be unique. However that will cause runtime error when executing 3730 # the scatter operation below. 3731 if not value_stacked: 3732 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3733 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) 3734 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3735 3736 3737def _transpose_first_two_dims(value): 3738 # TODO(agarwal): optimize if one of the dims == 1. 3739 value_shape = array_ops.shape(value) 3740 v0 = value_shape[0] 3741 v1 = value_shape[1] 3742 value = array_ops.reshape(value, [v0, v1, -1]) 3743 value = array_ops.transpose(value, [1, 0, 2]) 3744 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) 3745 return array_ops.reshape(value, new_shape) 3746 3747 3748@RegisterPFor("TensorArrayGatherV3") 3749def _convert_tensor_array_gather_v3(pfor_input): 3750 handle = pfor_input.unstacked_input(0) 3751 indices, indices_stacked, _ = pfor_input.input(1) 3752 indices = array_ops.reshape(indices, [-1]) 3753 flow, flow_stacked, _ = pfor_input.input(2) 3754 if flow_stacked: 3755 flow = _unstack_flow(flow) 3756 dtype = pfor_input.get_attr("dtype") 3757 # TODO(agarwal): support element_shape attr? 3758 3759 n = pfor_input.pfor.loop_len_vector 3760 value = data_flow_ops.tensor_array_gather_v3( 3761 handle, indices, flow, dtype=dtype) 3762 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3763 if is_inside: 3764 # flow_stacked indicates if values in the TensorArray are stacked or not. 3765 if indices_stacked: 3766 if flow_stacked: 3767 raise ValueError( 3768 "It looks like TensorArrayGatherV3 was called on a TensorArray " 3769 "whose values are not loop-invariant, and the indices were also " 3770 "not loop invariant. This is currently unsupported.") 3771 else: 3772 value = _unflatten_first_dim(value, n) 3773 return wrap(value, True) 3774 else: 3775 if flow_stacked: 3776 # Since elements in this array are stacked and `value` was produced by 3777 # gather, its first two dims are "gathered elements" and "stack 3778 # dimension". Our semantics require these two to be flipped. 3779 value = _transpose_first_two_dims(value) 3780 return wrap(value, flow_stacked) 3781 else: 3782 # Values in the TensorArray should be unstacked (since different iterations 3783 # couldn't write to the same location). So whether output is stacked or not 3784 # depends on indices_stacked. 3785 if indices_stacked: 3786 value = _unflatten_first_dim(value, n) 3787 return wrap(value, indices_stacked) 3788 3789 3790@RegisterPFor("TensorArrayScatterV3") 3791def _convert_tensor_array_scatter_v3(pfor_input): 3792 handle = pfor_input.unstacked_input(0) 3793 indices, indices_stacked, _ = pfor_input.input(1) 3794 indices = array_ops.reshape(indices, [-1]) 3795 value, value_stacked, _ = pfor_input.input(2) 3796 flow, flow_stacked, _ = pfor_input.input(3) 3797 3798 if flow_stacked: 3799 flow = _unstack_flow(flow) 3800 3801 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3802 if is_inside: 3803 if indices_stacked: 3804 raise ValueError(f"Need indices for {handle} to be loop invariant.") 3805 # Note that flow_stacked indicates if existing values in the array are 3806 # stacked or not. 3807 if not flow_stacked and not value_stacked: 3808 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3809 flow) 3810 return wrap(flow_out, False) 3811 if not value_stacked: 3812 # TODO(agarwal): tile in the second dimension directly instead of 3813 # transposing below. 3814 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3815 3816 value = _transpose_first_two_dims(value) 3817 # TODO(agarwal): Note that if a previous write was unstacked, flow will be 3818 # unstacked, and a stacked value may be written here which may cause 3819 # runtime error due to different elements having different shape. We do 3820 # not try to prevent that. 3821 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3822 flow) 3823 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3824 if not indices_stacked: 3825 raise ValueError(f"Need indices for {handle} to be not loop invariant.") 3826 if not value_stacked: 3827 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3828 value = _flatten_first_two_dims(value) 3829 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow) 3830 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3831 3832 3833@RegisterPFor("TensorArrayGradV3") 3834def _convert_tensor_array_grad_v3(pfor_input): 3835 handle = pfor_input.unstacked_input(0) 3836 flow, flow_stacked, _ = pfor_input.input(1) 3837 if flow_stacked: 3838 flow = _unstack_flow(flow) 3839 source = pfor_input.get_attr("source") 3840 # TODO(agarwal): For now, we assume that gradients are stacked if the 3841 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong 3842 # will give runtime error due to incorrect shape being written to the 3843 # accumulator. It is difficult to know in advance if gradients written will be 3844 # stacked or not. Note that flow being stacked is not indicative of the 3845 # gradient being stacked or not. Revisit this later. 3846 shape_to_prepend = pfor_input.pfor.loop_len_vector 3847 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( 3848 handle=handle, 3849 flow_in=flow, 3850 shape_to_prepend=shape_to_prepend, 3851 source=source) 3852 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t 3853 return [wrap(grad_handle, False), wrap(flow_out, True)] 3854 3855 3856def _stack_tensor_list_shape(shape, first_dim): 3857 shape_value = tensor_util.constant_value(shape) 3858 # Note that negative values in the shape are used to signify unknown shapes 3859 # and are handled in a special way. 3860 if shape_value is not None: 3861 shape_value = np.asarray(shape_value) 3862 if -1 in shape_value: 3863 return constant_op.constant(-1) 3864 elif not shape_value.size: 3865 return first_dim 3866 else: 3867 shape = array_ops.reshape(shape, [-1]) 3868 return control_flow_ops.cond( 3869 math_ops.reduce_any(shape < 0), 3870 lambda: constant_op.constant(-1), 3871 lambda: array_ops.concat([first_dim, shape], axis=0)) 3872 3873 3874def _tile_variant_with_length(t, length): 3875 """stacks `t` `length` times.""" 3876 if _is_variant_with_internal_stacking(t): 3877 # The content of TensorLists is vectorized, not the variant itself. 3878 return t 3879 original_tensor = t 3880 t.set_shape([]) 3881 t = array_ops.reshape(t, [-1]) 3882 with ops.device("CPU:0"): 3883 result = array_ops.tile(t, length) 3884 # TODO(b/169968286): Should regular shape functions do handle data 3885 # propagation here? 3886 handle_data_util.copy_handle_data(original_tensor, result) 3887 return result 3888 3889 3890def _tile_variant(t, pfor_input): 3891 """stacks `t` according to its loop context.""" 3892 return _tile_variant_with_length(t, pfor_input.pfor.loop_len_vector) 3893 3894 3895def _untile_variant(t): 3896 if _is_variant_with_internal_stacking(t): 3897 # The content of TensorLists is vectorized, not the variant itself. 3898 if not t.shape.is_compatible_with([]): 3899 raise AssertionError( 3900 ("Unexpectedly saw a vectorized variant (e.g. TensorList) with " 3901 f"non-scalar shape: {t!r}")) 3902 return t 3903 return array_ops.gather(t, 0) 3904 3905 3906@RegisterPFor("OptionalFromValue") 3907def _convert_optional_from_value(pfor_input): 3908 pfor_input.stack_inputs() 3909 return wrap( 3910 gen_dataset_ops.optional_from_value([x.t for x in pfor_input.inputs]), 3911 True) 3912 3913 3914@RegisterPFor("OptionalGetValue") 3915def _convert_optional_get_value(pfor_input): 3916 handle = pfor_input.stacked_input(0) 3917 output_types = pfor_input.get_attr("output_types") 3918 original_output_shapes = pfor_input.get_attr("output_shapes") 3919 output_shapes = [] 3920 for shape in original_output_shapes: 3921 shape = tensor_shape.TensorShape(shape) 3922 loop_len_shape = tensor_shape.TensorShape( 3923 [tensor_util.constant_value(pfor_input.pfor.loop_len_vector)]) 3924 shape = loop_len_shape.concatenate(shape) 3925 output_shapes.append(shape.as_proto()) 3926 results = gen_dataset_ops.optional_get_value(handle, output_types, 3927 output_shapes) 3928 return [wrap(t, True) for t in results] 3929 3930 3931@RegisterPFor("TensorListReserve") 3932def _convert_tensor_list_reserve(pfor_input): 3933 element_shape = pfor_input.unstacked_input(0) 3934 num_elements = pfor_input.unstacked_input(1) 3935 element_dtype = pfor_input.get_attr("element_dtype") 3936 3937 # Prepend a dimension to element_shape. 3938 element_shape = _stack_tensor_list_shape(element_shape, 3939 pfor_input.pfor.loop_len_vector) 3940 handle = list_ops.tensor_list_reserve( 3941 element_shape, num_elements, element_dtype=element_dtype) 3942 3943 return wrap(_tile_variant(handle, pfor_input), True) 3944 3945 3946@RegisterPFor("TensorListElementShape") 3947def _convert_tensor_list_element_shape(pfor_input): 3948 handle = _untile_variant(pfor_input.stacked_input(0)) 3949 shape_type = pfor_input.get_attr("shape_type") 3950 shape = list_ops.tensor_list_element_shape(handle, shape_type) 3951 shape = array_ops.reshape(shape, [-1]) 3952 shape = shape[1:] 3953 return wrap(shape, False) 3954 3955 3956@RegisterPFor("TensorListLength") 3957def _convert_tensor_list_length(pfor_input): 3958 handle = _untile_variant(pfor_input.stacked_input(0)) 3959 return wrap(list_ops.tensor_list_length(handle), False) 3960 3961 3962def _stack_tensor_list(handle, dtype, loop_len_vector, element_shape=None): 3963 if element_shape is None: 3964 element_shape = list_ops.tensor_list_element_shape(handle, dtypes.int32) 3965 length = list_ops.tensor_list_length(handle) 3966 new_handle = list_ops.tensor_list_reserve( 3967 _stack_tensor_list_shape(element_shape, loop_len_vector), length, dtype) 3968 3969 def _body_fn(i, h): 3970 elem = list_ops.tensor_list_get_item(handle, i, dtype, element_shape) 3971 elem = _stack(elem, loop_len_vector).t 3972 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 3973 3974 return control_flow_ops.while_loop(lambda i, _: i < length, _body_fn, 3975 [0, new_handle])[1] 3976 3977 3978@RegisterPFor("TensorListGetItem") 3979def _convert_tensor_list_get_item(pfor_input): 3980 handle, handle_stacked, _ = pfor_input.input(0) 3981 index, index_stacked, _ = pfor_input.input(1) 3982 element_shape = pfor_input.unstacked_input(2) 3983 element_dtype = pfor_input.get_attr("element_dtype") 3984 3985 if handle_stacked: 3986 handle = _untile_variant(handle) 3987 element_shape = _stack_tensor_list_shape(element_shape, 3988 pfor_input.pfor.loop_len_vector) 3989 if index_stacked: 3990 # We use a sequential loop since that may be more efficient than first 3991 # gathering and concatenating all the element corresponding to `index`, 3992 # and then doing a gather on it. 3993 def _map_fn(i): 3994 item_i = list_ops.tensor_list_get_item( 3995 handle, 3996 index[i], 3997 element_dtype=element_dtype) 3998 return array_ops.gather(item_i, i) 3999 4000 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 4001 return wrap(output, True) 4002 else: 4003 output = list_ops.tensor_list_get_item( 4004 handle, 4005 index, 4006 element_shape=element_shape, 4007 element_dtype=element_dtype) 4008 return wrap(output, True) 4009 else: 4010 assert index_stacked 4011 return wrap( 4012 list_ops.tensor_list_gather( 4013 handle, 4014 index, 4015 element_shape=element_shape, 4016 element_dtype=element_dtype), True) 4017 4018 4019@RegisterPFor("TensorListSetItem") 4020def _convert_tensor_array_set_item(pfor_input): 4021 handle, handle_stacked, _ = pfor_input.input(0) 4022 index, index_stacked, _ = pfor_input.input(1) 4023 item, item_stacked, _ = pfor_input.input(2) 4024 4025 if not handle_stacked: 4026 # Special case where we can statically guarantee that the indices are 4027 # disjoint. 4028 if index is pfor_input.pfor.all_indices: 4029 if not item_stacked: 4030 item = _stack(item, pfor_input.pfor.loop_len_vector).t 4031 return wrap( 4032 list_ops.tensor_list_scatter(item, index, input_handle=handle), False) 4033 else: 4034 handle = _stack_tensor_list(handle, item.dtype, 4035 pfor_input.pfor.loop_len_vector) 4036 else: 4037 handle = _untile_variant(handle) 4038 4039 if index_stacked: 4040 # TODO(agarwal): handle this. 4041 raise ValueError("Vectorizing writes to a TensorList with loop " 4042 "variant indices is currently unsupported.") 4043 4044 else: 4045 if not item_stacked: 4046 item = _stack(item, pfor_input.pfor.loop_len_vector).t 4047 handle = list_ops.tensor_list_set_item(handle, index, item) 4048 return wrap(_tile_variant(handle, pfor_input), True) 4049 4050 4051@RegisterPFor("TensorListPushBack") 4052def _convert_tensor_list_push_back(pfor_input): 4053 handle, handle_stacked, _ = pfor_input.input(0) 4054 tensor, tensor_stacked, _ = pfor_input.input(1) 4055 if handle_stacked: 4056 handle = _untile_variant(handle) 4057 else: 4058 handle = _stack_tensor_list(handle, tensor.dtype, 4059 pfor_input.pfor.loop_len_vector) 4060 if not tensor_stacked: 4061 tensor = _stack(tensor, pfor_input.pfor.loop_len_vector).t 4062 handle = list_ops.tensor_list_push_back(handle, tensor) 4063 return wrap(_tile_variant(handle, pfor_input), True) 4064 4065 4066@RegisterPFor("TensorListPopBack") 4067def _convert_tensor_array_push_back(pfor_input): 4068 handle = pfor_input.stacked_input(0) 4069 element_shape = pfor_input.unstacked_input(1) 4070 handle = _untile_variant(handle) 4071 4072 if element_shape.shape.ndims == 0: 4073 # Default / unspecified 4074 vectorized_shape = -1 4075 else: 4076 # PopBack has an element shape set when it's the gradient of PushBack, only 4077 # used when the list is uninitialized. 4078 vectorized_shape = array_ops.concat( 4079 [pfor_input.pfor.loop_len_vector, element_shape], axis=0) 4080 4081 output_handle, tensor = gen_list_ops.tensor_list_pop_back( 4082 input_handle=handle, element_dtype=pfor_input.get_attr("element_dtype"), 4083 element_shape=vectorized_shape) 4084 return wrap(output_handle, True), wrap(tensor, True) 4085 4086 4087@RegisterPFor("TensorListConcatV2") 4088def _convert_tensor_list_concat_v2(pfor_input): 4089 input_handle = pfor_input.stacked_input(0) 4090 element_shape = pfor_input.unstacked_input(1) 4091 leading_dims = pfor_input.unstacked_input(2) 4092 element_dtype = pfor_input.get_attr("element_dtype") 4093 4094 handle = _untile_variant(input_handle) 4095 length = list_ops.tensor_list_length(handle) 4096 # Note that element_shape attribute can have incomplete shapes. This doesn't 4097 # seem to work well when creating another list and then doing a concat on it. 4098 # Hence we try to find the dynamic shape here. 4099 element_shape = control_flow_ops.cond( 4100 length > 0, lambda: array_ops.shape( 4101 list_ops.tensor_list_get_item(handle, 0, element_dtype, None)), 4102 lambda: constant_op.constant([0, 0], dtype=dtypes.int32)) 4103 # The code below creates a copy of the list with each elements' first two 4104 # dimensions transposed. 4105 new_element_shape = array_ops.concat( 4106 [element_shape[1:2], element_shape[0:1], element_shape[2:]], axis=0) 4107 4108 # Create a new TensorList with elements transposed. 4109 def _transpose_elem(i, h): 4110 elem = list_ops.tensor_list_get_item(handle, i, element_dtype, None) 4111 elem = _transpose_first_two_dims(elem) 4112 return i + 1, list_ops.tensor_list_set_item(h, i, elem) 4113 4114 new_handle = list_ops.tensor_list_reserve(new_element_shape, length, 4115 element_dtype) 4116 new_handle = control_flow_ops.while_loop(lambda i, _: i < length, 4117 _transpose_elem, [0, new_handle])[1] 4118 output, lengths = gen_list_ops.tensor_list_concat_v2( 4119 input_handle=new_handle, 4120 element_dtype=element_dtype, 4121 element_shape=new_element_shape, 4122 leading_dims=leading_dims) 4123 output = _transpose_first_two_dims(output) 4124 return wrap(output, True), wrap(lengths, False) 4125 4126 4127@RegisterPFor("TensorListStack") 4128def _convert_tensor_list_stack(pfor_input): 4129 handle = pfor_input.stacked_input(0) 4130 input_shape = pfor_input.unstacked_input(1) 4131 element_dtype = pfor_input.get_attr("element_dtype") 4132 num_elements = pfor_input.get_attr("num_elements") 4133 4134 handle = _untile_variant(handle) 4135 input_shape = _stack_tensor_list_shape(input_shape, 4136 pfor_input.pfor.loop_len_vector) 4137 output = list_ops.tensor_list_stack( 4138 handle, 4139 element_dtype, 4140 element_shape=input_shape, 4141 num_elements=num_elements) 4142 output = _transpose_first_two_dims(output) 4143 return wrap(output, True) 4144 4145 4146@RegisterPFor("TensorListGather") 4147def _convert_tensor_list_gather(pfor_input): 4148 handle, handle_stacked, _ = pfor_input.input(0) 4149 index, index_stacked, _ = pfor_input.input(1) 4150 element_shape = pfor_input.unstacked_input(2) 4151 element_dtype = pfor_input.get_attr("element_dtype") 4152 4153 if handle_stacked: 4154 handle = _untile_variant(handle) 4155 element_shape = _stack_tensor_list_shape(element_shape, 4156 pfor_input.pfor.loop_len_vector) 4157 if index_stacked: 4158 # We use a sequential loop since that may be more efficient than first 4159 # gathering and concatenating all the element corresponding to `index`, 4160 # and then doing a gather on it. 4161 def _map_fn(i): 4162 item_i = list_ops.tensor_list_gather( 4163 handle, 4164 index[i], 4165 element_dtype=element_dtype) 4166 axis = array_ops.rank(index) - 1 4167 return array_ops.gather(item_i, i, axis=axis) 4168 4169 output = map_fn.map_fn(_map_fn, pfor_input.pfor.all_indices) 4170 return wrap(output, True) 4171 else: 4172 output = list_ops.tensor_list_gather( 4173 handle, 4174 index, 4175 element_shape=element_shape, 4176 element_dtype=element_dtype) 4177 return wrap(output, True) 4178 else: 4179 assert index_stacked 4180 index_shape = array_ops.shape(index) 4181 index = array_ops.reshape(index, [-1]) 4182 values = list_ops.tensor_list_gather( 4183 handle, index, element_shape=element_shape, element_dtype=element_dtype) 4184 final_shape = array_ops.concat( 4185 [index_shape, array_ops.shape(values)[1:]], axis=0) 4186 return wrap(array_ops.reshape(values, final_shape), True) 4187 4188 4189@RegisterPFor("TensorListScatterIntoExistingList") 4190def _convert_tensor_list_scatter(pfor_input): 4191 pfor_input.stack_inputs([1]) 4192 handle, handle_stacked, _ = pfor_input.input(0) 4193 item = pfor_input.stacked_input(1) 4194 indices, indices_stacked, _ = pfor_input.input(2) 4195 if handle_stacked: 4196 handle = _untile_variant(handle) 4197 else: 4198 handle = _stack_tensor_list(handle, item.dtype, 4199 pfor_input.pfor.loop_len_vector) 4200 4201 item = _transpose_first_two_dims(item) 4202 if indices_stacked: 4203 # Pretend the list is a dense tensor: 4204 # list_as_dense: Tensor[list_len, loop_len, ...] 4205 # And indices are a tensor with shape (before transpose): 4206 # indices: Tensor[loop_len, num_scatters] 4207 # The item to scatter has shape (before transpose): 4208 # item: Tensor[loop_len, num_scatters, ...] 4209 # 4210 # We want list_as_dense[indices[i, j], i] = item[i, j] 4211 # 4212 # Since we're not just indexing along the first axis of `list_as_dense`, we 4213 # need to first extract the relevant list entries based on `indices`, 4214 # scatter into them according to the loop index, and re-scatter the chunks 4215 # we updated back into the list. 4216 indices = _transpose_first_two_dims(indices) 4217 indices_flat = array_ops.reshape(indices, [-1]) 4218 # In many cases `indices` will be unique across pfor iterations, but this is 4219 # not guaranteed. If there are duplicates, we need to map multiple updates 4220 # to a single chunk extracted from the list. The last update should win. 4221 unique_indices = array_ops.unique(indices_flat) 4222 gathered_items = list_ops.tensor_list_gather( 4223 handle, unique_indices.y, element_dtype=item.dtype, 4224 element_shape=array_ops.shape(item)[1:]) 4225 loop_idx = math_ops.range(pfor_input.pfor.loop_len_vector[0]) 4226 scatters_per_op = array_ops.shape(indices)[0] 4227 4228 unique_indices_loop_idx = array_ops.reshape(array_ops.tile( 4229 loop_idx[None, :], [scatters_per_op, 1]), [-1]) 4230 scatter_indices = array_ops.stack( 4231 [unique_indices.idx, unique_indices_loop_idx], 4232 axis=1) 4233 # This op does *not* guarantee last-update-wins on GPU, so semantics may not 4234 # be exactly preserved for duplicate updates there. 4235 scattered = array_ops.tensor_scatter_nd_update( 4236 tensor=gathered_items, 4237 indices=scatter_indices, 4238 updates=_flatten_first_two_dims(item)) 4239 handle = list_ops.tensor_list_scatter( 4240 scattered, unique_indices.y, input_handle=handle) 4241 else: 4242 handle = list_ops.tensor_list_scatter(item, indices, input_handle=handle) 4243 return wrap(_tile_variant(handle, pfor_input), True) 4244 4245 4246@RegisterPFor("TensorListFromTensor") 4247def _convert_tensor_list_from_tensor(pfor_input): 4248 tensor = pfor_input.stacked_input(0) 4249 element_shape = pfor_input.unstacked_input(1) 4250 tensor = _transpose_first_two_dims(tensor) 4251 element_shape = _stack_tensor_list_shape(element_shape, 4252 pfor_input.pfor.loop_len_vector) 4253 handle = list_ops.tensor_list_from_tensor(tensor, element_shape) 4254 return wrap(_tile_variant(handle, pfor_input), True) 4255 4256 4257@RegisterPFor("TensorScatterUpdate") 4258def _convert_tensor_scatter_update(pfor_input): 4259 pfor_input.stack_inputs([0, 1, 2]) 4260 tensor = pfor_input.stacked_input(0) 4261 indices = pfor_input.stacked_input(1) 4262 updates = pfor_input.stacked_input(2) 4263 4264 indices_shape = array_ops.shape(indices) 4265 indices_rank = array_ops.rank(indices) 4266 loop_length = indices_shape[0] 4267 4268 # Create a loop count range and extend its dimensions to match `indices`. 4269 loop_count_shape = array_ops.tensor_scatter_nd_update( 4270 array_ops.ones([indices_rank], dtype=dtypes.int32), [[0]], [loop_length]) 4271 loop_count = array_ops.reshape(math_ops.range(loop_length), loop_count_shape) 4272 4273 # Tile the loop count range for the batch dimensions (all except the first and 4274 # last dimensions of indices). 4275 # Rank(indices) >= 3 always for this function so we always have at least 1. 4276 tile_multiplier = array_ops.tensor_scatter_nd_update( 4277 indices_shape, [[0], [indices_rank - 1]], [1, 1]) 4278 meta_index = array_ops.tile(loop_count, tile_multiplier) 4279 4280 # Insert the loop-identifying index. 4281 indices = array_ops.concat([meta_index, indices], axis=-1) 4282 4283 result = array_ops.tensor_scatter_nd_update(tensor, indices, updates) 4284 return wrap(result, True) 4285 4286# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar 4287# to TensorArrays, we convert them by changing the dimension of the elements 4288# inside the stack. 4289# 4290# We consider two cases: 4291# 4292# 1. StackV2 is constructed and used entirely inside the pfor loop. 4293# We keep a single Stack and perform the push/pop operations of all the 4294# iterations in lock-step. We also assume that all the iterations perform these 4295# operations. In case of dynamic control flow, if only some of the iterations 4296# try to perform a push/pop, then the conversion may not work correctly and may 4297# cause undefined behavior. 4298# TODO(agarwal): test StackV2 with dynamic control flow. 4299# 4300# 2. StackV2 is constructed outside the pfor loop. 4301# Performing stack push/pop in a parallel fashion is ill-defined. However given 4302# that reading stacks created externally is a common operation when computing 4303# jacobians, we provide some special semantics here as follows. 4304# - disallow push operations to the stack 4305# - pop operations are performed in lock step by all iterations, similar to the 4306# case when the stack is created inside. A single value is popped during the 4307# lock-step operation and broadcast to all the iterations. Values in the stack 4308# are assumed to be loop-invariant. 4309# 4310# Some other implementation details: 4311# We use an ugly logic to find whether values in Stack data structure are 4312# loop invariant or not. When converting push/pop operations, we keep track of 4313# whether the last conversion used a stacked value or not (see _stack_cache 4314# below). As a result if an unstacked value is written first, subsequent stacked 4315# writes are disallowed when they could have been allowed in theory. 4316 4317# Map from cache key based on StackV2 handle to a bool indicating whether values 4318# are stacked or not. 4319# TODO(agarwal): move _stack_cache inside pfor? 4320_stack_cache = {} 4321 4322 4323def _stack_cache_key(pfor_input): 4324 """Create cache key corresponding to a stack handle.""" 4325 op_type = pfor_input.op_type 4326 assert op_type in ["StackPushV2", "StackPopV2"], op_type 4327 orig_handle = pfor_input.op.inputs[0] 4328 while orig_handle.op.type in ["Identity", "Enter"]: 4329 orig_handle = orig_handle.op.inputs[0] 4330 assert orig_handle.op.type == "StackV2", orig_handle.op 4331 return ops.get_default_graph(), pfor_input.pfor, orig_handle 4332 4333 4334def _stack_handle_inside_pfor(handle, pfor_input): 4335 while handle.op.type in ["Identity", "Enter"]: 4336 handle = handle.op.inputs[0] 4337 assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" % 4338 handle.op) 4339 return pfor_input.pfor.op_is_inside_loop(handle.op) 4340 4341 4342@RegisterPFor("StackPushV2") 4343def _convert_stack_push_v2(pfor_input): 4344 handle = pfor_input.unstacked_input(0) 4345 elem, elem_stacked, _ = pfor_input.input(1) 4346 swap_memory = pfor_input.get_attr("swap_memory") 4347 4348 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): 4349 raise ValueError("StackPushV2 not allowed on stacks created outside pfor.") 4350 stack_cache_key = _stack_cache_key(pfor_input) 4351 stacked = _stack_cache.get(stack_cache_key, None) 4352 if stacked is None: 4353 stacked = elem_stacked 4354 _stack_cache[stack_cache_key] = stacked 4355 else: 4356 # If we previously made it unstacked then we can't revert to being stacked. 4357 if not stacked and elem_stacked: 4358 raise ValueError( 4359 "It looks like the stack was previously determined to be loop " 4360 "invariant, but we are now trying to push a loop dependent value " 4361 "to it. This is currently unsupported.") 4362 if stacked and not elem_stacked: 4363 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t 4364 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) 4365 return wrap(out, stacked) 4366 4367 4368# Note that inputs to this convertor will be unstacked. However it should get 4369# called since it is a stateful op. 4370@RegisterPFor("StackPopV2") 4371def _convert_stack_pop_v2(pfor_input): 4372 handle = pfor_input.unstacked_input(0) 4373 stack_cache_key = _stack_cache_key(pfor_input) 4374 stacked = _stack_cache.get(stack_cache_key, None) 4375 # If a StackPushV2 has not been converted yet, we default to unstacked since 4376 # the push could be outside of pfor, or the convertor may not be called if the 4377 # inputs are unconverted. 4378 if stacked is None: 4379 stacked = False 4380 _stack_cache[stack_cache_key] = False 4381 elem_type = pfor_input.get_attr("elem_type") 4382 out = data_flow_ops.stack_pop_v2(handle, elem_type) 4383 return wrap(out, stacked) 4384 4385 4386# parsing_ops 4387 4388 4389@RegisterPFor("DecodeCSV") 4390def _convert_decode_csv(pfor_input): 4391 lines = pfor_input.stacked_input(0) 4392 record_defaults = [ 4393 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 4394 ] 4395 field_delim = pfor_input.get_attr("field_delim") 4396 use_quote_delim = pfor_input.get_attr("use_quote_delim") 4397 select_cols = pfor_input.get_attr("select_cols") 4398 if not select_cols: 4399 select_cols = None 4400 return [ 4401 wrap(t, True) for t in parsing_ops.decode_csv( 4402 lines, 4403 record_defaults, 4404 field_delim=field_delim, 4405 use_quote_delim=use_quote_delim, 4406 select_cols=select_cols) 4407 ] 4408 4409 4410@RegisterPFor("ParseSingleExample") 4411def _convert_parse_single_example(pfor_input): 4412 serialized = pfor_input.stacked_input(0) 4413 dense_defaults = [ 4414 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 4415 ] 4416 sparse_keys = pfor_input.get_attr("sparse_keys") 4417 dense_keys = pfor_input.get_attr("dense_keys") 4418 sparse_types = pfor_input.get_attr("sparse_types") 4419 dense_shapes = pfor_input.get_attr("dense_shapes") 4420 output = gen_parsing_ops.parse_example( 4421 serialized=serialized, 4422 names=[], 4423 dense_defaults=dense_defaults, 4424 sparse_keys=sparse_keys, 4425 dense_keys=dense_keys, 4426 sparse_types=sparse_types, 4427 dense_shapes=dense_shapes) 4428 return [wrap(t, True, True) for t in nest.flatten(output)] 4429 4430 4431@RegisterPFor("ParseExampleV2") 4432def _convert_parse_example_v2(pfor_input): 4433 serialized = pfor_input.stacked_input(0) 4434 sparse_keys = pfor_input.unstacked_input(2) 4435 dense_keys = pfor_input.unstacked_input(3) 4436 ragged_keys = pfor_input.unstacked_input(4) 4437 dense_defaults = [ 4438 pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs) 4439 ] 4440 num_sparse = pfor_input.get_attr("num_sparse") 4441 sparse_types = pfor_input.get_attr("sparse_types") 4442 ragged_value_types = pfor_input.get_attr("ragged_value_types") 4443 ragged_split_types = pfor_input.get_attr("ragged_split_types") 4444 dense_shapes = pfor_input.get_attr("dense_shapes") 4445 if serialized.shape.ndims not in (None, 1): 4446 raise ValueError("ParseExampleV2 can only be converted if `serialized` " 4447 f"is scalar. Received shape: {serialized.shape}.") 4448 output = gen_parsing_ops.parse_example_v2( 4449 serialized=serialized, 4450 names=[], 4451 sparse_keys=sparse_keys, 4452 dense_keys=dense_keys, 4453 ragged_keys=ragged_keys, 4454 dense_defaults=dense_defaults, 4455 num_sparse=num_sparse, 4456 sparse_types=sparse_types, 4457 ragged_value_types=ragged_value_types, 4458 ragged_split_types=ragged_split_types, 4459 dense_shapes=dense_shapes) 4460 return [wrap(t, True, True) for t in nest.flatten(output)] 4461 4462 4463# functional_ops 4464 4465 4466def _convert_function_call(func, converter, inputs): 4467 assert isinstance(func.graph, func_graph.FuncGraph), func 4468 assert isinstance(converter, PFor) 4469 4470 # TODO(agarwal): consider caching this function definition. 4471 @def_function.function 4472 def f(*args): 4473 assert all(isinstance(arg, WrappedTensor) for arg in args), args 4474 assert len(args) == len(func.graph.inputs), (args, func.graph.inputs) 4475 # Map inputs to function arguments. 4476 for inp, arg in zip(func.graph.inputs, args): 4477 converter._add_conversion(inp, arg) 4478 # Convert output tensors. 4479 return tuple( 4480 [converter._convert_helper(x).t for x in func._func_graph_outputs]) 4481 4482 call_outputs = f(*inputs) 4483 assert len(call_outputs) == len(func._func_graph_outputs) 4484 outputs = [] 4485 for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs): 4486 func_output = converter._convert_helper(output_tensor) 4487 outputs.append( 4488 wrap(call_output, func_output.is_stacked, 4489 func_output.is_sparse_stacked)) 4490 return outputs 4491 4492 4493@RegisterPFor("StatefulPartitionedCall") 4494@RegisterPFor("PartitionedCall") 4495def _convert_partitioned_call(pfor_input): 4496 func_name = pfor_input.get_attr("f").name 4497 func = pfor_input.op.graph._get_function(compat.as_bytes(func_name)) 4498 assert isinstance(func.graph, func_graph.FuncGraph), ( 4499 "Could not find FuncGraph object for %s. Got func %s" % (func_name, func)) 4500 pfor = pfor_input.pfor 4501 converter = PFor( 4502 loop_var=pfor.loop_var, 4503 loop_len=pfor.loop_len_vector[0], 4504 pfor_ops=func.graph.get_operations(), 4505 fallback_to_while_loop=pfor.fallback_to_while_loop, 4506 all_indices=pfor.all_indices, 4507 all_indices_partitioned=pfor.all_indices_partitioned, 4508 pfor_config=pfor.pfor_config) 4509 return _convert_function_call(func, converter, pfor_input.inputs) 4510 4511 4512def _partition_inputs_for_indices(inputs, indices): 4513 new_inputs = [] 4514 for inp in inputs: 4515 if inp.is_stacked: 4516 new_inputs.append(wrap(array_ops.gather(inp.t, indices), True)) 4517 else: 4518 new_inputs.append(inp) 4519 return new_inputs 4520 4521 4522def _outputs_for_branch(func_name, indices, pfor_input, inputs): 4523 if indices is None: 4524 indices = pfor_input.pfor.all_indices 4525 partitioned = pfor_input.pfor.all_indices_partitioned 4526 else: 4527 partitioned = True 4528 func = pfor_input.op.graph._get_function(func_name) 4529 converter = PFor( 4530 loop_var=pfor_input.pfor.loop_var, 4531 loop_len=array_ops.size(indices), 4532 pfor_ops=func.graph.get_operations(), 4533 fallback_to_while_loop=pfor_input.pfor.fallback_to_while_loop, 4534 all_indices=indices, 4535 all_indices_partitioned=partitioned, 4536 pfor_config=pfor_input.pfor.pfor_config) 4537 outputs = _convert_function_call(func, converter, inputs) 4538 stacked_outputs = [] 4539 for out in outputs: 4540 if not out.is_stacked: 4541 stacked_outputs.append(_stack(out.t, [array_ops.size(indices)]).t) 4542 else: 4543 stacked_outputs.append(out.t) 4544 return stacked_outputs 4545 4546 4547# TODO(agarwal): Currently the converted code aggressively tiles loop variant 4548# outputs from the then/else branches. Instead, it could do so only if at least 4549# one of the branch outputs is loop variant. 4550@RegisterPFor("StatelessIf") 4551@RegisterPFor("If") 4552def _convert_if(pfor_input): 4553 cond, cond_stacked, _ = pfor_input.input(0) 4554 inputs = pfor_input.inputs[1:] 4555 then_branch = pfor_input.get_attr("then_branch") 4556 else_branch = pfor_input.get_attr("else_branch") 4557 4558 if cond_stacked: 4559 cond_int = math_ops.cast(cond, dtypes.int32) 4560 # Compute loop indices for the different branches 4561 false_indices, true_indices = data_flow_ops.dynamic_partition( 4562 pfor_input.pfor.all_indices, cond_int, 2) 4563 # Compute indices for cond being True or False. 4564 if pfor_input.pfor.all_indices_partitioned: 4565 else_indices, then_indices = data_flow_ops.dynamic_partition( 4566 math_ops.range(pfor_input.pfor.loop_len_vector[0]), 4567 cond_int, 2) 4568 else: 4569 else_indices, then_indices = false_indices, true_indices 4570 # Partition inputs 4571 then_inputs = _partition_inputs_for_indices(inputs, then_indices) 4572 else_inputs = _partition_inputs_for_indices(inputs, else_indices) 4573 4574 # Convert "then" branch. 4575 then_outputs = _outputs_for_branch(then_branch.name, true_indices, 4576 pfor_input, then_inputs) 4577 4578 # Convert "else" branch. 4579 else_outputs = _outputs_for_branch(else_branch.name, false_indices, 4580 pfor_input, else_inputs) 4581 4582 assert len(then_outputs) == len(else_outputs) 4583 # Note that if the "then" and "else" branches are updating the same state, 4584 # and possibly reading them as well, it could lead to undefined behavior 4585 # since the ordering of those operations is not well defined. 4586 # One possibility is to order all the "then" branches to execute before all 4587 # the "else" branches so that the side-effects in the former are visible to 4588 # the latter. For now, we leave that as undefined behavior. 4589 outputs = [] 4590 # Merge outputs 4591 for then_output, else_output in zip(then_outputs, else_outputs): 4592 out = data_flow_ops.dynamic_stitch([then_indices, else_indices], 4593 [then_output, else_output]) 4594 outputs.append(wrap(out, True)) 4595 return outputs 4596 else: 4597 outputs = control_flow_ops.cond( 4598 cond, 4599 lambda: _outputs_for_branch(then_branch.name, None, pfor_input, inputs), 4600 lambda: _outputs_for_branch(else_branch.name, None, pfor_input, inputs)) 4601 return [wrap(t, True) for t in outputs] 4602 4603 4604@RegisterPFor("Case") 4605@RegisterPFor("StatelessCase") 4606def _convert_stateless_case(pfor_input): 4607 branch_idx, is_stacked, _ = pfor_input.input(0) 4608 branches = pfor_input.get_attr("branches") 4609 inputs = pfor_input.inputs[1:] 4610 4611 if is_stacked: 4612 logging.info("Running stacked flow") 4613 4614 # Compute loop indices for the different branches 4615 switch_indices = data_flow_ops.dynamic_partition( 4616 pfor_input.pfor.all_indices, branch_idx, len(branches)) 4617 if pfor_input.pfor.all_indices_partitioned: 4618 partitioned_indices = data_flow_ops.dynamic_partition( 4619 math_ops.range(pfor_input.pfor.loop_len_vector[0]), branch_idx, 4620 len(branches)) 4621 else: 4622 partitioned_indices = switch_indices 4623 # Partition inputs 4624 input_list = [] 4625 for indices in partitioned_indices: 4626 input_list.append(_partition_inputs_for_indices(inputs, indices)) 4627 4628 outputs = [] 4629 for (b, indices, inputs) in zip(branches, switch_indices, input_list): 4630 out = _outputs_for_branch(b.name, indices, pfor_input, inputs) 4631 outputs.extend(out) 4632 4633 out = data_flow_ops.dynamic_stitch(partitioned_indices, outputs) 4634 return [wrap(out, True)] 4635 else: 4636 new_branches = [] 4637 for b in branches: 4638 def new_function(func=b.name): 4639 return _outputs_for_branch(func, None, pfor_input, 4640 pfor_input.inputs[1:]) 4641 4642 new_branches.append(new_function) 4643 4644 outputs = [] 4645 outputs = control_flow_ops.switch_case(branch_idx, new_branches) 4646 return [wrap(t, True) for t in outputs] 4647 4648 4649class WhileV2: 4650 """Object for vectorizing V2 while_loop op.""" 4651 4652 def __init__(self, pfor_input): 4653 self._pfor_input = pfor_input 4654 self._pfor = pfor_input.pfor 4655 cond_func_name = pfor_input.get_attr("cond").name 4656 self._cond_func = pfor_input.op.graph._get_function(compat.as_bytes( 4657 cond_func_name)) 4658 body_func_name = pfor_input.get_attr("body").name 4659 self._body_func = pfor_input.op.graph._get_function(compat.as_bytes( 4660 body_func_name)) 4661 if self._cond_func is None or self._body_func is None: 4662 raise ValueError("Error extracting cond and body functions for op " 4663 f"{self._pfor_input.op}.") 4664 # Indices of inputs that are passed unchanged through the while loop body. 4665 # Typically these are tensors captured from outside the body context. 4666 self._body_pass_through_indices = set() 4667 for i, (inp, out) in enumerate(zip(self._body_func.graph.inputs, 4668 self._body_func.graph.outputs)): 4669 if id(inp) == id(out): 4670 self._body_pass_through_indices.add(i) 4671 self._parallel_iterations = self._pfor_input.get_attr("parallel_iterations") 4672 4673 def _output_shapes(self): 4674 # Calculate output shape for vectorized loop. This will be used as 4675 # shape_invariant. Merges shape inference outputs with the `output_shapes` 4676 # attribute of the op. 4677 output_shapes = [out.shape for out in self._pfor_input.op.outputs] 4678 shapes = self._pfor_input.get_attr("output_shapes") 4679 if not shapes: 4680 shapes = [tensor_shape.TensorShape(None) for _ in output_shapes] 4681 else: 4682 shapes = [tensor_shape.TensorShape(shape) for shape in shapes] 4683 for i, shape in enumerate(shapes): 4684 shape = shape.merge_with(output_shapes[i]) 4685 pfor_input = self._pfor_input.input(i) 4686 if pfor_input.is_stacked: 4687 if _is_variant_with_internal_stacking(pfor_input.t): 4688 shape = tensor_shape.TensorShape([]).concatenate(shape) 4689 else: 4690 shape = tensor_shape.TensorShape([None]).concatenate(shape) 4691 output_shapes[i] = shape 4692 assert len(output_shapes) == self._pfor_input.num_inputs 4693 return output_shapes 4694 4695 def _init_values(self): 4696 """Create arguments passed to converted while_loop.""" 4697 loop_len = self._pfor.loop_len_vector[0] 4698 inputs = [] 4699 # TensorArrays for outputs of converted while loop 4700 output_tas = [] 4701 4702 with ops.name_scope("while_init"): 4703 for inp in self._pfor_input.inputs: 4704 inputs.append(inp.t) 4705 variant_type_id = _variant_type_id(inp.t) 4706 if variant_type_id in _INTERNAL_STACKING_TYPE_IDS: 4707 if variant_type_id != full_type_pb2.TFT_ARRAY: 4708 raise NotImplementedError( 4709 "While loop conversion is only supported for TensorLists. Got " 4710 f"another variant {inp.t}, probably an optional. Please file " 4711 "a bug.") 4712 4713 # For TensorLists, the input format is: 4714 # 4715 # List[user_list_len, Tensor[loop_len, ...]] 4716 # 4717 # rather than the usual 4718 # 4719 # Tensor[loop_len, ...] 4720 # 4721 # The body of the loop will take and return lists in this "internal 4722 # vectorization" format, so we want to keep it that way as much as 4723 # possible. We'll accumulate finished iterations (only relevant for 4724 # pfor-loop-variant while_loop conditions) in an accumulator with 4725 # type : 4726 # 4727 # List[user_list_len, List[loop_len, Tensor[...]]] 4728 # 4729 # This means that each while_loop iteration, we'll iterate over the 4730 # length of the TensorList, dividing done/remaining pfor loop indices 4731 # and scattering the done indices into the inner nested list of the 4732 # accumulator. 4733 element_shape = list_ops.tensor_list_element_shape( 4734 inp.t, dtypes.int32) 4735 if inp.is_stacked: 4736 # Shapes may be tf.constant(-1) for fully dynamic, in which case 4737 # slicing is an error. 4738 element_shape = control_flow_ops.cond( 4739 math_ops.equal(array_ops.rank(element_shape), 0), 4740 lambda: element_shape, 4741 lambda: element_shape[1:]) 4742 dtype = _parse_variant_shapes_and_types(inp.t)[0].dtype 4743 4744 def _init_loop_body(index, output_ta): 4745 output_ta = output_ta.write( 4746 index, 4747 list_ops.tensor_list_reserve(element_shape, loop_len, dtype)) 4748 return index + 1, output_ta 4749 4750 length = list_ops.tensor_list_length(inp.t) 4751 output_ta = tensor_array_ops.TensorArray( 4752 inp.t.dtype, # Variant; this is a nested TensorList 4753 size=length, 4754 dynamic_size=True, 4755 infer_shape=False) 4756 _, output_ta = control_flow_ops.while_loop( 4757 lambda index, _: index < length, 4758 _init_loop_body, 4759 [0, output_ta]) 4760 else: 4761 output_ta = tensor_array_ops.TensorArray( 4762 inp.t.dtype, 4763 size=loop_len, 4764 dynamic_size=False, 4765 infer_shape=True) 4766 output_tas.append(output_ta) 4767 # See documentation for __call__ for the structure of init_values. 4768 indices = ( 4769 math_ops.range(self._pfor.loop_len_vector[0]) 4770 if self._pfor.all_indices_partitioned else self._pfor.all_indices) 4771 return [True, indices] + inputs + output_tas 4772 4773 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 4774 """Handles case when condition is pfor loop invariant.""" 4775 # Note that all iterations end together. So we don't need to partition the 4776 # inputs. 4777 not_all_done = array_ops.reshape(conditions, []) 4778 return not_all_done, indices, inputs, output_tas 4779 4780 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 4781 output_tas): 4782 """Handles case when condition is pfor loop dependent.""" 4783 # Compute if all iterations are done. 4784 not_all_done = math_ops.reduce_any(conditions) 4785 conditions_int = math_ops.cast(conditions, dtypes.int32) 4786 # Partition the indices. 4787 done_indices, new_indices = data_flow_ops.dynamic_partition( 4788 indices, conditions_int, 2) 4789 4790 new_inputs = [] 4791 new_output_tas = [] 4792 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 4793 pass_through = i in self._body_pass_through_indices 4794 if not pass_through and _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: 4795 shape_and_type = _parse_variant_shapes_and_types(inp)[0] 4796 element_shape = list_ops.tensor_list_element_shape(inp, dtypes.int32) 4797 user_list_len = list_ops.tensor_list_length(inp) 4798 4799 def _split_vectorized_ta_element(index, new_inp, new_out_ta): 4800 elem = list_ops.tensor_list_get_item(inp, index, shape_and_type.dtype, 4801 element_shape) 4802 if stacked: 4803 done_elem, new_elem = data_flow_ops.dynamic_partition( 4804 elem, conditions_int, 2) 4805 new_inp = list_ops.tensor_list_set_item(new_inp, index, new_elem) 4806 else: 4807 done_elem = _stack(elem, [array_ops.size(done_indices)]).t 4808 done_accum = new_out_ta.read(index) 4809 done_accum = list_ops.tensor_list_scatter( 4810 tensor=done_elem, indices=done_indices, input_handle=done_accum) 4811 new_out_ta = new_out_ta.write(index, done_accum) 4812 return index + 1, new_inp, new_out_ta 4813 4814 length = list_ops.tensor_list_length(inp) 4815 new_inp = list_ops.tensor_list_reserve( 4816 tensor_shape.TensorShape([None]) 4817 + tensor_shape.TensorShape(shape_and_type.shape)[1:], 4818 user_list_len, shape_and_type.dtype) 4819 _, new_inp, out_ta = control_flow_ops.while_loop( 4820 lambda index, unused_new_inp, unused_new_out_ta: index < length, 4821 _split_vectorized_ta_element, 4822 [0, new_inp, output_tas[i]]) 4823 else: 4824 # Partition the inputs. 4825 if stacked: 4826 done_inp, new_inp = data_flow_ops.dynamic_partition( 4827 inp, conditions_int, 2) 4828 else: 4829 if not pass_through: 4830 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 4831 new_inp = inp 4832 4833 out_ta = output_tas[i] 4834 if not pass_through: 4835 # Note that done_indices can be empty. done_inp should also be empty 4836 # in that case. 4837 out_ta = out_ta.scatter(done_indices, done_inp) 4838 new_inputs.append(new_inp) 4839 new_output_tas.append(out_ta) 4840 4841 assert len(new_output_tas) == len(output_tas) 4842 assert len(new_inputs) == len(inputs) 4843 return not_all_done, new_indices, new_inputs, new_output_tas 4844 4845 def _process_body(self, inputs_stacked, new_indices, cond_stacked, 4846 new_inputs, not_all_done): 4847 """Convert the body function.""" 4848 # This is used to store the indices of inputs to the while op that need to 4849 # be stacked. This stacking may be needed in cases where the input to the 4850 # while_loop is loop_invariant but the corresponding output is not. 4851 mismatching_stacked_indices = [] 4852 4853 def true_fn(): 4854 """Converts the body function for all but last iteration.""" 4855 wrapped_inputs = [wrap(inp, stacked) for inp, stacked in 4856 zip(new_inputs, inputs_stacked)] 4857 # Note the iterative process below to figure out loop invariance. 4858 # Here we iterate on vectorization process till a fixed point. The issue 4859 # is that the while body can take pfor loop invariant inputs but return 4860 # loop variant outputs. For any loop variant output, the corresponding 4861 # input has to be then made loop variant (since subsequent while 4862 # iterations will need to see loop variant values). 4863 # However once we make a new input loop variant, we might make other 4864 # outputs loop variant. Hence we need to iterate till we get fixed point. 4865 while True: 4866 if self._pfor.all_indices_partitioned: 4867 indices = array_ops.gather(self._pfor.all_indices, new_indices) 4868 else: 4869 indices = new_indices 4870 body_pfor = PFor( 4871 loop_var=self._pfor.loop_var, 4872 loop_len=array_ops.size(new_indices), 4873 pfor_ops=self._body_func.graph.get_operations(), 4874 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 4875 all_indices=indices, 4876 all_indices_partitioned=(self._pfor.all_indices_partitioned or 4877 cond_stacked), 4878 pfor_config=self._pfor.pfor_config) 4879 stacking_mismatch = False 4880 outputs = _convert_function_call(self._body_func, 4881 body_pfor, 4882 wrapped_inputs) 4883 for i, (out, inp) in enumerate(zip(outputs, wrapped_inputs)): 4884 if out.is_stacked != inp.is_stacked: 4885 stacking_mismatch = True 4886 mismatching_stacked_indices.append(i) 4887 stacked = _stack(inp.t, [array_ops.size(new_indices)]) 4888 if inp.t.dtype == dtypes.variant: 4889 stacked = wrap( 4890 _tile_variant_with_length(stacked.t, 4891 [array_ops.size(new_indices)])) 4892 wrapped_inputs[i] = stacked 4893 if not stacking_mismatch: 4894 if mismatching_stacked_indices: 4895 # We needed to stack some inputs. This code will be abandoned and 4896 # should not get executed. Hence we simply return `new_inputs` to 4897 # make sure the graph construction code completes. 4898 with ops.control_dependencies([ 4899 control_flow_ops.Assert( 4900 False, ["pfor ERROR: this branch should never execute"])]): 4901 return [array_ops.identity(x) for x in new_inputs] 4902 else: 4903 return [out.t for out in outputs] 4904 4905 # If all are done, we simply return `new_inputs`. Else we need to run the 4906 # body function. 4907 return control_flow_ops.cond( 4908 not_all_done, 4909 true_fn, 4910 lambda: list(new_inputs)), mismatching_stacked_indices 4911 4912 def __call__(self): 4913 """Converter for the V2 while_loop. 4914 4915 The conversion of a while_loop is another while_loop. 4916 4917 The arguments to this converted while_loop are as follows: 4918 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 4919 are done. 4920 indices: int32 1-D Tensor storing the id of the pfor iterations that are not 4921 done. 4922 args: Remaining arguments. These can be divided into 2 categories: 4923 - The first set of arguments correspond one-to-one to the inputs to the 4924 unvectorized while_loop. 4925 - The second set are TensorArrays, corresponding one-to-one to each output 4926 of the unvectorized while_loop. Each TensorArray has `PFor.loop_len` 4927 elements, i.e. the number of pfor iterations. At the end, the i'th 4928 element of each TensorArray will contain the output computed by the i'th 4929 iteration of pfor. Note that elements can be written into these tensors 4930 arrays in any order, depending on when the corresponding pfor iteration 4931 is done. 4932 In each iteration, the while_loop body recomputes the condition for all 4933 active pfor iterations to see which of them are now done. It then partitions 4934 all the inputs and passes them along to the converted body. Values for all 4935 the iterations that are done are written to TensorArrays indexed by the pfor 4936 iteration number. When all iterations are done, the TensorArrays are stacked 4937 to get the final value. 4938 4939 Returns: 4940 List of converted outputs. 4941 """ 4942 output_shapes = self._output_shapes() 4943 # Note that we use these lists as a hack since we need the `body` to compute 4944 # these values during construction of the while_loop graph. 4945 cond_is_stacked = [None] 4946 indices_to_stack = [] 4947 4948 def cond(not_all_done, *_): 4949 return not_all_done 4950 4951 def body(not_all_done, indices, *args): 4952 # See documentation for __call__ for the structure of *args. 4953 num_inputs = self._pfor_input.num_inputs 4954 inputs = args[:num_inputs] 4955 output_tas = args[num_inputs:] 4956 inputs_stacked = [x.is_stacked for x in self._pfor_input.inputs] 4957 assert len(inputs) >= len(output_tas) 4958 assert len(inputs) == len(inputs_stacked) 4959 # Convert condition 4960 with ops.name_scope("while_cond"): 4961 # Note that we set all_indices_partitioned to True here. At this point 4962 # we don't know if indices will be partitioned. Hence we use the 4963 # conservative value. 4964 cond_pfor = PFor( 4965 loop_var=self._pfor.loop_var, 4966 loop_len=array_ops.size(indices), 4967 pfor_ops=self._cond_func.graph.get_operations(), 4968 fallback_to_while_loop=self._pfor.fallback_to_while_loop, 4969 all_indices=indices, 4970 all_indices_partitioned=True, 4971 pfor_config=self._pfor.pfor_config) 4972 4973 wrapped_inputs = [wrap(inp, stacked) for inp, stacked 4974 in zip(inputs, inputs_stacked)] 4975 conditions, cond_stacked, _ = _convert_function_call( 4976 self._cond_func, 4977 cond_pfor, 4978 wrapped_inputs)[0] 4979 cond_is_stacked[0] = cond_stacked 4980 4981 # Recompute the new condition, write outputs of done iterations, and 4982 # partition the inputs if needed. 4983 if not cond_stacked: 4984 (not_all_done, new_indices, new_inputs, 4985 new_output_tas) = self._process_cond_unstacked(conditions, indices, 4986 inputs, output_tas) 4987 else: 4988 (not_all_done, new_indices, new_inputs, 4989 new_output_tas) = self._process_cond_stacked(conditions, indices, 4990 inputs, inputs_stacked, 4991 output_tas) 4992 # Convert body 4993 with ops.name_scope("while_body"): 4994 # Compute the outputs from the body. 4995 new_outputs, mismatching_stacked_indices = self._process_body( 4996 inputs_stacked, new_indices, cond_stacked, new_inputs, not_all_done) 4997 4998 indices_to_stack[:] = mismatching_stacked_indices 4999 for i, new_output in enumerate(new_outputs): 5000 new_output.set_shape(output_shapes[i]) 5001 new_args = ([not_all_done, new_indices] + new_outputs + 5002 list(new_output_tas)) 5003 return tuple(new_args) 5004 5005 # Note that we run the code below in a function since we might abandon the 5006 # generated code in cases where the conversion dictates that some inputs be 5007 # further stacked. Hence we run the graph construction using 5008 # `get_concrete_function` and avoid calling the constructed function if not 5009 # needed. 5010 @def_function.function 5011 def while_fn(): 5012 # Create init_values that will be passed to the while_loop. 5013 init_values = self._init_values() 5014 ta_shape_invariants = [tensor_shape.TensorShape([]) for _ in 5015 self._pfor_input.outputs] 5016 shape_invariants = ( 5017 [tensor_shape.TensorShape([]), tensor_shape.TensorShape([None])] 5018 + output_shapes + ta_shape_invariants) 5019 5020 while_outputs = control_flow_ops.while_loop( 5021 cond, body, init_values, 5022 shape_invariants=shape_invariants, 5023 parallel_iterations=self._parallel_iterations) 5024 if indices_to_stack: 5025 # This function will be abandoned. 5026 return while_outputs 5027 else: 5028 num_inputs = self._pfor_input.num_inputs 5029 new_inputs = while_outputs[2:num_inputs+2] 5030 output_tas = while_outputs[num_inputs+2:] 5031 assert cond_is_stacked[0] is not None 5032 outputs = [] 5033 for i, inp in enumerate(new_inputs): 5034 if cond_is_stacked[0]: 5035 if i in self._body_pass_through_indices: 5036 outputs.append(init_values[i + 2]) 5037 else: 5038 ta = output_tas[i] 5039 if _variant_type_id(inp) == full_type_pb2.TFT_ARRAY: 5040 shape_and_type = _parse_variant_shapes_and_types(inp)[0] 5041 length = list_ops.tensor_list_length(inp) 5042 5043 # We have been accumulating values in a: 5044 # 5045 # List[user_list_len, List[loop_len, Tensor[...]]] 5046 # 5047 # We want to return an output in the same format as the input: 5048 # 5049 # List[user_list_len, Tensor[loop_len, ...]] 5050 # 5051 # So we need to loop over the list and stack its contents. 5052 def _stack_loop_body(index, output_list): 5053 current_value = ta.read(index) 5054 output_list = list_ops.tensor_list_set_item( 5055 output_list, index, 5056 list_ops.tensor_list_stack( 5057 current_value, shape_and_type.dtype)) 5058 return index + 1, output_list 5059 5060 output_list = list_ops.tensor_list_reserve( 5061 tensor_shape.TensorShape(shape_and_type.shape), length, 5062 shape_and_type.dtype) 5063 _, output_list = control_flow_ops.while_loop( 5064 lambda index, _: index < length, 5065 _stack_loop_body, 5066 [0, output_list]) 5067 outputs.append(output_list) 5068 else: 5069 outputs.append(ta.stack()) 5070 else: 5071 outputs.append(inp) 5072 return outputs 5073 5074 _ = while_fn.get_concrete_function() 5075 if indices_to_stack: 5076 # Need to abandon the current conversion, stack some inputs and restart. 5077 self._pfor_input.stack_inputs( 5078 stack_indices=indices_to_stack, tile_variants=True) 5079 # Note that this call will recurse at most one time. The first call will 5080 # do the required stacking, based on the iterative procedure in 5081 # _process_body, and the next invocation to __call__ should not need to do 5082 # any more stacking. 5083 # We invoke `self()` here as a way to discard any corrupted state. 5084 return self() 5085 else: 5086 outputs = while_fn() 5087 wrapped_outputs = [] 5088 for i, (out, inp) in enumerate(zip(outputs, self._pfor_input.inputs)): 5089 if i not in self._body_pass_through_indices and cond_is_stacked[0]: 5090 wrapped_outputs.append(wrap(out, True)) 5091 else: 5092 wrapped_outputs.append(wrap(out, inp.is_stacked)) 5093 return wrapped_outputs 5094 5095 5096@RegisterPFor("StatelessWhile") 5097@RegisterPFor("While") 5098def _convert_while(pfor_input): 5099 converter = WhileV2(pfor_input) 5100 return converter() 5101 5102 5103# spectral_ops 5104 5105 5106@RegisterPForWithArgs("FFT", gen_spectral_ops.fft) 5107@RegisterPForWithArgs("FFT2D", gen_spectral_ops.fft2d) 5108@RegisterPForWithArgs("FFT3D", gen_spectral_ops.fft3d) 5109@RegisterPForWithArgs("IFFT", gen_spectral_ops.ifft) 5110@RegisterPForWithArgs("IFFT2D", gen_spectral_ops.ifft2d) 5111@RegisterPForWithArgs("IFFT3D", gen_spectral_ops.ifft3d) 5112def _convert_fft(pfor_input, _, op_func): 5113 return wrap(op_func(pfor_input.stacked_input(0)), True) 5114 5115 5116@RegisterPForWithArgs("RFFT", gen_spectral_ops.rfft, "Tcomplex") 5117@RegisterPForWithArgs("RFFT2D", gen_spectral_ops.rfft2d, "Tcomplex") 5118@RegisterPForWithArgs("RFFT3D", gen_spectral_ops.rfft3d, "Tcomplex") 5119@RegisterPForWithArgs("IRFFT", gen_spectral_ops.irfft, "Treal") 5120@RegisterPForWithArgs("IRFFT2D", gen_spectral_ops.irfft2d, "Treal") 5121@RegisterPForWithArgs("IRFFT3D", gen_spectral_ops.irfft3d, "Treal") 5122def _convert_rfft(pfor_input, _, op_func, attr_name): 5123 inp = pfor_input.stacked_input(0) 5124 fft_length = pfor_input.unstacked_input(1) 5125 attr = pfor_input.get_attr(attr_name) 5126 return wrap(op_func(inp, fft_length, attr), True) 5127