1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Compiled parallel-for loop.""" 16# pylint: disable=missing-docstring,g-direct-tensorflow-import 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import collections 23import string 24import sys 25import traceback 26 27import six 28 29from tensorflow.compiler.tf2xla.python import xla 30from tensorflow.python.eager import context 31from tensorflow.python.eager import def_function 32from tensorflow.python.eager import execute 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import func_graph 36from tensorflow.python.framework import ops 37from tensorflow.python.framework import sparse_tensor 38from tensorflow.python.framework import tensor_shape 39from tensorflow.python.framework import tensor_spec 40from tensorflow.python.framework import tensor_util 41from tensorflow.python.ops import array_ops 42from tensorflow.python.ops import bitwise_ops 43from tensorflow.python.ops import check_ops 44from tensorflow.python.ops import control_flow_ops 45from tensorflow.python.ops import data_flow_ops 46from tensorflow.python.ops import gen_array_ops 47from tensorflow.python.ops import gen_linalg_ops 48from tensorflow.python.ops import gen_nn_ops 49from tensorflow.python.ops import gen_parsing_ops 50from tensorflow.python.ops import gen_random_ops 51from tensorflow.python.ops import gen_sparse_ops 52from tensorflow.python.ops import linalg_ops 53from tensorflow.python.ops import map_fn 54from tensorflow.python.ops import math_ops 55from tensorflow.python.ops import nn_ops 56from tensorflow.python.ops import parsing_ops 57from tensorflow.python.ops import sparse_ops 58from tensorflow.python.ops import special_math_ops 59from tensorflow.python.ops import tensor_array_ops 60from tensorflow.python.platform import flags 61from tensorflow.python.platform import tf_logging as logging 62from tensorflow.python.util import compat 63from tensorflow.python.util import nest 64from tensorflow.python.util import object_identity 65 66flags.DEFINE_bool( 67 "op_conversion_fallback_to_while_loop", False, 68 "If true, falls back to using a while loop for ops for " 69 "which a converter is not defined.") 70 71 72def _stack(t, length): 73 """stacks `t` `length` times.""" 74 ones = array_ops.ones_like(array_ops.shape(t)) 75 multiples = array_ops.concat([length, ones], 0) 76 t = array_ops.tile(array_ops.expand_dims(t, 0), multiples) 77 return wrap(t, True) 78 79 80# The following stateful ops can be safely called once, and with the same 81# signature as the unconverted version, if their inputs are loop invariant. 82# TODO(agarwal): implement a strategy for converting Variable reads/writes. The 83# plan is to map each read/write in the loop_fn to a corresponding merged 84# read/write in the converted graph. Writes need to be mergeable (e.g. 85# AssignAdd) to be used in `pfor`. Given a certain read/write order in the 86# loop_fn, doing a one-to-one conversion will simulate executing such 87# instructions in lock-step across all iterations. 88passthrough_stateful_ops = set([ 89 "VariableV2", 90 "VarHandleOp", 91 "ReadVariableOp", 92 "StackV2", 93 "TensorArrayWriteV3", 94 "TensorArrayReadV3", 95 "TensorArraySizeV3", 96]) 97 98 99def _is_stateful_pfor_op(op): 100 if isinstance(op, WhileOp): 101 return op.is_stateful 102 if op.type == "Const": 103 # Const didn't have an op_def. 104 return False 105 if op.type in passthrough_stateful_ops: 106 return False 107 assert hasattr(op, "op_def") and op.op_def is not None, op 108 return op.op_def.is_stateful 109 110 111# pylint: disable=protected-access 112class WhileOp(object): 113 """Object for storing state for converting the outputs of a while_loop.""" 114 115 def __init__(self, exit_node, pfor_ops, pfor_config): 116 """Initializer. 117 118 Args: 119 exit_node: A tensor output from the while_loop. 120 pfor_ops: list of ops inside the current pfor loop. 121 pfor_config: PForConfig object used while constructing loop body. 122 """ 123 self._pfor_config = pfor_config 124 self._pfor_ops = set(pfor_ops) 125 self._pfor_op_ids = set(x._id for x in pfor_ops) 126 assert isinstance(exit_node, ops.Tensor) 127 self._while_context = exit_node.op._get_control_flow_context() 128 assert isinstance(self._while_context, control_flow_ops.WhileContext) 129 self._context_name = self._while_context.name 130 self._condition = self._while_context.pivot.op.inputs[0] 131 # Parts of an external while_loop could be created inside a pfor loop. 132 # However for the purpose here, we declare such loops to be external. Also 133 # note that we check if the condition was created inside or outside to 134 # determine if the while_loop was first created inside or outside. 135 # TODO(agarwal): check that the Enter and Exit of this loop are unstacked. 136 self._is_inside_loop = self.op_is_inside_loop(self._condition.op) 137 if self._is_inside_loop: 138 for e in self._while_context.loop_exits: 139 assert self.op_is_inside_loop(e.op) 140 141 # Note the code below tries to reverse engineer an existing while_loop graph 142 # by assuming the following pattern of nodes. 143 # 144 # NextIteration <---- Body <--- Enter 145 # | ^ 146 # V ___| Y 147 # Enter -> Merge -> Switch___ 148 # ^ | N 149 # | V 150 # LoopCond Exit 151 152 # Node that elements in the list below correspond one-to-one with each 153 # other. i.e. these lists are the same size, and the i_th entry corresponds 154 # to different Operations/Tensors of a single cycle as illustrated above. 155 # List of Switch ops (ops.Operation) that feed into an Exit Node. 156 self._exit_switches = [] 157 # List of inputs (ops.Tensor) to NextIteration. 158 self._body_outputs = [] 159 # List of list of control inputs of the NextIteration nodes. 160 self._next_iter_control_inputs = [] 161 # List of Merge ops (ops.Operation). 162 self._enter_merges = [] 163 # List of output (ops.Tensor) of Exit nodes. 164 self._outputs = [] 165 166 # List of Enter Tensors. 167 # There are two types of Enter nodes: 168 # - The Enter nodes that are used in the `loop_vars` argument to 169 # `while_loop` (see 170 # https://www.tensorflow.org/api_docs/python/tf/while_loop). We collect 171 # these Enter nodes immediately below by tracing backwards from the Exit 172 # nodes via Exit <- Switch <- Merge <- Enter. You can see this chain in the 173 # diagram above. This allows us to have a 1:1 correspondence between the 174 # self._outputs and the first elements in self._enters. 175 # - The Enter nodes that are used only by the body. They don't appear in the 176 # `loop_vars` and are not returned from the `while_loop`. In Python code, 177 # they are usually captured by the body lambda. We collect them below by 178 # iterating over all the ops in the graph. They are appended to the end of 179 # self._enters or self._direct_enters, and don't correspond to any outputs 180 # in self._outputs. Note that we keep the resource/variant Enter nodes in 181 # self._direct_enters and the constructed while_loop's body uses them 182 # directly as opposed to passing them as loop variables. This is done 183 # because the while_body cannot partition the resource/variant Tensors, so 184 # it has to leave them unchanged. 185 self._enters = [] 186 self._direct_enters = [] 187 188 for e in self._while_context.loop_exits: 189 self._outputs.append(e.op.outputs[0]) 190 switch = e.op.inputs[0].op 191 assert switch.type == "Switch", switch 192 self._exit_switches.append(switch) 193 merge = switch.inputs[0].op 194 assert merge.type == "Merge", merge 195 self._enter_merges.append(merge) 196 enter = merge.inputs[0].op 197 assert enter.type == "Enter", enter 198 self._enters.append(enter.outputs[0]) 199 next_iter = merge.inputs[1].op 200 assert next_iter.type == "NextIteration", next_iter 201 self._body_outputs.append(next_iter.inputs[0]) 202 self._next_iter_control_inputs.append(next_iter.control_inputs) 203 204 # Collect all the Enter nodes that are not part of `loop_vars`, the second 205 # category described above. 206 # Also track whether the loop body has any stateful ops. 207 self._is_stateful = False 208 for op in ops.get_default_graph().get_operations(): 209 # TODO(agarwal): make sure this works with nested case. 210 control_flow_context = op._get_control_flow_context() 211 if control_flow_context is None: 212 continue 213 if control_flow_context.name == self._context_name: 214 self._is_stateful |= _is_stateful_pfor_op(op) 215 if op.type == "Enter": 216 output = op.outputs[0] 217 if output not in self._enters: 218 if output.dtype in (dtypes.resource, dtypes.variant): 219 if output not in self._direct_enters: 220 self._direct_enters.append(output) 221 else: 222 self._enters.append(output) 223 224 def __str__(self): 225 """String representation.""" 226 return "while_loop(%s)" % self.name 227 228 @property 229 def inputs(self): 230 """Input to all the Enter nodes.""" 231 return [x.op.inputs[0] for x in self._enters + self._direct_enters] 232 233 @property 234 def control_inputs(self): 235 """Control input to all the Enter nodes.""" 236 control_inputs = [] 237 for x in self._enters + self._direct_enters: 238 control_inputs.extend(x.op.control_inputs) 239 return control_inputs 240 241 @property 242 def outputs(self): 243 """Outputs of all the Exit nodes.""" 244 return self._outputs 245 246 @property 247 def name(self): 248 """Context name for the while loop.""" 249 return self._context_name 250 251 @property 252 def is_inside_loop(self): 253 """Returns true if the while_loop was created inside the pfor.""" 254 return self._is_inside_loop 255 256 def op_is_inside_loop(self, op): 257 """True if op was created inside the pfor loop body.""" 258 assert isinstance(op, ops.Operation) 259 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 260 # since it appears there tensorflow API could return different python 261 # objects representing the same Operation node. 262 return op._id in self._pfor_op_ids 263 264 @property 265 def is_stateful(self): 266 return self._is_stateful 267 268 @property 269 def pfor_converter(self): 270 """Return a converter for the while loop.""" 271 return self 272 273 def _init_pfor(self, parent_pfor, indices, cond_stacked, inputs, 274 inputs_stacked): 275 """Create a PFor object for converting parts of the while_loop. 276 277 Args: 278 parent_pfor: PFor object being used for converting the while_loop. 279 indices: int32 Tensor of ids for the iterations that are still active 280 (i.e. did not exit the while_loop). 281 cond_stacked: True if the while_loop condition is stacked. 282 inputs: list of input Tensors corresponding 1-to-1 with self._enters. Note 283 that these Tensors are a subset of the loop variables for the generated 284 while_loop. 285 inputs_stacked: List of booleans corresponding 1-to-1 with `inputs`, 286 indicating if the value is stacked or not. 287 288 Returns: 289 A PFor instance. The instance is initialized by adding conversion mappings 290 of nodes that will be external to the conversion that the returned 291 instance will be used for. e.g. Enter nodes as well as Merge and Switch 292 outputs are mapped to converted values. 293 """ 294 num_outputs = len(self._outputs) 295 assert len(inputs) == len(self._enters) 296 assert len(inputs_stacked) == len(self._enters) 297 loop_var = parent_pfor.loop_var 298 loop_len = array_ops.size(indices) 299 pfor = PFor( 300 loop_var, 301 loop_len, 302 pfor_ops=self._pfor_ops, 303 all_indices=indices, 304 all_indices_partitioned=cond_stacked, 305 pfor_config=self._pfor_config) 306 # Map all inputs of Enter nodes in self._direct_enters to their converted 307 # values. 308 for enter in self._direct_enters: 309 enter_input = enter.op.inputs[0] 310 converted_enter, stacked, is_sparse_stacked = parent_pfor._convert_helper( 311 enter_input) 312 # Since these are resources / variants, they should be unstacked. 313 assert not stacked and not is_sparse_stacked, (enter, converted_enter) 314 pfor._add_conversion(enter, wrap(converted_enter, False)) 315 316 # Map all Enter nodes to the inputs. 317 for enter, inp, stacked in zip(self._enters, inputs, inputs_stacked): 318 pfor._add_conversion(enter, wrap(inp, stacked)) 319 # Map outputs of Switch and Merge. 320 for i in range(num_outputs): 321 wrapped_inp = wrap(inputs[i], inputs_stacked[i]) 322 merge = self._enter_merges[i] 323 pfor._add_conversion(merge.outputs[0], wrapped_inp) 324 # Note that second output of Merge is typically not used, except possibly 325 # as a control dependency. To avoid trying to output the correct value, we 326 # employ a hack here. We output a dummy invalid value with an incorrect 327 # dtype. This will allow control dependency to work but if using it as an 328 # input, it should typically lead to errors during graph construction due 329 # to dtype mismatch. 330 # TODO(agarwal): Check in the original graph to see if there are any 331 # consumers of this Tensor that use it as an input. 332 pfor._add_conversion(merge.outputs[1], 333 wrap(constant_op.constant(-1.0), False)) 334 switch = self._exit_switches[i] 335 # Don't need to worry about switch.output[0] which will feed to Exit node. 336 pfor._add_conversion(switch.outputs[1], wrapped_inp) 337 return pfor 338 339 def _convert_enter(self, parent_pfor, enter): 340 """Converts an Enter node.""" 341 inp, stacked, _ = parent_pfor._convert_helper(enter.op.inputs[0]) 342 control_inputs = [] 343 for x in enter.op.control_inputs: 344 converted = parent_pfor._convert_helper(x) 345 if not isinstance(converted, ops.Operation): 346 converted = converted.t 347 control_inputs.append(converted) 348 if control_inputs: 349 with ops.control_dependencies(control_inputs): 350 inp = array_ops.identity(inp) 351 return inp, stacked 352 353 def _maybe_stacked(self, cache, inp): 354 """Heuristic to figue out if the coverting inp leads to a stacked value. 355 356 357 Args: 358 cache: map from Tensor to boolean indicating stacked/unstacked. 359 inp: input Tensor. 360 361 Returns: 362 True if `inp` could get stacked. If the function returns False, the 363 converted value should be guaranteed to be unstacked. If returning True, 364 it may or may not be stacked. 365 """ 366 if inp in cache: 367 return cache[inp] 368 if not self.op_is_inside_loop(inp.op): 369 return False 370 op = inp.op 371 output = False 372 if op.type in [ 373 "Shape", 374 "Rank", 375 "ShapeN", 376 "ZerosLike", 377 "TensorArrayV3", 378 "TensorArraySizeV3", 379 ]: 380 output = False 381 elif _is_stateful_pfor_op(op): 382 # This may be fairly aggressive. 383 output = True 384 elif op.type == "Exit": 385 # This may be fairly aggressive. 386 output = True 387 else: 388 for t in op.inputs: 389 if self._maybe_stacked(cache, t): 390 output = True 391 break 392 cache[inp] = output 393 return output 394 395 def _create_init_values(self, pfor_input): 396 """Create arguments passed to converted while_loop.""" 397 with ops.name_scope("while_init"): 398 loop_len_vector = pfor_input.pfor.loop_len_vector 399 loop_len = loop_len_vector[0] 400 num_outputs = len(self._outputs) 401 402 inputs = [] 403 maybe_stacked_cache = {} 404 # Convert all the Enters. Need to do this before checking for stacking 405 # below. 406 for i, enter in enumerate(self._enters): 407 inp, stacked = self._convert_enter(pfor_input.pfor, enter) 408 inputs.append(inp) 409 maybe_stacked_cache[enter] = stacked 410 # Since this enter node is part of the `loop_vars`, it corresponds to an 411 # output and its preceding switch. We mark this switch's output the same 412 # stackness, to act at the base case for the logic below. Below, we will 413 # be going through the body figuring out which inputs might need to be 414 # stacked and which inputs can safely remain unstacked. 415 if i < num_outputs: 416 maybe_stacked_cache[self._exit_switches[i].outputs[1]] = stacked 417 418 # Shape invariants for init_values corresponding to self._enters. 419 input_shape_invariants = [] 420 # TensorArrays for outputs of converted while loop 421 output_tas = [] 422 # Shape invariants for output TensorArrays. 423 ta_shape_invariants = [] 424 # List of booleans indicating stackness of inputs, i.e. tensors 425 # corresponding to self._enters. 426 inputs_stacked = [] 427 for i, inp in enumerate(inputs): 428 enter = self._enters[i] 429 inp_stacked = self._maybe_stacked(maybe_stacked_cache, enter) 430 # Note that even when an input is unstacked, the body could make it 431 # stacked. we use a heuristic below to figure out if body may be making 432 # it stacked. 433 if i < num_outputs: 434 body_output = self._body_outputs[i] 435 if enter.op in self._pfor_ops: 436 body_output_stacked = self._maybe_stacked(maybe_stacked_cache, 437 body_output) 438 else: 439 # If constructed outside of pfor loop, then the output would not be 440 # stacked. 441 body_output_stacked = False 442 if body_output_stacked and not inp_stacked: 443 inp = _stack(inp, loop_len_vector).t 444 inputs[i] = inp 445 inp_stacked = True 446 # TODO(agarwal): other attributes for the TensorArray ? 447 output_tas.append(tensor_array_ops.TensorArray(inp.dtype, loop_len)) 448 ta_shape_invariants.append(tensor_shape.TensorShape(None)) 449 450 inputs_stacked.append(inp_stacked) 451 input_shape_invariants.append(tensor_shape.TensorShape(None)) 452 453 # See documentation for __call__ for the structure of init_values. 454 init_values = [True, pfor_input.pfor.all_indices] + inputs + output_tas 455 # TODO(agarwal): try stricter shape invariants 456 shape_invariants = ( 457 [tensor_shape.TensorShape(None), 458 tensor_shape.TensorShape(None)] + input_shape_invariants + 459 ta_shape_invariants) 460 461 return init_values, inputs_stacked, shape_invariants 462 463 def _process_cond_unstacked(self, conditions, indices, inputs, output_tas): 464 """Handles case when condition is unstacked. 465 466 Note that all iterations end together. So we don't need to partition the 467 inputs. When all iterations are done, we write the inputs to the 468 TensorArrays. Note that we only write to index 0 of output_tas. Since all 469 iterations end together, they can all be output together. 470 """ 471 not_all_done = array_ops.reshape(conditions, []) 472 new_output_tas = [] 473 # pylint: disable=cell-var-from-loop 474 for i, out_ta in enumerate(output_tas): 475 inp = inputs[i] 476 new_output_tas.append( 477 control_flow_ops.cond(not_all_done, lambda: out_ta, 478 lambda: out_ta.write(0, inp))) 479 # pylint: enable=cell-var-from-loop 480 return not_all_done, indices, inputs, new_output_tas 481 482 def _process_cond_stacked(self, conditions, indices, inputs, inputs_stacked, 483 output_tas): 484 num_outputs = len(self._outputs) 485 # Compute if all iterations are done. 486 not_all_done = math_ops.reduce_any(conditions) 487 conditions_int = math_ops.cast(conditions, dtypes.int32) 488 # Partition the indices. 489 done_indices, new_indices = data_flow_ops.dynamic_partition( 490 indices, conditions_int, 2) 491 492 new_inputs = [] 493 new_output_tas = [] 494 for i, (inp, stacked) in enumerate(zip(inputs, inputs_stacked)): 495 # Partition the inputs. 496 if stacked: 497 done_inp, new_inp = data_flow_ops.dynamic_partition( 498 inp, conditions_int, 2) 499 else: 500 # TODO(agarwal): avoid this stacking. See TODO earlier in 501 # _process_cond_unstacked. 502 done_inp = _stack(inp, [array_ops.size(done_indices)]).t 503 new_inp = inp 504 new_inputs.append(new_inp) 505 # For iterations that are done, write them to TensorArrays. 506 if i < num_outputs: 507 out_ta = output_tas[i] 508 # Note that done_indices can be empty. done_inp should also be empty in 509 # that case. 510 new_output_tas.append(out_ta.scatter(done_indices, done_inp)) 511 return not_all_done, new_indices, new_inputs, new_output_tas 512 513 def _process_body(self, pfor_input, inputs_stacked, new_indices, cond_stacked, 514 new_inputs, not_all_done): 515 """Convert the body function.""" 516 517 def true_fn(control_inputs, body_pfor, body_output, stacked): 518 """Converts the body function for all but last iteration. 519 520 This essentially converts body_output. Additionally, it needs to handle 521 any control dependencies on the NextIteration node. So it creates another 522 Identity node with the converted dependencies. 523 """ 524 converted_control_inp = [] 525 for x in control_inputs: 526 for t in x.outputs: 527 converted_control_inp.append(body_pfor._convert_helper(t).t) 528 if stacked: 529 # Note convert always does the stacking. 530 output = body_pfor.convert(body_output) 531 else: 532 output, convert_stacked, _ = body_pfor._convert_helper(body_output) 533 assert convert_stacked == stacked, body_output 534 with ops.control_dependencies(converted_control_inp): 535 return array_ops.identity(output) 536 537 body_pfor = self._init_pfor(pfor_input.pfor, new_indices, cond_stacked, 538 new_inputs, inputs_stacked) 539 new_outputs = [] 540 541 for i, (body_output, 542 stacked) in enumerate(zip(self._body_outputs, inputs_stacked)): 543 control_inp = self._next_iter_control_inputs[i] 544 out_dtype = body_output.dtype 545 # Note that we want to run the body only if not all pfor iterations are 546 # done. If all are done, we return empty tensors since these values will 547 # not be used. Notice that the value returned by the loop is based on 548 # TensorArrays and not directly on these returned values. 549 # pylint: disable=cell-var-from-loop 550 new_output = control_flow_ops.cond( 551 not_all_done, 552 lambda: true_fn(control_inp, body_pfor, body_output, stacked), 553 lambda: constant_op.constant([], dtype=out_dtype)) 554 # pylint: enable=cell-var-from-loop 555 new_outputs.append(new_output) 556 return new_outputs 557 558 def __call__(self, pfor_input): 559 """Converter for the while_loop. 560 561 The conversion of a while_loop is another while_loop. 562 563 The arguments to this converted while_loop are as follows: 564 not_all_done: Boolean scalar Tensor indicating if all the pfor iterations 565 are done. 566 indices: int32 1-D Tensor storing the id of the iterations that are not 567 done. 568 args: Remaining arguments. These can be divided into 3 categories: 569 - First set of arguments are the tensors that correspond to the initial 570 elements of self._enters. The elements that appear in original while 571 loop's `loop_vars`. 572 - The second set of arguments are the tensors that correspond to the 573 remaining elements of self._enters. These are the tensors that directly 574 enter the original while loop body. 575 - Finally, the last set of arguments are TensorArrays. These TensorArrays 576 correspond to the outputs of the original while_loop, i.e. to the 577 elements in self._outputs. Each TensorArray has `PFor.loop_len` 578 elements, i.e. the number of pfor iterations. At the end, the i'th 579 element of each TensorArray will contain the output computed by the 580 i'th iteration of pfor. Note that elements can be written into these 581 tensors arrays in any order, depending on when the corresponding pfor 582 iteration is done. 583 If the original while_loop had `k` tensors in its `loop_vars` and its body 584 directly captured `m` tensors, the `args` will contain `2 * k + m` values. 585 586 In each iteration, the while_loop body recomputes the condition for all 587 active pfor iterations to see which of them are now done. It then partitions 588 all the inputs and passes them along to the converted body. Values for all 589 the iterations that are done are written to TensorArrays indexed by the pfor 590 iteration number. When all iterations are done, the TensorArrays are stacked 591 to get the final value. 592 593 Args: 594 pfor_input: A PForInput object corresponding to the output of any Exit 595 node from this while loop. 596 597 Returns: 598 List of converted outputs. 599 """ 600 # Create init_values that will be passed to the while_loop. 601 init_values, inputs_stacked, shape_invariants = self._create_init_values( 602 pfor_input) 603 # Note that we use a list as a hack since we need the nested function body 604 # to set the value of cond_is_stacked. python2.x doesn't support nonlocal 605 # variables. 606 cond_is_stacked = [None] 607 608 def cond(not_all_done, *_): 609 return not_all_done 610 611 def body(not_all_done, indices, *args): 612 # See documentatin for __call__ for the structure of *args. 613 num_enters = len(self._enters) 614 inputs = args[:num_enters] 615 output_tas = args[num_enters:] 616 # TODO(agarwal): see which outputs have consumers and only populate the 617 # TensorArrays corresponding to those. Or do those paths get trimmed out 618 # from inside the while_loop body? 619 assert len(inputs) >= len(output_tas) 620 assert len(inputs) == len(inputs_stacked) 621 622 # Convert condition 623 with ops.name_scope("while_cond"): 624 # Note that we set cond_stacked to True here. At this point we don't 625 # know if it could be loop invariant, hence the conservative value is 626 # to assume stacked. 627 cond_pfor = self._init_pfor( 628 pfor_input.pfor, 629 indices, 630 cond_stacked=True, 631 inputs=inputs, 632 inputs_stacked=inputs_stacked) 633 conditions, cond_stacked, _ = cond_pfor._convert_helper(self._condition) 634 cond_is_stacked[0] = cond_stacked 635 636 # Recompute the new condition, write outputs of done iterations, and 637 # partition the inputs if needed. 638 if not cond_stacked: 639 (not_all_done, new_indices, new_inputs, 640 new_output_tas) = self._process_cond_unstacked(conditions, indices, 641 inputs, output_tas) 642 else: 643 (not_all_done, new_indices, new_inputs, 644 new_output_tas) = self._process_cond_stacked(conditions, indices, 645 inputs, inputs_stacked, 646 output_tas) 647 648 # Convert body 649 with ops.name_scope("while_body"): 650 # Compute the outputs from the body. 651 new_outputs = self._process_body(pfor_input, inputs_stacked, 652 new_indices, cond_stacked, new_inputs, 653 not_all_done) 654 655 # Note that the first num_outputs new values of inputs are computed using 656 # the body. Rest of them were direct Enters into the condition/body and 657 # the partitioning done earlier is sufficient to give the new value. 658 num_outputs = len(self._outputs) 659 new_args = ([not_all_done, new_indices] + new_outputs + 660 list(new_inputs[num_outputs:]) + new_output_tas) 661 return tuple(new_args) 662 663 while_outputs = control_flow_ops.while_loop( 664 cond, body, init_values, shape_invariants=shape_invariants) 665 output_tas = while_outputs[-len(self._outputs):] 666 outputs = [] 667 assert cond_is_stacked[0] is not None 668 for inp_stacked, ta in zip(inputs_stacked, output_tas): 669 if cond_is_stacked[0]: 670 outputs.append(wrap(ta.stack(), True)) 671 else: 672 # Note that if while_loop condition is unstacked, all iterations exit at 673 # the same time and we wrote those outputs in index 0 of the tensor 674 # array. 675 outputs.append(wrap(ta.read(0), inp_stacked)) 676 return outputs 677 678 679class _PforInput(object): 680 """Input object passed to registered pfor converters.""" 681 682 def __init__(self, pfor, op, inputs): 683 """Creates a _PforInput object. 684 685 Args: 686 pfor: PFor converter object. 687 op: the Operation object that is being converted. 688 inputs: list of WrappedTensor objects representing converted values of the 689 inputs of `op`. 690 """ 691 self.pfor = pfor 692 self._op = op 693 self._inputs = inputs 694 695 def stack_inputs(self, stack_indices=None): 696 """Stacks unstacked inputs at `stack_indices`. 697 698 Args: 699 stack_indices: indices of inputs at which stacking is done. If None, 700 stacking is done at all indices. 701 """ 702 if stack_indices is None: 703 stack_indices = range(len(self._inputs)) 704 length = self.pfor.loop_len_vector 705 for i in stack_indices: 706 inp = self._inputs[i] 707 if not inp.is_stacked: 708 self._inputs[i] = _stack(inp.t, length) 709 710 def expanddim_inputs_for_broadcast(self): 711 """Reshapes stacked inputs to prepare them for broadcast. 712 713 Since stacked inputs have an extra leading dimension, automatic broadcasting 714 rules could incorrectly try to expand dimensions before that leading 715 dimension. To avoid that, we reshape these stacked inputs to the maximum 716 rank they will need to be broadcasted to. 717 """ 718 if not self._inputs: 719 return 720 721 # Find max rank 722 def _get_rank(x): 723 rank = array_ops.rank(x.t) 724 if not x.is_stacked: 725 rank += 1 726 return rank 727 728 ranks = [_get_rank(x) for x in self._inputs] 729 max_rank = ranks[0] 730 for rank in ranks[1:]: 731 max_rank = math_ops.maximum(rank, max_rank) 732 733 for i, inp in enumerate(self._inputs): 734 if inp.is_stacked: 735 shape = array_ops.shape(inp.t) 736 rank_diff = array_ops.reshape(max_rank - ranks[i], [1]) 737 ones = array_ops.tile([1], rank_diff) 738 new_shape = array_ops.concat([shape[:1], ones, shape[1:]], axis=0) 739 self._inputs[i] = wrap(array_ops.reshape(inp.t, new_shape), True) 740 741 @property 742 def inputs(self): 743 return self._inputs 744 745 @property 746 def num_inputs(self): 747 return len(self._inputs) 748 749 def input(self, index): 750 assert len(self._inputs) > index, (index, self._inputs) 751 return self._inputs[index] 752 753 def stacked_input(self, index): 754 t, is_stacked, _ = self.input(index) 755 if not is_stacked: 756 op_type = self.op_type 757 op_def = getattr(self._op, "op_def", None) 758 if op_def is None: 759 input_name = "at index %d" % index 760 else: 761 input_name = "\"%s\"" % op_def.input_arg[index].name 762 raise ValueError( 763 "Input %s of op \"%s\" expected to be not loop invariant" % ( 764 input_name, op_type)) 765 return t 766 767 def unstacked_input(self, index): 768 t, is_stacked, _ = self.input(index) 769 if is_stacked: 770 op_type = self.op_type 771 op_def = getattr(self._op, "op_def", None) 772 if op_def is None: 773 input_name = "at index %d" % index 774 else: 775 input_name = "\"%s\"" % op_def.input_arg[index].name 776 raise ValueError("Input %s of op \"%s\" expected to be loop invariant" % ( 777 input_name, op_type)) 778 return t 779 780 @property 781 def op(self): 782 return self._op 783 784 @property 785 def op_type(self): 786 return self._op.type 787 788 def get_attr(self, attr): 789 return self._op.get_attr(attr) 790 791 @property 792 def outputs(self): 793 return self._op.outputs 794 795 def output(self, index): 796 assert index < len(self._op.outputs) 797 return self._op.outputs[index] 798 799 800_pfor_converter_registry = {} 801 802 803class RegisterPFor(object): 804 """Utility to register converters for pfor. 805 806 Usage: 807 @RegisterPFor(foo_op_type) 808 def _foo_converter(pfor_input): 809 ... 810 811 The above will register conversion function `_foo_converter` for handling 812 conversion of `foo_op_type`. These converters are called during vectorization 813 of a `pfor` loop body. For each operation node in this loop body, 814 the vectorization process will call the converter corresponding to the 815 operation type of the node. 816 817 During conversion, the registered function will be called with a single 818 argument `pfor_input`, of type `PForInput`, which will contain state needed 819 for the conversion. When the converter is called for a node, all its inputs 820 should already have been converted and these converted values are stored in 821 `pfor_input.inputs`. This registered function should output a list of 822 WrappedTensor objects with the same length as the number of outputs of the 823 node being converted. If the node had zero outputs, then it should return an 824 ops.Operation object. These new sets of nodes should implement the 825 functionality of running that operation for the number of iterations specified 826 by `pfor_input.pfor.loop_len_vector[0]` where the inputs of the node for each 827 iteration are picked from `pfor_inputs.inputs()`. 828 829 One tricky aspect of the conversion process is keeping track of, and 830 leveraging loop invariance of computation. Each converted input is a 831 WrappedTensor which indicates whether the input was loop invariant or not. If 832 the converted value is loop invariant, its rank should match the rank of the 833 corresponding tensor in the loop body, else its rank is larger by 1. The 834 converter should look at the loop invariance of the inputs and generate new 835 nodes based on that. Note that the converter will not be called if all inputs 836 are loop invariant and the operation is not stateful. The converter should 837 determine if its own output is loop invariant and `wrap` its output 838 accordingly. 839 840 Example: 841 842 Here, the converter is trying to convert a Reshape node in the loop body. This 843 node will have two inputs: the tensor to reshape, and the new shape. The 844 example here only handles the case where the shape is loop invariant. 845 846 @RegisterPFor("Reshape") 847 def _convert_reshape(pfor_input): 848 # We assume that input is not loop invariant. Call to `stacked_input` 849 # asserts that and returns the converted value. This value will have a rank 850 # larger by 1 compared to the rank of the input in the loop body. 851 t = pfor_input.stacked_input(0) 852 853 # We assume that shape input is loop invariant. Call to `unstacked_input` 854 # asserts that and returns the converted value. 855 shape = pfor_input.unstacked_input(1) 856 857 # We compute `new_shape` by prepending the number of iterations to the 858 # original shape. 859 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], 860 axis=0) 861 862 # The vectorized output involves reshaping the converted input `t` using 863 # `new_shape`. 864 new_output = array_ops.reshape(t, new_shape) 865 866 # The converted output is marked as not loop invariant using the call to 867 # wrap. 868 return wrap(new_output, True) 869 """ 870 871 def __init__(self, op_type): 872 """Creates an object to register a converter for op with type `op_type`.""" 873 self.op_type = op_type 874 875 def __call__(self, converter): 876 name = self.op_type 877 assert name not in _pfor_converter_registry, "Re-registering %s " % name 878 _pfor_converter_registry[name] = converter 879 return converter 880 881 882class RegisterPForWithArgs(RegisterPFor): 883 """Utility to register converters for pfor. 884 885 Usage: 886 @RegisteRPFor(foo_op_type, foo=value, ....) 887 def _foo_converter(pfor_input, foo=None, ....): 888 ... 889 890 See RegisterPFor for details on the conversion function. 891 `RegisterPForWithArgs` allows binding extra arguments to the 892 conversion function at registration time. 893 """ 894 895 def __init__(self, op_type, *args, **kw_args): 896 super(RegisterPForWithArgs, self).__init__(op_type) 897 self._args = args 898 self._kw_args = kw_args 899 900 def __call__(self, converter): 901 902 def _f(pfor_input): 903 return converter(pfor_input, self.op_type, *self._args, **self._kw_args) 904 905 super(RegisterPForWithArgs, self).__call__(_f) 906 return converter 907 908 909# TODO(agarwal): call raw_ops instead of calling these low level routines. 910def _create_op(op_type, inputs, op_dtypes, attrs=None): 911 """Utility to create an op.""" 912 op = ops.get_default_graph().create_op( 913 op_type, inputs, op_dtypes, attrs=attrs, compute_device=True) 914 flat_attrs = nest.flatten([(str(a), op.get_attr(str(a))) for a in attrs]) 915 execute.record_gradient(op_type, op.inputs, tuple(flat_attrs), op.outputs[:]) 916 return op 917 918 919WrappedTensor = collections.namedtuple("WrappedTensor", 920 ["t", "is_stacked", "is_sparse_stacked"]) 921"""Wrapper around the result of a Tensor conversion. 922 923The additional fields are useful for keeping track of the conversion state as 924data flows through the ops in the loop body. For every op whose output is a 925Tensor, its converter should return either a WrappedTensor or a list of 926WrappedTensors. 927 928Args: 929 t: The converted tensor 930 is_stacked: True if the tensor is stacked, i.e. represents the results of all 931 the iterations of the loop, where each row i of the tensor corresponds to 932 that op's output on iteration i of the loop. False if the tensor is not 933 stacked, i.e. represents the result of the op on of a single iteration of 934 the loop, where the result does not vary between iterations. 935 is_sparse_stacked: True if the tensor corresponds to a component tensor 936 (indices, values, or dense_shape) of a sparse tensor, and has been logically 937 stacked via a sparse conversion. 938""" 939 940 941def wrap(tensor, is_stacked=True, is_sparse_stacked=False): 942 """Helper to create a WrappedTensor object.""" 943 assert isinstance(is_stacked, bool) 944 assert isinstance(is_sparse_stacked, bool) 945 assert isinstance(tensor, ops.Tensor) 946 assert not is_sparse_stacked or is_stacked, ("If the wrapped tensor is " 947 "stacked via a sparse " 948 "conversion, it must also be " 949 "stacked.") 950 return WrappedTensor(tensor, is_stacked, is_sparse_stacked) 951 952 953def _fallback_converter(pfor_input): 954 logging.warn("Using a while_loop for converting %s", pfor_input.op_type) 955 output_dtypes = [x.dtype for x in pfor_input.outputs] 956 iters = pfor_input.pfor.loop_len_vector[0] 957 958 def while_body(i, *ta_list): 959 """Body of while loop.""" 960 inputs = [ 961 x[i, ...] if stacked else x for x, stacked, _ in pfor_input.inputs 962 ] 963 op_outputs = _create_op( 964 pfor_input.op_type, 965 inputs, 966 output_dtypes, 967 attrs=pfor_input.op.node_def.attr).outputs 968 969 outputs = [] 970 for out, ta in zip(op_outputs, ta_list): 971 assert isinstance(out, ops.Tensor) 972 outputs.append(ta.write(i, array_ops.expand_dims(out, 0))) 973 return tuple([i + 1] + outputs) 974 975 ta_list = control_flow_ops.while_loop( 976 lambda i, *ta: i < iters, while_body, [0] + 977 [tensor_array_ops.TensorArray(dtype, iters) for dtype in output_dtypes 978 ])[1:] 979 return tuple([wrap(ta.concat(), True) for ta in ta_list]) 980 981 982class PForConfig(object): 983 """A configuration object used to communicate with loop body function.""" 984 985 def __init__(self): 986 # This may be set to the number of iterations. 987 self._maybe_iters = None 988 # Map from reduction node, created by `reduce`, to the bundle of reduction 989 # function and arguments. 990 self._reduce_map = {} 991 992 def _has_reductions(self): 993 """True if some reductions where performed by loop body.""" 994 return len(self._reduce_map) 995 996 def _set_iters(self, iters): 997 """Set number of pfor iterations.""" 998 self._maybe_iters = iters 999 1000 def reduce(self, fn, *args): 1001 """Performs reduction `fn` on `args` vectorized across pfor iterations. 1002 1003 Note that `fn` is traced once inside the loop function context. Hence any 1004 captures or side-effects will happen in that context. Call to the traced 1005 version of `fn` happens during the construction of the vectorized code. 1006 1007 Note that this currently may not work inside a control flow construct. 1008 Args: 1009 fn: a reduction function. It will be called with arguments that have the 1010 same structure as *args but with individual values whose rank may be 1011 higher by 1 since they represent loop invariant vectorized versions of 1012 the corresponding Tensors in *args. 1013 *args: unvectorized Tensors. 1014 1015 Returns: 1016 The result of running `fn` on the vectorized versions of `*args`. These 1017 outputs will be available as loop invariant values to all the iterations. 1018 """ 1019 assert not context.executing_eagerly() 1020 # Creates a concrete function that will be used for reduction. 1021 tensor_specs = [] 1022 for arg in args: 1023 if not isinstance(arg, ops.Tensor): 1024 raise ValueError("Got a non-Tensor argument %s in reduce" % arg) 1025 batched_shape = tensor_shape.TensorShape( 1026 [self._maybe_iters]).concatenate(arg.shape) 1027 tensor_specs.append( 1028 tensor_spec.TensorSpec(shape=batched_shape, dtype=arg.dtype)) 1029 concrete_function = def_function.function(fn).get_concrete_function( 1030 *tensor_specs) 1031 1032 # Creates PlaceholderWithDefault and IdentityN nodes corresponding the the 1033 # reduction. 1034 pl_outputs = [] 1035 with ops.control_dependencies(args): 1036 for output in concrete_function.outputs: 1037 if not isinstance(output, ops.Tensor): 1038 raise ValueError("Got a non-Tensor output %s while running reduce" % 1039 output) 1040 # Note that we use placeholder_with_default just to make XLA happy since 1041 # it does not like placeholder ops. 1042 if output.shape.is_fully_defined(): 1043 dummy = array_ops.zeros(output.shape.as_list(), dtype=output.dtype) 1044 pl_outputs.append( 1045 array_ops.placeholder_with_default(dummy, shape=output.shape)) 1046 else: 1047 # TODO(agarwal): support case when under XLA and output.shape is not 1048 # fully defined. 1049 pl_outputs.append( 1050 array_ops.placeholder(output.dtype, shape=output.shape)) 1051 1052 reduction_op = array_ops.identity_n(pl_outputs)[0].op 1053 self._reduce_map[reduction_op] = (concrete_function, args) 1054 if len(reduction_op.outputs) == 1: 1055 return reduction_op.outputs[0] 1056 else: 1057 return tuple(reduction_op.outputs) 1058 1059 # TODO(agarwal): handle reductions inside control flow constructs. 1060 def reduce_concat(self, x): 1061 """Performs a concat reduction on `x` across pfor iterations. 1062 1063 Note that this currently may not work inside a control flow construct. 1064 Args: 1065 x: an unvectorized Tensor. 1066 1067 Returns: 1068 A Tensor that has rank one higher than `x`. The value is the vectorized 1069 version of `x`, i.e. stacking the value of `x` across different pfor 1070 iterations. 1071 """ 1072 return self.reduce(lambda y: y, x) 1073 1074 def reduce_mean(self, x): 1075 """Performs a mean reduction on `x` across pfor iterations. 1076 1077 Note that this currently may not work inside a control flow construct. 1078 Args: 1079 x: an unvectorized Tensor. 1080 1081 Returns: 1082 A Tensor that has same rank as `x`. The value is the mean of the values 1083 of `x` across the pfor iterations. 1084 """ 1085 return self.reduce(lambda y: math_ops.reduce_mean(y, axis=0), x) 1086 1087 def reduce_sum(self, x): 1088 """Performs a sum reduction on `x` across pfor iterations. 1089 1090 Note that this currently may not work inside a control flow construct. 1091 Args: 1092 x: an unvectorized Tensor. 1093 1094 Returns: 1095 A Tensor that has same rank as `x`. The value is the sum of the values 1096 of `x` across the pfor iterations. 1097 """ 1098 return self.reduce(lambda y: math_ops.reduce_sum(y, axis=0), x) 1099 1100 def _lookup_reduction(self, t): 1101 """Lookups Tensor `t` in the reduction maps.""" 1102 assert isinstance(t, ops.Tensor), t 1103 return self._reduce_map.get(t.op) 1104 1105 1106class PFor(object): 1107 """Implementation of rewrite of parallel-for loops. 1108 1109 This class takes a DAG or a set of DAGs representing the body of a 1110 parallel-for loop, and adds new operations to the graph that implements 1111 functionality equivalent to running that loop body for a specified number of 1112 iterations. This new set of nodes may or may not use a tensorflow loop 1113 construct. 1114 1115 The process of conversion does not delete or change any existing operations. 1116 It only adds operations that efficiently implement the equivalent 1117 functionality. We refer to the added ops as "converted ops". 1118 1119 The conversion process uses a simple greedy heuristic. It walks the loop body 1120 and tries to express the functionality of running each node in a loop with a 1121 new set of nodes. When converting an op several cases are possible: 1122 - The op is not inside the loop body. Hence it can be used as is. 1123 - The op does not depend on the iteration number and is stateless. In this 1124 case, it can be used as is. 1125 - The op is not stateful, and depends on iteration number only through control 1126 dependencies. In this case, we can create a single op with same inputs and 1127 attributes, but with "converted" control dependencies. 1128 - The op is not stateful, and all its inputs are loop invariant. In this 1129 case, similar to above, we can create a single op with same inputs and 1130 attributes, but with "converted" control dependencies. 1131 - The op is stateful or at least one of the inputs is not loop invariant. In 1132 this case, we run the registered converter for that op to create a set of 1133 converted ops. All nodes in the set will have converted control dependencies 1134 corresponding to control dependencies of the original op. If the op returned 1135 multiple outputs, "converted outputs" could be produced by different ops in 1136 this set. 1137 """ 1138 1139 def __init__(self, 1140 loop_var, 1141 loop_len, 1142 pfor_ops, 1143 all_indices=None, 1144 all_indices_partitioned=False, 1145 pfor_config=None): 1146 """Creates an object to rewrite a parallel-for loop. 1147 1148 Args: 1149 loop_var: ops.Tensor output of a Placeholder operation. The value should 1150 be an int32 scalar representing the loop iteration number. 1151 loop_len: A scalar or scalar Tensor representing the number of iterations 1152 the loop is run for. 1153 pfor_ops: List of all ops inside the loop body. 1154 all_indices: If not None, an int32 vector with size `loop_len` 1155 representing the iteration ids that are still active. These values 1156 should be unique and sorted. However they may not be contiguous. This is 1157 typically the case when inside a control flow construct which has 1158 partitioned the indices of the iterations that are being converted. 1159 all_indices_partitioned: If True, this object is being constructed from a 1160 control flow construct where not all the pfor iterations are guaranteed 1161 to be active. 1162 pfor_config: PForConfig object used while constructing the loop body. 1163 """ 1164 assert isinstance(loop_var, ops.Tensor) 1165 assert loop_var.op.type == "PlaceholderWithDefault" 1166 self._loop_var = loop_var 1167 loop_len_value = tensor_util.constant_value(loop_len) 1168 if loop_len_value is not None: 1169 loop_len = loop_len_value 1170 self._loop_len_vector = array_ops.reshape(loop_len, [1]) 1171 self._all_indices_partitioned = all_indices_partitioned 1172 if all_indices_partitioned: 1173 assert all_indices is not None 1174 self.all_indices = ( 1175 math_ops.range(loop_len) if all_indices is None else all_indices) 1176 1177 self._conversion_map = object_identity.ObjectIdentityDictionary() 1178 self._conversion_map[loop_var] = wrap(self.all_indices, True) 1179 self._pfor_ops = set(pfor_ops) 1180 self._pfor_op_ids = set(x._id for x in pfor_ops) 1181 self._pfor_config = pfor_config 1182 1183 def op_is_inside_loop(self, op): 1184 """True if op was created inside the pfor loop body.""" 1185 assert isinstance(op, ops.Operation) 1186 # Note that we use self._pfor_op_ids for the check and not self._pfor_ops 1187 # since it appears there tensorflow API could return different python 1188 # objects representing the same Operation node. 1189 return op._id in self._pfor_op_ids 1190 1191 def _convert_sparse(self, y): 1192 """Returns the converted value corresponding to SparseTensor y. 1193 1194 For SparseTensors, instead of stacking the component tensors separately, 1195 resulting in component tensors with shapes (N, m, rank), (N, m), and (N, 1196 rank) respectively for indices, values, and dense_shape (where N is the loop 1197 length and m is the number of sparse tensor values per loop iter), we want 1198 to logically stack the SparseTensors, to create a SparseTensor whose 1199 components are size (N * m, rank + 1), (N * m, ), and (rank + 1,) 1200 respectively. 1201 1202 Here, we try to get the conversion of each component tensor. 1203 If the tensors are stacked via a sparse conversion, return the resulting 1204 SparseTensor composed of the converted components. Otherwise, the component 1205 tensors are either unstacked or stacked naively. In the latter case, we 1206 unstack the component tensors to reform loop_len SparseTensor elements, 1207 then correctly batch them. 1208 1209 The unstacked tensors must have the same rank. Each dimension of each 1210 SparseTensor will expand to be the largest among all SparseTensor elements 1211 for that dimension. For example, if there are N SparseTensors of rank 3 1212 being stacked, with N dense shapes, where the i_th shape is (x_i, y_i, z_i), 1213 the new dense shape will be (N, max_i(x_i), max_i(y_i), max_i(z_i)). 1214 1215 Args: 1216 y: A tf.SparseTensor. 1217 1218 Returns: 1219 A tf.SparseTensor that is the converted value corresponding to y. 1220 """ 1221 outputs = [ 1222 self._convert_helper(t) for t in (y.indices, y.values, y.dense_shape) 1223 ] 1224 assert all(isinstance(o, WrappedTensor) for o in outputs) 1225 1226 if all(w.is_sparse_stacked for w in outputs): 1227 return sparse_tensor.SparseTensor(*[w.t for w in outputs]) 1228 1229 assert not any(w.is_sparse_stacked for w in outputs), ( 1230 "Error converting SparseTensor. All components should be logically " 1231 "stacked, or none.") 1232 1233 # If component tensors were not sparsely stacked, they are either unstacked 1234 # or stacked without knowledge that they are components of sparse tensors. 1235 # In this case, we have to restack them. 1236 return self._restack_sparse_tensor_logically( 1237 *[self._unwrap_or_tile(w) for w in outputs]) 1238 1239 def _restack_sparse_tensor_logically(self, indices, values, shape): 1240 sparse_tensor_rank = indices.get_shape().dims[-1].value 1241 if sparse_tensor_rank is not None: 1242 sparse_tensor_rank += 1 1243 1244 def fn(args): 1245 res = gen_sparse_ops.serialize_sparse( 1246 args[0], args[1], args[2], out_type=dtypes.variant) 1247 return res 1248 1249 # Applies a map function to the component tensors to serialize each 1250 # sparse tensor element and batch them all, then deserializes the batch. 1251 # TODO(rachelim): Try to do this without map_fn -- add the right offsets 1252 # to shape and indices tensors instead. 1253 result = map_fn.map_fn(fn, [indices, values, shape], dtype=dtypes.variant) 1254 return sparse_ops.deserialize_sparse( 1255 result, dtype=values.dtype, rank=sparse_tensor_rank) 1256 1257 def _unwrap_or_tile(self, wrapped_tensor): 1258 """Given a wrapped tensor, unwrap if stacked. Otherwise, tiles it.""" 1259 output, is_stacked = wrapped_tensor.t, wrapped_tensor.is_stacked 1260 if is_stacked: 1261 return output 1262 else: 1263 return _stack(output, self._loop_len_vector).t 1264 1265 def convert(self, y): 1266 """Returns the converted value corresponding to y. 1267 1268 Args: 1269 y: A ops.Tensor or a ops.Operation object. If latter, y should not have 1270 any outputs. 1271 1272 Returns: 1273 If y does not need to be converted, it returns y as is. Else it returns 1274 the "converted value" corresponding to y. 1275 """ 1276 if y is None: 1277 return None 1278 if isinstance(y, sparse_tensor.SparseTensor): 1279 return self._convert_sparse(y) 1280 assert isinstance(y, (ops.Tensor, ops.Operation)), y 1281 output = self._convert_helper(y) 1282 if isinstance(output, WrappedTensor): 1283 assert isinstance(y, ops.Tensor) 1284 return self._unwrap_or_tile(output) 1285 else: 1286 assert isinstance(y, ops.Operation) 1287 assert not y.outputs 1288 assert isinstance(output, ops.Operation) 1289 return output 1290 1291 def _was_converted(self, t): 1292 """True if t is not a conversion of itself.""" 1293 converted_t = self._conversion_map[t] 1294 return converted_t.t is not t 1295 1296 def _add_conversion(self, old_output, new_output): 1297 assert isinstance(old_output, (ops.Tensor, ops.Operation)), old_output 1298 assert isinstance(new_output, (WrappedTensor, ops.Operation)), new_output 1299 self._conversion_map[old_output] = new_output 1300 1301 def _convert_reduction(self, y): 1302 # Handle reductions. 1303 if self._pfor_config is None: 1304 return None 1305 reduction = self._pfor_config._lookup_reduction(y) 1306 if reduction is None: 1307 return None 1308 (reduction_fn, reduction_args) = reduction 1309 batched_args = [] 1310 for reduction_arg in reduction_args: 1311 assert isinstance(reduction_arg, ops.Tensor), reduction_arg 1312 # Tensor being reduced should already be converted due to a control 1313 # dependency on the created placeholder. 1314 # Note that in cases where reduction_arg is in an outer context, one 1315 # needs to locate the corresponding Enter node and use that to lookup 1316 # the conversion. 1317 # TODO(agarwal): handle reductions inside control flow constructs. 1318 assert reduction_arg in self._conversion_map, ( 1319 "Unable to handle reduction of %s, possibly as it was used " 1320 "inside a control flow construct. Note that reductions across " 1321 "pfor iterations are currently not supported inside control flow " 1322 "constructs." % reduction_arg) 1323 batched_arg = self._conversion_map[reduction_arg] 1324 batched_args.append(self._unwrap_or_tile(batched_arg)) 1325 outputs = reduction_fn(*batched_args) 1326 return [wrap(output, False) for output in nest.flatten(outputs)] 1327 1328 def _convert_helper(self, op_or_tensor): 1329 stack = [op_or_tensor] 1330 while stack: 1331 y = stack[0] 1332 if y in self._conversion_map: 1333 assert isinstance(self._conversion_map[y], 1334 (WrappedTensor, ops.Operation)) 1335 stack.pop(0) 1336 continue 1337 if isinstance(y, ops.Operation): 1338 assert not y.outputs, ( 1339 "We only support converting Operation objects with no outputs. " 1340 "Got %s", y) 1341 y_op = y 1342 else: 1343 assert isinstance(y, ops.Tensor), y 1344 y_op = y.op 1345 1346 is_while_loop = y_op.type == "Exit" 1347 if is_while_loop: 1348 while_op = WhileOp( 1349 y, pfor_ops=self._pfor_ops, pfor_config=self._pfor_config) 1350 is_inside_loop = while_op.is_inside_loop 1351 # If all nodes in the while_loop graph were created inside the pfor, we 1352 # treat the whole loop subgraph as a single op (y_op) and try to convert 1353 # it. For while_loops that are created completely or partially outside, 1354 # we treat them as external and should be able to simply return the Exit 1355 # node output as is without needing any conversion. Note that for 1356 # while_loops that are partially constructed inside, we assume they will 1357 # be loop invariant. If that is not the case, it will create runtime 1358 # errors since the converted graph would depend on the self._loop_var 1359 # placeholder. 1360 if is_inside_loop: 1361 y_op = while_op 1362 else: 1363 is_inside_loop = self.op_is_inside_loop(y_op) 1364 1365 # If this op was not created inside the loop body, we will return as is. 1366 # 1. Convert inputs and control inputs. 1367 1368 def _add_to_stack(x): 1369 if x not in self._conversion_map: 1370 stack.insert(0, x) 1371 return True 1372 else: 1373 return False 1374 1375 if is_inside_loop: 1376 added_to_stack = False 1377 for inp in y_op.inputs: 1378 added_to_stack |= _add_to_stack(inp) 1379 for cinp in y_op.control_inputs: 1380 if cinp.outputs: 1381 for t in cinp.outputs: 1382 added_to_stack |= _add_to_stack(t) 1383 else: 1384 added_to_stack |= _add_to_stack(cinp) 1385 if added_to_stack: 1386 continue 1387 1388 converted_inputs = [self._conversion_map[inp] for inp in y_op.inputs] 1389 some_input_converted = any(self._was_converted(x) for x in y_op.inputs) 1390 some_input_stacked = any(x.is_stacked for x in converted_inputs) 1391 1392 converted_control_ops = set() 1393 some_control_input_converted = False 1394 for cinp in y_op.control_inputs: 1395 if cinp.outputs: 1396 for t in cinp.outputs: 1397 converted_t = self._conversion_map[t] 1398 if self._was_converted(t): 1399 some_control_input_converted = True 1400 converted_control_ops.add(converted_t.t.op) 1401 else: 1402 converted_cinp = self._conversion_map[cinp] 1403 assert isinstance(converted_cinp, ops.Operation) 1404 if converted_cinp != cinp: 1405 some_control_input_converted = True 1406 converted_control_ops.add(converted_cinp) 1407 converted_control_ops = list(converted_control_ops) 1408 is_stateful = _is_stateful_pfor_op(y_op) 1409 else: 1410 converted_inputs = [] 1411 converted_control_ops = [] 1412 logging.vlog(3, "converting op:%s\ninputs:%s\ncontrol_inputs:%s", y_op, 1413 converted_inputs, converted_control_ops) 1414 1415 # 2. Convert y_op 1416 # If converting a while_loop, we let the while_loop convertor deal with 1417 # putting the control dependencies appropriately. 1418 control_dependencies = [] if is_while_loop else converted_control_ops 1419 with ops.control_dependencies(control_dependencies), ops.name_scope( 1420 y_op.name + "/pfor/"), ops.get_default_graph()._original_op(y_op): 1421 # Op is a placeholder for a reduction. 1422 reduce_output = self._convert_reduction(y) 1423 if reduce_output is not None: 1424 new_outputs = reduce_output 1425 # None of the inputs and control inputs were converted. 1426 elif ((not is_inside_loop or 1427 (not is_stateful and not some_input_converted and 1428 not some_control_input_converted)) and 1429 y.graph == ops.get_default_graph()): 1430 if y is y_op: 1431 assert not isinstance(y_op, WhileOp) 1432 new_outputs = y_op 1433 else: 1434 new_outputs = [wrap(x, False) for x in y_op.outputs] 1435 elif not (is_stateful or is_while_loop or some_input_stacked): 1436 # All inputs are unstacked or uncoverted but some control inputs are 1437 # converted. 1438 # TODO(rachelim): Handle the case where some inputs are sparsely 1439 # stacked (i.e. any(x.is_sparse_stacked for x in converted_inputs)) 1440 new_op = _create_op(y_op.type, [x.t for x in converted_inputs], 1441 [x.dtype for x in y_op.outputs], 1442 y_op.node_def.attr) 1443 if y is y_op: 1444 new_outputs = new_op 1445 else: 1446 new_outputs = [wrap(x, False) for x in new_op.outputs] 1447 else: 1448 # Either some inputs are not loop invariant or op is stateful. 1449 if hasattr(y_op, "pfor_converter"): 1450 converter = y_op.pfor_converter 1451 else: 1452 converter = _pfor_converter_registry.get(y_op.type, None) 1453 if converter is None: 1454 if flags.FLAGS.op_conversion_fallback_to_while_loop: 1455 converter = _fallback_converter 1456 else: 1457 raise ValueError("No converter defined for %s\n%s\ninputs: %s. " 1458 "\nEither add a converter or set " 1459 "--op_conversion_fallback_to_while_loop=True, " 1460 "which may run slower" % 1461 (y_op.type, y_op, converted_inputs)) 1462 # TODO(rachelim): Handle the case where some inputs are sparsely 1463 # stacked. We should only call the converter if it supports handling 1464 # those inputs. 1465 pfor_inputs = _PforInput(self, y_op, converted_inputs) 1466 try: 1467 new_outputs = converter(pfor_inputs) 1468 except Exception as e: # pylint: disable=broad-except 1469 logging.error("Got error while pfor was converting op %s" 1470 "with inputs %s\n, converted inputs %s\n" 1471 "%s\n" 1472 "Here are the pfor conversion stack traces:" % ( 1473 y_op, 1474 y_op.inputs[:], 1475 pfor_inputs.inputs, 1476 str(e))) 1477 original_op = y_op 1478 while isinstance(original_op, ops.Operation): 1479 logging.error("%s\ncreated at:\n %s" % ( 1480 original_op, 1481 " ".join(traceback.format_list(original_op.traceback)))) 1482 original_op = original_op._original_op 1483 six.reraise(e.__class__, e, sys.exc_info()[2]) 1484 1485 if isinstance(new_outputs, WrappedTensor): 1486 new_outputs = [new_outputs] 1487 assert isinstance(new_outputs, 1488 (list, tuple, ops.Operation)), new_outputs 1489 logging.vlog(2, "converted %s %s", y_op, new_outputs) 1490 1491 # Insert into self._conversion_map 1492 if y is y_op: 1493 assert isinstance(new_outputs, ops.Operation) 1494 self._add_conversion(y_op, new_outputs) 1495 else: 1496 assert len(y_op.outputs) == len(new_outputs), (y_op, y_op.outputs, 1497 new_outputs) 1498 for old_output, new_output in zip(y_op.outputs, new_outputs): 1499 assert isinstance(new_output, WrappedTensor), (new_output, y, y_op) 1500 assert old_output.dtype == new_output.t.dtype, (new_output, y, y_op) 1501 # Set shape for converted output. 1502 output_shape = old_output.shape 1503 if not new_output.is_sparse_stacked: 1504 if new_output.is_stacked: 1505 loop_len = tensor_util.constant_value(self.loop_len_vector) 1506 if loop_len is None: 1507 batch_dim = tensor_shape.TensorShape([None]) 1508 else: 1509 batch_dim = tensor_shape.TensorShape(loop_len) 1510 output_shape = batch_dim.concatenate(output_shape) 1511 new_output.t.set_shape(output_shape) 1512 self._add_conversion(old_output, new_output) 1513 stack.pop(0) 1514 1515 return self._conversion_map[op_or_tensor] 1516 1517 @property 1518 def loop_len_vector(self): 1519 """Returns a single element vector whose value is number of iterations.""" 1520 return self._loop_len_vector 1521 1522 @property 1523 def loop_var(self): 1524 """Returns placeholder loop variable.""" 1525 return self._loop_var 1526 1527 @property 1528 def pfor_ops(self): 1529 return self._pfor_ops 1530 1531 @property 1532 def pfor_config(self): 1533 return self._pfor_config 1534 1535 @property 1536 def all_indices_partitioned(self): 1537 """all_indices_partitioned property. 1538 1539 Returns: 1540 True if we are inside a control flow construct and not all pfor iterations 1541 may be active. 1542 """ 1543 return self._all_indices_partitioned 1544 1545 1546# The code below defines converters for different operations. Please see comment 1547# for RegisterPFor to see how converters should be defined. 1548 1549# nn_ops 1550 1551 1552def _flatten_first_two_dims(x): 1553 """Merges first two dimensions.""" 1554 old_shape = array_ops.shape(x) 1555 new_shape = array_ops.concat([[-1], old_shape[2:]], axis=0) 1556 return array_ops.reshape(x, new_shape) 1557 1558 1559def _unflatten_first_dim(x, first_dim): 1560 """Splits first dimension into [first_dim, -1].""" 1561 old_shape = array_ops.shape(x) 1562 new_shape = array_ops.concat([first_dim, [-1], old_shape[1:]], axis=0) 1563 return array_ops.reshape(x, new_shape) 1564 1565 1566def _inputs_with_flattening(pfor_input, input_indices): 1567 """Stacks and flattens first dim of inputs at indices `input_indices`.""" 1568 if input_indices is None: 1569 input_indices = [] 1570 pfor_input.stack_inputs(stack_indices=input_indices) 1571 inputs = [] 1572 for i in range(pfor_input.num_inputs): 1573 if i in input_indices: 1574 inp = pfor_input.stacked_input(i) 1575 inp = _flatten_first_two_dims(inp) 1576 else: 1577 inp = pfor_input.unstacked_input(i) 1578 inputs.append(inp) 1579 return inputs 1580 1581 1582@RegisterPForWithArgs("Conv2D", dims=[0]) 1583@RegisterPForWithArgs("AvgPool", dims=[0]) 1584@RegisterPForWithArgs("MaxPool", dims=[0]) 1585@RegisterPForWithArgs("MaxPool3D", dims=[0]) 1586@RegisterPForWithArgs("MaxPool3DGrad", dims=[0, 1, 2]) 1587@RegisterPForWithArgs("MaxPoolGrad", dims=[0, 1, 2]) 1588@RegisterPForWithArgs("MaxPool3DGradGrad", dims=[0, 1, 2]) 1589@RegisterPForWithArgs("MaxPoolGradGrad", dims=[0, 1, 2]) 1590@RegisterPForWithArgs("SoftmaxCrossEntropyWithLogits", dims=[0, 1]) 1591def _convert_flatten_batch(pfor_input, op_type, dims): 1592 del op_type 1593 inputs = _inputs_with_flattening(pfor_input, dims) 1594 outputs = _create_op( 1595 pfor_input.op_type, 1596 inputs, [x.dtype for x in pfor_input.outputs], 1597 attrs=pfor_input.op.node_def.attr).outputs 1598 n = pfor_input.pfor.loop_len_vector 1599 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1600 return [wrap(x, True) for x in outputs] 1601 1602 1603_channel_flatten_input_cache = {} 1604 1605 1606def _channel_flatten_input(x, data_format): 1607 """Merge the stack dimension with the channel dimension. 1608 1609 If S is pfor's stacking dimension, then, 1610 - for SNCHW, we transpose to NSCHW. If N dimension has size 1, the transpose 1611 should be cheap. 1612 - for SNHWC, we transpose to NHWCS. 1613 We then merge the S and C dimension. 1614 1615 Args: 1616 x: ops.Tensor to transform. 1617 data_format: "NCHW" or "NHWC". 1618 1619 Returns: 1620 A 3-element tuple with the transformed value, along with the shape for 1621 reshape and order for transpose required to transform back. 1622 """ 1623 1624 graph = ops.get_default_graph() 1625 cache_key = (graph, x.experimental_ref(), data_format) 1626 if cache_key not in _channel_flatten_input_cache: 1627 x_shape = array_ops.shape(x) 1628 if data_format == b"NCHW": 1629 order = [1, 0, 2, 3, 4] 1630 shape = array_ops.concat([x_shape[1:2], [-1], x_shape[3:]], axis=0) 1631 reverse_order = order 1632 else: 1633 order = [1, 2, 3, 0, 4] 1634 shape = array_ops.concat([x_shape[1:4], [-1]], axis=0) 1635 reverse_order = [3, 0, 1, 2, 4] 1636 # Move S dimension next to C dimension. 1637 x = array_ops.transpose(x, order) 1638 reverse_shape = array_ops.shape(x) 1639 # Reshape to merge the S and C dimension. 1640 x = array_ops.reshape(x, shape) 1641 outputs = x, reverse_order, reverse_shape 1642 _channel_flatten_input_cache[cache_key] = outputs 1643 else: 1644 outputs = _channel_flatten_input_cache[cache_key] 1645 return outputs 1646 1647 1648# Note that with training=True, running FusedBatchNormV3 on individual examples 1649# is very different from running FusedBatchNormV3 on a batch of those examples. 1650# This is because, for the latter case, the operation can be considered as first 1651# computing the mean and variance over all the examples and then using these 1652# to scale all those examples. This creates a data dependency between these 1653# different "iterations" since the inputs to the scaling step depends on the 1654# statistics coming from all these inputs. 1655# As with other kernels, the conversion here effectively runs the kernel 1656# independently for each iteration, and returns outputs by stacking outputs from 1657# each of those iterations. 1658@RegisterPFor("FusedBatchNormV3") 1659def _convert_fused_batch_norm(pfor_input): 1660 is_training = pfor_input.get_attr("is_training") 1661 # When BatchNorm is used with training=False, mean and variance are provided 1662 # externally and used as is by the op. Thus, we can merge the S and N 1663 # dimensions as we do for regular operations. 1664 # When BatchNorm is used with training=True, mean and variance are computed 1665 # for each channel across the batch dimension (first one). If we merge S and N 1666 # dimensions, mean and variances will be computed over a larger set. So, we 1667 # merge the S and C dimensions instead. 1668 if not is_training: 1669 # We return zeros for batch_mean and batch_variance output. Note that CPU 1670 # and GPU seem to have different behavior for those two outputs. CPU outputs 1671 # zero because these values are not used during inference. GPU outputs 1672 # something, probably real means and variances. 1673 inputs = _inputs_with_flattening(pfor_input, [0]) 1674 outputs = _create_op( 1675 pfor_input.op_type, 1676 inputs, [x.dtype for x in pfor_input.outputs], 1677 attrs=pfor_input.op.node_def.attr).outputs 1678 y = outputs[0] 1679 n = pfor_input.pfor.loop_len_vector 1680 y = _unflatten_first_dim(y, n) 1681 mean = pfor_input.unstacked_input(3) 1682 zeros = array_ops.zeros_like(mean) 1683 return [wrap(y, True)] + [wrap(zeros, False)] * 5 1684 1685 pfor_input.stack_inputs() 1686 data_format = pfor_input.get_attr("data_format") 1687 # We merge the first dimension with the "C" dimension, run FusedBatchNormV3, 1688 # and then transpose back. 1689 x = pfor_input.stacked_input(0) 1690 x, reverse_order, reverse_shape = _channel_flatten_input(x, data_format) 1691 # Note that we stack all the other inputs as well so that they are the same 1692 # size as the new size of the channel dimension. 1693 inputs = [x] + [ 1694 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1695 for i in range(1, pfor_input.num_inputs) 1696 ] 1697 outputs = _create_op( 1698 pfor_input.op_type, 1699 inputs, [x.dtype for x in pfor_input.outputs], 1700 attrs=pfor_input.op.node_def.attr).outputs 1701 y = outputs[0] 1702 y = array_ops.reshape(y, reverse_shape) 1703 y = array_ops.transpose(y, reverse_order) 1704 n = pfor_input.pfor.loop_len_vector 1705 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1706 outputs = [y] + outputs 1707 return [wrap(x, True) for x in outputs] 1708 1709 1710@RegisterPFor("FusedBatchNormGradV3") 1711def _convert_fused_batch_norm_grad(pfor_input): 1712 pfor_input.stack_inputs() 1713 data_format = pfor_input.get_attr("data_format") 1714 y_backprop = pfor_input.stacked_input(0) 1715 y_backprop, _, _ = _channel_flatten_input(y_backprop, data_format) 1716 x = pfor_input.stacked_input(1) 1717 x, x_reverse_order, x_reverse_shape = _channel_flatten_input(x, data_format) 1718 inputs = [y_backprop, x] + [ 1719 array_ops.reshape(pfor_input.stacked_input(i), [-1]) 1720 for i in range(2, pfor_input.num_inputs) 1721 ] 1722 outputs = _create_op( 1723 pfor_input.op_type, 1724 inputs, [x.dtype for x in pfor_input.outputs], 1725 attrs=pfor_input.op.node_def.attr).outputs 1726 x_backprop = outputs[0] 1727 x_backprop = array_ops.reshape(x_backprop, x_reverse_shape) 1728 x_backprop = array_ops.transpose(x_backprop, x_reverse_order) 1729 n = pfor_input.pfor.loop_len_vector 1730 outputs = [_unflatten_first_dim(x, n) for x in outputs[1:]] 1731 outputs = [x_backprop] + outputs 1732 return [wrap(output, True) for output in outputs] 1733 1734 1735@RegisterPForWithArgs("Conv2DBackpropInput", flatten_dims=[2], shape_dim=0) 1736@RegisterPForWithArgs("AvgPoolGrad", flatten_dims=[1], shape_dim=0) 1737def _convert_flatten_batch_shape_input(pfor_input, op_type, flatten_dims, 1738 shape_dim): 1739 del op_type 1740 inputs = _inputs_with_flattening(pfor_input, flatten_dims) 1741 n = pfor_input.pfor.loop_len_vector 1742 # Adjust the `input_sizes` input. 1743 ones = array_ops.ones([array_ops.shape(inputs[shape_dim])[0] - 1], 1744 dtype=n.dtype) 1745 inputs[shape_dim] *= array_ops.concat([n, ones], axis=0) 1746 outputs = _create_op( 1747 pfor_input.op_type, 1748 inputs, [x.dtype for x in pfor_input.outputs], 1749 attrs=pfor_input.op.node_def.attr).outputs 1750 outputs = [_unflatten_first_dim(x, n) for x in outputs] 1751 return [wrap(x, True) for x in outputs] 1752 1753 1754@RegisterPFor("Conv2DBackpropFilter") 1755def _convert_conv2d_backprop_filter(pfor_input): 1756 pfor_input.stack_inputs(stack_indices=[2]) 1757 inputs, inputs_stacked, _ = pfor_input.input(0) 1758 filter_sizes = pfor_input.unstacked_input(1) 1759 grads = pfor_input.stacked_input(2) 1760 strides = pfor_input.get_attr("strides") 1761 padding = pfor_input.get_attr("padding") 1762 use_cudnn_on_gpu = pfor_input.get_attr("use_cudnn_on_gpu") 1763 data_format = pfor_input.get_attr("data_format") 1764 dilations = pfor_input.get_attr("dilations") 1765 if inputs_stacked: 1766 # TODO(agarwal): Implement this efficiently. 1767 logging.warn("Conv2DBackpropFilter uses a while_loop. Fix that!") 1768 1769 def while_body(i, ta): 1770 inp_i = inputs[i, ...] 1771 grad_i = grads[i, ...] 1772 output = nn_ops.conv2d_backprop_filter( 1773 inp_i, 1774 filter_sizes, 1775 grad_i, 1776 strides=strides, 1777 padding=padding, 1778 use_cudnn_on_gpu=use_cudnn_on_gpu, 1779 data_format=data_format, 1780 dilations=dilations) 1781 return i + 1, ta.write(i, array_ops.expand_dims(output, 0)) 1782 1783 n = array_ops.reshape(pfor_input.pfor.loop_len_vector, []) 1784 _, ta = control_flow_ops.while_loop( 1785 lambda i, ta: i < n, while_body, 1786 (0, tensor_array_ops.TensorArray(inputs.dtype, n))) 1787 output = ta.concat() 1788 return wrap(output, True) 1789 else: 1790 # We merge the stack dimension with the channel dimension of the gradients 1791 # and pretend we had a larger filter (see change to filter_sizes below). 1792 # Once the filter backprop is computed, we reshape and transpose back 1793 # appropriately. 1794 grads, _, _ = _channel_flatten_input(grads, data_format) 1795 n = pfor_input.pfor.loop_len_vector 1796 old_filter_sizes = filter_sizes 1797 filter_sizes *= array_ops.concat([[1, 1, 1], n], axis=0) 1798 output = nn_ops.conv2d_backprop_filter( 1799 inputs, 1800 filter_sizes, 1801 grads, 1802 strides=strides, 1803 padding=padding, 1804 use_cudnn_on_gpu=use_cudnn_on_gpu, 1805 data_format=data_format, 1806 dilations=dilations) 1807 new_filter_shape = array_ops.concat([old_filter_sizes[:3], n, [-1]], axis=0) 1808 output = array_ops.reshape(output, new_filter_shape) 1809 output = array_ops.transpose(output, [3, 0, 1, 2, 4]) 1810 return wrap(output, True) 1811 1812 1813@RegisterPForWithArgs("LogSoftmax", gen_nn_ops.log_softmax) 1814@RegisterPForWithArgs("Softmax", gen_nn_ops.softmax) 1815def _convert_softmax(pfor_input, op_type, op_func): 1816 del op_type 1817 return wrap(op_func(pfor_input.stacked_input(0)), True) 1818 1819 1820# array_ops 1821 1822 1823@RegisterPForWithArgs("Identity", array_ops.identity) 1824@RegisterPForWithArgs("StopGradient", array_ops.stop_gradient) 1825@RegisterPForWithArgs("MatrixDiag", array_ops.matrix_diag) 1826@RegisterPForWithArgs("MatrixDiagPart", array_ops.matrix_diag_part) 1827def _convert_identity(pfor_input, op_type, op_func): 1828 del op_type 1829 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 1830 1831 1832@RegisterPFor("IdentityN") 1833def _convert_identity_n(pfor_input): 1834 outputs = array_ops.identity_n([x.t for x in pfor_input.inputs]) 1835 return [ 1836 wrap(out, inp.is_stacked) for out, inp in zip(outputs, pfor_input.inputs) 1837 ] 1838 1839 1840@RegisterPFor("Reshape") 1841def _convert_reshape(pfor_input): 1842 t = pfor_input.stacked_input(0) 1843 shape = pfor_input.unstacked_input(1) 1844 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 1845 return wrap(array_ops.reshape(t, new_shape), True) 1846 1847 1848@RegisterPFor("BroadcastTo") 1849def _convert_broadcast_to(pfor_input): 1850 t = pfor_input.stacked_input(0) 1851 shape = pfor_input.unstacked_input(1) 1852 new_shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 1853 1854 # Expand dims of stacked t to broadcast against the new shape. 1855 # TODO(davmre): consider factoring out common code with 1856 # `expanddim_inputs_for_broadcast`, which has similar logic but with 1857 # implicit shapes (of input Tensors) rather than explicit shapes. 1858 rank_diff = array_ops.shape(new_shape)[0] - array_ops.rank(t) 1859 ones = array_ops.tile([1], array_ops.reshape(rank_diff, [1])) 1860 t_shape = array_ops.shape(t) 1861 t_expanded_shape = array_ops.concat([t_shape[:1], ones, t_shape[1:]], axis=0) 1862 1863 return wrap( 1864 array_ops.broadcast_to(array_ops.reshape(t, t_expanded_shape), new_shape), 1865 True) 1866 1867 1868@RegisterPFor("ExpandDims") 1869def _convert_expanddims(pfor_input): 1870 t = pfor_input.stacked_input(0) 1871 dim = pfor_input.unstacked_input(1) 1872 dim += math_ops.cast(dim >= 0, dtypes.int32) 1873 return wrap(array_ops.expand_dims(t, axis=dim), True) 1874 1875 1876@RegisterPForWithArgs("LowerBound", gen_array_ops.lower_bound) 1877@RegisterPForWithArgs("UpperBound", gen_array_ops.upper_bound) 1878def _convert_searchsorted(pfor_input, _, op_func): 1879 pfor_input.stack_inputs() 1880 sorted_inputs = _flatten_first_two_dims(pfor_input.stacked_input(0)) 1881 values = _flatten_first_two_dims(pfor_input.stacked_input(1)) 1882 out_type = pfor_input.get_attr("out_type") 1883 output = op_func(sorted_inputs, values, out_type) 1884 return wrap( 1885 _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector), True) 1886 1887 1888@RegisterPFor("MatrixBandPart") 1889def _convert_matrix_band_part(pfor_input): 1890 t = pfor_input.stacked_input(0) 1891 num_lower = pfor_input.unstacked_input(1) 1892 num_upper = pfor_input.unstacked_input(2) 1893 return wrap( 1894 array_ops.matrix_band_part(t, num_lower=num_lower, num_upper=num_upper), 1895 True) 1896 1897 1898@RegisterPFor("MatrixSetDiag") 1899def _convert_matrix_set_diag(pfor_input): 1900 pfor_input.stack_inputs() 1901 t = pfor_input.stacked_input(0) 1902 diag = pfor_input.stacked_input(1) 1903 return wrap(array_ops.matrix_set_diag(t, diag), True) 1904 1905 1906# Registrations for Matrix{Diag,DiagPart,SetDiag}V2-3. 1907# The input orders defined in the OpKernel and the actual python API are 1908# different (for compatibility with V1), so we cannot use _convert_identity. 1909# v2 is not compatible with v3 and is never exposed on the public API. 1910@RegisterPFor("MatrixDiagV2") 1911@RegisterPFor("MatrixDiagV3") 1912def _convert_matrix_diag_v2(pfor_input): 1913 params = { 1914 "diagonal": pfor_input.stacked_input(0), 1915 "k": pfor_input.unstacked_input(1), 1916 "num_rows": pfor_input.unstacked_input(2), 1917 "num_cols": pfor_input.unstacked_input(3), 1918 "padding_value": pfor_input.unstacked_input(4) 1919 } 1920 if pfor_input.op_type == "MatrixDiagV2": 1921 return wrap(array_ops.matrix_diag_v2(**params), True) 1922 params["align"] = pfor_input.get_attr("align") 1923 return wrap(array_ops.matrix_diag(**params), True) 1924 1925 1926# See notes for MatrixDiagV2 1927@RegisterPFor("MatrixDiagPartV2") 1928@RegisterPFor("MatrixDiagPartV3") 1929def _convert_matrix_diag_part_v2(pfor_input): 1930 params = { 1931 "input": pfor_input.stacked_input(0), 1932 "k": pfor_input.unstacked_input(1), 1933 "padding_value": pfor_input.unstacked_input(2) 1934 } 1935 if pfor_input.op_type == "MatrixDiagPartV2": 1936 return wrap(array_ops.matrix_diag_part_v2(**params), True) 1937 params["align"] = pfor_input.get_attr("align") 1938 return wrap(array_ops.matrix_diag_part(**params), True) 1939 1940 1941# See notes for MatrixDiagV2 1942@RegisterPFor("MatrixSetDiagV2") 1943@RegisterPFor("MatrixSetDiagV3") 1944def _convert_matrix_set_diag_v2(pfor_input): 1945 pfor_input.stack_inputs([0, 1]) 1946 params = { 1947 "input": pfor_input.stacked_input(0), 1948 "diagonal": pfor_input.stacked_input(1), 1949 "k": pfor_input.unstacked_input(2) 1950 } 1951 if pfor_input.op_type == "MatrixSetDiagV2": 1952 return wrap(array_ops.matrix_set_diag_v2(**params), True) 1953 params["align"] = pfor_input.get_attr("align") 1954 return wrap(array_ops.matrix_set_diag(**params), True) 1955 1956 1957@RegisterPFor("OneHot") 1958def _convert_one_hot(pfor_input): 1959 indices = pfor_input.stacked_input(0) 1960 depth = pfor_input.unstacked_input(1) 1961 on_value = pfor_input.unstacked_input(2) 1962 off_value = pfor_input.unstacked_input(3) 1963 axis = pfor_input.get_attr("axis") 1964 if axis >= 0: 1965 axis += 1 1966 return wrap( 1967 array_ops.one_hot(indices, depth, on_value, off_value, axis), True) 1968 1969 1970@RegisterPFor("Slice") 1971def _convert_slice(pfor_input): 1972 t = pfor_input.stacked_input(0) 1973 begin = pfor_input.unstacked_input(1) 1974 size = pfor_input.unstacked_input(2) 1975 begin = array_ops.concat([[0], begin], axis=0) 1976 size = array_ops.concat([[-1], size], axis=0) 1977 return wrap(array_ops.slice(t, begin, size), True) 1978 1979 1980@RegisterPFor("Tile") 1981def _convert_tile(pfor_input): 1982 t = pfor_input.stacked_input(0) 1983 multiples = pfor_input.unstacked_input(1) 1984 multiples = array_ops.concat([[1], multiples], 0) 1985 return wrap(array_ops.tile(t, multiples), True) 1986 1987 1988@RegisterPFor("Pack") 1989def _convert_pack(pfor_input): 1990 pfor_input.stack_inputs() 1991 axis = pfor_input.get_attr("axis") 1992 if axis >= 0: 1993 axis += 1 1994 return wrap( 1995 array_ops.stack([x.t for x in pfor_input.inputs], axis=axis), True) 1996 1997 1998@RegisterPFor("Unpack") 1999def _convert_unpack(pfor_input): 2000 value = pfor_input.stacked_input(0) 2001 axis = pfor_input.get_attr("axis") 2002 if axis >= 0: 2003 axis += 1 2004 num = pfor_input.get_attr("num") 2005 return [wrap(x, True) for x in array_ops.unstack(value, axis=axis, num=num)] 2006 2007 2008@RegisterPFor("Pad") 2009def _convert_pad(pfor_input): 2010 t = pfor_input.stacked_input(0) 2011 paddings = pfor_input.unstacked_input(1) 2012 paddings = array_ops.concat([[[0, 0]], paddings], 0) 2013 return wrap(array_ops.pad(t, paddings, mode="CONSTANT"), True) 2014 2015 2016@RegisterPFor("Split") 2017def _convert_split(pfor_input): 2018 split_dim = pfor_input.unstacked_input(0) 2019 t = pfor_input.stacked_input(1) 2020 num_split = pfor_input.get_attr("num_split") 2021 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2022 return [wrap(x, True) for x in array_ops.split(t, num_split, axis=split_dim)] 2023 2024 2025@RegisterPFor("SplitV") 2026def _convert_split_v(pfor_input): 2027 t = pfor_input.stacked_input(0) 2028 splits = pfor_input.unstacked_input(1) 2029 split_dim = pfor_input.unstacked_input(2) 2030 split_dim += math_ops.cast(split_dim >= 0, dtypes.int32) 2031 return [wrap(x, True) for x in array_ops.split(t, splits, axis=split_dim)] 2032 2033 2034@RegisterPFor("Squeeze") 2035def _convert_squeeze(pfor_input): 2036 t = pfor_input.stacked_input(0) 2037 squeeze_dims = pfor_input.get_attr("squeeze_dims") 2038 squeeze_dims = [i + 1 if i >= 0 else i for i in squeeze_dims] 2039 return wrap(array_ops.squeeze(t, axis=squeeze_dims), True) 2040 2041 2042@RegisterPFor("Transpose") 2043def _convert_transpose(pfor_input): 2044 t = pfor_input.stacked_input(0) 2045 perm = pfor_input.unstacked_input(1) 2046 new_perm = array_ops.concat([[0], perm + 1], axis=0) 2047 return wrap(array_ops.transpose(t, new_perm), True) 2048 2049 2050@RegisterPFor("ZerosLike") 2051def _convert_zeroslike(pfor_input): 2052 t = pfor_input.stacked_input(0) 2053 shape = array_ops.shape(t)[1:] 2054 return wrap(array_ops.zeros(shape, dtype=t.dtype), False) 2055 2056 2057@RegisterPFor("Gather") 2058@RegisterPFor("GatherV2") 2059def _convert_gather(pfor_input): 2060 param, param_stacked, _ = pfor_input.input(0) 2061 indices, indices_stacked, _ = pfor_input.input(1) 2062 op_type = pfor_input.op_type 2063 if op_type == "Gather": 2064 validate_indices = pfor_input.get_attr("validate_indices") 2065 axis = 0 2066 else: 2067 validate_indices = None 2068 # Assume we will never have a Tensor with rank > 2**32. 2069 axis = math_ops.cast(pfor_input.unstacked_input(2), dtypes.int32) 2070 axis_value = tensor_util.constant_value(axis) 2071 if axis_value is not None: 2072 axis = axis_value 2073 if indices_stacked and not param_stacked: 2074 if indices is pfor_input.pfor.all_indices and axis == 0: 2075 param_shape0 = tensor_shape.dimension_value(param.shape[0]) 2076 indices_shape0 = tensor_shape.dimension_value(indices.shape[0]) 2077 if param_shape0 is not None and indices_shape0 == param_shape0: 2078 # Note that with loops and conditionals, indices may not be contiguous. 2079 # However they will be sorted and unique. So if the shape matches, then 2080 # it must be picking up all the rows of param. 2081 return wrap(param, True) 2082 # TODO(agarwal): use array_ops.slice here. 2083 output = array_ops.gather( 2084 param, indices, validate_indices=validate_indices, axis=axis) 2085 if axis != 0: 2086 axis = control_flow_ops.cond(axis < 0, 2087 lambda: axis + array_ops.rank(param), 2088 lambda: axis) 2089 order = array_ops.concat( 2090 [[axis], 2091 math_ops.range(axis), 2092 math_ops.range(axis + 1, array_ops.rank(output))], 2093 axis=0) 2094 output = control_flow_ops.cond( 2095 math_ops.equal(axis, 0), lambda: output, 2096 lambda: array_ops.transpose(output, order)) 2097 return wrap(output, True) 2098 if param_stacked: 2099 loop_len_vector = pfor_input.pfor.loop_len_vector 2100 pfor_input.stack_inputs(stack_indices=[1]) 2101 indices = pfor_input.stacked_input(1) 2102 param_flat = _flatten_first_two_dims(param) 2103 2104 # Recompute indices to handle stacked param. 2105 indices_offset = (math_ops.range(math_ops.cast(loop_len_vector[0], 2106 dtype=indices.dtype)) * 2107 math_ops.cast(array_ops.shape(param)[1], indices.dtype)) 2108 # Reshape indices_offset to allow broadcast addition 2109 ones = array_ops.ones([array_ops.rank(indices) - 1], dtype=dtypes.int32) 2110 new_shape = array_ops.concat([loop_len_vector, ones], axis=0) 2111 indices_offset = array_ops.reshape(indices_offset, new_shape) 2112 indices += indices_offset 2113 2114 # TODO(agarwal): handle axis != 0. May need to transpose param or 2115 # array_ops.gather_nd. 2116 if isinstance(axis, ops.Tensor): 2117 axis_value = tensor_util.constant_value(axis) 2118 else: 2119 try: 2120 axis_value = int(axis) 2121 except TypeError: 2122 axis_value = None 2123 msg = ("Gather, where indices and param are both loop dependent, currently " 2124 "requires axis=0") 2125 if axis_value is not None and axis_value != 0: 2126 raise ValueError("Error while converting %s. %s. Got axis=%d" % 2127 (pfor_input.op, msg, axis)) 2128 with ops.control_dependencies( 2129 [check_ops.assert_equal(axis, 0, message=msg)]): 2130 output = array_ops.gather(param_flat, indices) 2131 return wrap(output, True) 2132 2133 2134@RegisterPFor("GatherNd") 2135def _convert_gather_nd(pfor_input): 2136 # TODO(jmenick): Add support for unstacked params. 2137 pfor_input.stack_inputs(stack_indices=[1]) 2138 params = pfor_input.stacked_input(0) 2139 indices = pfor_input.stacked_input(1) 2140 stacked_result = array_ops.gather_nd(params, indices, batch_dims=1) 2141 return wrap(stacked_result, True) 2142 2143 2144@RegisterPFor("ConcatV2") 2145def _convert_concatv2(pfor_input): 2146 n = pfor_input.num_inputs 2147 pfor_input.stack_inputs(stack_indices=range(n - 1)) 2148 axis = pfor_input.unstacked_input(n - 1) 2149 axis += math_ops.cast(axis >= 0, axis.dtype) 2150 return wrap( 2151 array_ops.concat([x.t for x in pfor_input.inputs[:n - 1]], axis=axis), 2152 True) 2153 2154 2155@RegisterPFor("StridedSlice") 2156def _convert_strided_slice(pfor_input): 2157 inp = pfor_input.stacked_input(0) 2158 begin = pfor_input.unstacked_input(1) 2159 end = pfor_input.unstacked_input(2) 2160 strides = pfor_input.unstacked_input(3) 2161 begin_mask = pfor_input.get_attr("begin_mask") 2162 end_mask = pfor_input.get_attr("end_mask") 2163 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2164 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2165 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2166 2167 begin = array_ops.concat([[0], begin], axis=0) 2168 end = array_ops.concat([[0], end], axis=0) 2169 strides = array_ops.concat([[1], strides], axis=0) 2170 begin_mask = begin_mask << 1 | 1 2171 end_mask = end_mask << 1 | 1 2172 ellipsis_mask <<= 1 2173 new_axis_mask <<= 1 2174 shrink_axis_mask <<= 1 2175 return wrap( 2176 array_ops.strided_slice( 2177 inp, 2178 begin, 2179 end, 2180 strides, 2181 begin_mask=begin_mask, 2182 end_mask=end_mask, 2183 ellipsis_mask=ellipsis_mask, 2184 new_axis_mask=new_axis_mask, 2185 shrink_axis_mask=shrink_axis_mask), True) 2186 2187 2188@RegisterPFor("StridedSliceGrad") 2189def _convert_strided_slice_grad(pfor_input): 2190 shape = pfor_input.unstacked_input(0) 2191 begin = pfor_input.unstacked_input(1) 2192 end = pfor_input.unstacked_input(2) 2193 strides = pfor_input.unstacked_input(3) 2194 dy = pfor_input.stacked_input(4) 2195 begin_mask = pfor_input.get_attr("begin_mask") 2196 end_mask = pfor_input.get_attr("end_mask") 2197 ellipsis_mask = pfor_input.get_attr("ellipsis_mask") 2198 new_axis_mask = pfor_input.get_attr("new_axis_mask") 2199 shrink_axis_mask = pfor_input.get_attr("shrink_axis_mask") 2200 2201 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2202 begin = array_ops.concat([[0], begin], axis=0) 2203 end = array_ops.concat([[0], end], axis=0) 2204 strides = array_ops.concat([[1], strides], axis=0) 2205 begin_mask = begin_mask << 1 | 1 2206 end_mask = end_mask << 1 | 1 2207 ellipsis_mask <<= 1 2208 new_axis_mask <<= 1 2209 shrink_axis_mask <<= 1 2210 return wrap( 2211 array_ops.strided_slice_grad( 2212 shape, 2213 begin, 2214 end, 2215 strides, 2216 dy, 2217 begin_mask=begin_mask, 2218 end_mask=end_mask, 2219 ellipsis_mask=ellipsis_mask, 2220 new_axis_mask=new_axis_mask, 2221 shrink_axis_mask=shrink_axis_mask), True) 2222 2223 2224# math_ops 2225 2226 2227@RegisterPFor("MatMul") 2228def _convert_matmul(pfor_input): 2229 # TODO(agarwal): Check if tiling is faster than two transposes. 2230 a, a_stacked, _ = pfor_input.input(0) 2231 b, b_stacked, _ = pfor_input.input(1) 2232 tr_a = pfor_input.get_attr("transpose_a") 2233 tr_b = pfor_input.get_attr("transpose_b") 2234 if a_stacked and b_stacked: 2235 output = wrap(math_ops.matmul(a, b, adjoint_a=tr_a, adjoint_b=tr_b), True) 2236 return output 2237 elif a_stacked: 2238 if tr_a: 2239 a = array_ops.transpose(a, [0, 2, 1]) 2240 if a.shape.is_fully_defined(): 2241 x, y, z = a.shape 2242 else: 2243 x, y, z = [ 2244 array_ops.reshape(i, []) 2245 for i in array_ops.split(array_ops.shape(a), 3) 2246 ] 2247 a = array_ops.reshape(a, [x * y, z]) 2248 prod = math_ops.matmul(a, b, transpose_b=tr_b) 2249 return wrap(array_ops.reshape(prod, [x, y, -1]), True) 2250 else: 2251 assert b_stacked 2252 if tr_b: 2253 perm = [2, 0, 1] 2254 b = array_ops.transpose(b, perm) 2255 else: 2256 # As an optimization, if one of the first two dimensions is 1, then we can 2257 # reshape instead of transpose. 2258 # TODO(agarwal): This check can be done inside Transpose kernel. 2259 b_shape = array_ops.shape(b) 2260 min_dim = math_ops.minimum(b_shape[0], b_shape[1]) 2261 perm = control_flow_ops.cond( 2262 math_ops.equal(min_dim, 1), lambda: [0, 1, 2], lambda: [1, 0, 2]) 2263 new_shape = array_ops.stack([b_shape[1], b_shape[0], b_shape[2]]) 2264 b = array_ops.transpose(b, perm) 2265 b = array_ops.reshape(b, new_shape) 2266 2267 if b.shape.is_fully_defined(): 2268 x, y, z = b.shape 2269 else: 2270 x, y, z = [ 2271 array_ops.reshape(i, []) 2272 for i in array_ops.split(array_ops.shape(b), 3) 2273 ] 2274 b = array_ops.reshape(b, [x, y * z]) 2275 prod = math_ops.matmul(a, b, transpose_a=tr_a) 2276 prod = array_ops.reshape(prod, [-1, y, z]) 2277 prod = array_ops.transpose(prod, [1, 0, 2]) 2278 return wrap(prod, True) 2279 2280 2281# TODO(rmlarsen): Use the converter of BatchMatMulV2 once compatibility window 2282# is met. 2283@RegisterPFor("BatchMatMul") 2284def _convert_batch_mat_mul(pfor_input): 2285 # TODO(agarwal): There may be a more efficient way to do this instead of 2286 # stacking the inputs. 2287 pfor_input.stack_inputs() 2288 x = pfor_input.stacked_input(0) 2289 y = pfor_input.stacked_input(1) 2290 adj_x = pfor_input.get_attr("adj_x") 2291 adj_y = pfor_input.get_attr("adj_y") 2292 2293 x = _flatten_first_two_dims(x) 2294 y = _flatten_first_two_dims(y) 2295 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2296 output = _unflatten_first_dim(output, pfor_input.pfor.loop_len_vector) 2297 return wrap(output, True) 2298 2299 2300@RegisterPFor("BatchMatMulV2") 2301def _convert_batch_mat_mul_v2(pfor_input): 2302 pfor_input.expanddim_inputs_for_broadcast() 2303 x = pfor_input.input(0)[0] 2304 y = pfor_input.input(1)[0] 2305 adj_x = pfor_input.get_attr("adj_x") 2306 adj_y = pfor_input.get_attr("adj_y") 2307 2308 output = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y) 2309 return wrap(output, True) 2310 2311 2312@RegisterPForWithArgs("Sum", math_ops.reduce_sum) 2313@RegisterPForWithArgs("Prod", math_ops.reduce_prod) 2314@RegisterPForWithArgs("Max", math_ops.reduce_max) 2315@RegisterPForWithArgs("Min", math_ops.reduce_min) 2316@RegisterPForWithArgs("Mean", math_ops.reduce_mean) 2317@RegisterPForWithArgs("All", math_ops.reduce_all) 2318@RegisterPForWithArgs("Any", math_ops.reduce_any) 2319def _convert_reduction(pfor_input, _, op_func): 2320 t = pfor_input.stacked_input(0) 2321 indices = pfor_input.unstacked_input(1) 2322 # Shift positive indices by one to account for the extra dimension. 2323 indices += math_ops.cast(indices >= 0, dtypes.int32) 2324 keep_dims = pfor_input.get_attr("keep_dims") 2325 return wrap(op_func(t, indices, keepdims=keep_dims), True) 2326 2327 2328@RegisterPForWithArgs("Cumsum", math_ops.cumsum) 2329@RegisterPForWithArgs("Cumprod", math_ops.cumprod) 2330def _convert_cumfoo(pfor_input, _, op_func): 2331 t = pfor_input.stacked_input(0) 2332 axis = pfor_input.unstacked_input(1) 2333 # Shift positive indices by one to account for the extra dimension. 2334 axis += math_ops.cast(axis >= 0, dtypes.int32) 2335 exclusive = pfor_input.get_attr("exclusive") 2336 reverse = pfor_input.get_attr("reverse") 2337 return wrap(op_func(t, axis, exclusive=exclusive, reverse=reverse), True) 2338 2339 2340@RegisterPFor("BiasAdd") 2341def _convert_biasadd(pfor_input): 2342 t, t_stacked, _ = pfor_input.input(0) 2343 bias, bias_stacked, _ = pfor_input.input(1) 2344 data_format = pfor_input.get_attr("data_format").decode() 2345 if bias_stacked: 2346 # BiasAdd only supports 1-D biases, so cast bias to match value and use Add. 2347 pfor_input.expanddim_inputs_for_broadcast() 2348 t, _, _ = pfor_input.input(0) 2349 bias = math_ops.cast(pfor_input.stacked_input(1), t.dtype) 2350 if compat.as_bytes(data_format) == b"NCHW": 2351 b_shape = array_ops.shape(bias) 2352 new_b_shape = array_ops.concat( 2353 [b_shape[:-3], b_shape[-1:], b_shape[-3:-1]], axis=0) 2354 bias = array_ops.reshape(bias, new_b_shape) 2355 return wrap(math_ops.add(t, bias), True) 2356 else: 2357 assert t_stacked, "At least one input to BiasAdd should be loop variant." 2358 if compat.as_bytes(data_format) == b"NCHW": 2359 shape = array_ops.shape(t) 2360 flattened_shape = array_ops.concat([[-1], shape[2:]], axis=0) 2361 t = array_ops.reshape(t, flattened_shape) 2362 t = nn_ops.bias_add(t, bias, data_format="NCHW") 2363 t = array_ops.reshape(t, shape) 2364 return wrap(t, True) 2365 return wrap(nn_ops.bias_add(t, bias, data_format=data_format), True) 2366 2367 2368@RegisterPFor("UnsortedSegmentSum") 2369def _convert_unsortedsegmentsum(pfor_input): 2370 pfor_input.stack_inputs([0, 1]) 2371 data = pfor_input.stacked_input(0) 2372 segment_ids = pfor_input.stacked_input(1) 2373 # TODO(agarwal): handle stacked? 2374 num_segments = pfor_input.unstacked_input(2) 2375 if segment_ids.dtype != num_segments.dtype: 2376 segment_ids = math_ops.cast(segment_ids, dtypes.int64) 2377 num_segments = math_ops.cast(num_segments, dtypes.int64) 2378 dtype = segment_ids.dtype 2379 segment_shape = array_ops.shape(segment_ids, out_type=dtype) 2380 n = segment_shape[0] 2381 ones = array_ops.ones_like(segment_shape, dtype=dtype)[1:] 2382 segment_offset = num_segments * math_ops.range(n, dtype=dtype) 2383 segment_offset = array_ops.reshape(segment_offset, 2384 array_ops.concat([[n], ones], axis=0)) 2385 segment_ids += segment_offset 2386 num_segments = math_ops.cast(num_segments, dtypes.int64) * math_ops.cast( 2387 n, dtypes.int64) 2388 output = math_ops.unsorted_segment_sum(data, segment_ids, num_segments) 2389 new_output_shape = array_ops.concat( 2390 [[n, -1], array_ops.shape(output)[1:]], axis=0) 2391 output = array_ops.reshape(output, new_output_shape) 2392 return wrap(output, True) 2393 2394 2395def _flatten_array_with_offset(ids, offset_delta, num_rows): 2396 """Flattens a rank 2 tensor, adding an offset to each row.""" 2397 # Note that if `ids` is rank 1, it is broadcast to rank 2. 2398 offset_delta = math_ops.cast(offset_delta, ids.dtype) 2399 n = math_ops.cast(num_rows, dtype=ids.dtype) 2400 offsets = math_ops.range( 2401 start=0, limit=n * offset_delta, delta=offset_delta, dtype=ids.dtype) 2402 offsets = array_ops.expand_dims(offsets, -1) 2403 ids += offsets 2404 return array_ops.reshape(ids, [-1]) 2405 2406 2407@RegisterPForWithArgs("SparseSegmentSum", math_ops.sparse_segment_sum_v2) 2408@RegisterPForWithArgs("SparseSegmentMean", math_ops.sparse_segment_mean_v2) 2409@RegisterPForWithArgs("SparseSegmentSqrtN", math_ops.sparse_segment_sqrt_n_v2) 2410@RegisterPForWithArgs("SparseSegmentSumWithNumSegments", 2411 math_ops.sparse_segment_sum_v2) 2412@RegisterPForWithArgs("SparseSegmentMeanWithNumSegments", 2413 math_ops.sparse_segment_mean_v2) 2414@RegisterPForWithArgs("SparseSegmentSqrtNWithNumSegments", 2415 math_ops.sparse_segment_sqrt_n_v2) 2416def _convert_sparse_segment(pfor_input, _, op_func): 2417 _, segment_ids_stacked, _ = pfor_input.input(2) 2418 if segment_ids_stacked: 2419 pfor_input.stack_inputs([1]) 2420 data, data_stacked, _ = pfor_input.input(0) 2421 indices, _, _ = pfor_input.input(1) 2422 num_inputs = len(pfor_input.inputs) 2423 assert num_inputs in (3, 4) 2424 if num_inputs == 3: 2425 # `segment_ids` needs to be unstacked since otherwise output sizes could 2426 # differ across pfor iterations. 2427 segment_ids = pfor_input.unstacked_input(2) 2428 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2429 else: 2430 segment_ids, _, _ = pfor_input.input(2) 2431 num_segments = pfor_input.unstacked_input(3) 2432 2433 n = pfor_input.pfor.loop_len_vector[0] 2434 if data_stacked: 2435 indices = _flatten_array_with_offset(indices, array_ops.shape(data)[1], n) 2436 data = _flatten_first_two_dims(data) 2437 else: 2438 indices = array_ops.reshape(indices, [-1]) 2439 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2440 2441 if num_inputs == 3: 2442 num_segments = None 2443 else: 2444 num_segments *= n 2445 output = op_func(data, indices, segment_ids, num_segments=num_segments) 2446 output = _unflatten_first_dim(output, [n]) 2447 return wrap(output, True) 2448 2449 2450@RegisterPForWithArgs("SparseSegmentMeanGrad", 2451 math_ops.sparse_segment_mean_grad) 2452@RegisterPForWithArgs("SparseSegmentSqrtNGrad", 2453 math_ops.sparse_segment_sqrt_n_grad) 2454def _convert_sparse_segment_grad(pfor_input, _, op_func): 2455 grad = pfor_input.stacked_input(0) 2456 indices = pfor_input.unstacked_input(1) 2457 segment_ids = pfor_input.unstacked_input(2) 2458 dim0 = pfor_input.unstacked_input(3) 2459 2460 n = pfor_input.pfor.loop_len_vector[0] 2461 indices = _flatten_array_with_offset(indices, dim0, n) 2462 num_segments = nn_ops.relu(math_ops.reduce_max(segment_ids) + 1) 2463 segment_ids = _flatten_array_with_offset(segment_ids, num_segments, n) 2464 grad = _flatten_first_two_dims(grad) 2465 dim0 *= n 2466 output = op_func(grad, indices, segment_ids, dim0) 2467 output = _unflatten_first_dim(output, [n]) 2468 return wrap(output, True) 2469 2470 2471@RegisterPFor("Cast") 2472def _convert_cast(pfor_input): 2473 inp = pfor_input.stacked_input(0) 2474 dtype = pfor_input.get_attr("DstT") 2475 return wrap(math_ops.cast(inp, dtype), True) 2476 2477 2478@RegisterPForWithArgs("Abs", math_ops.abs) 2479@RegisterPForWithArgs("Acos", math_ops.acos) 2480@RegisterPForWithArgs("Acosh", math_ops.acosh) 2481@RegisterPForWithArgs("Add", math_ops.add) 2482@RegisterPForWithArgs("AddV2", math_ops.add_v2) 2483@RegisterPForWithArgs("Angle", math_ops.angle) 2484@RegisterPForWithArgs("Asin", math_ops.asin) 2485@RegisterPForWithArgs("Asinh", math_ops.asinh) 2486@RegisterPForWithArgs("Atan", math_ops.atan) 2487@RegisterPForWithArgs("Atan2", math_ops.atan2) 2488@RegisterPForWithArgs("Atanh", math_ops.atanh) 2489@RegisterPForWithArgs("BesselI0e", math_ops.bessel_i0e) 2490@RegisterPForWithArgs("BesselI1e", math_ops.bessel_i1e) 2491@RegisterPForWithArgs("BitwiseAnd", bitwise_ops.bitwise_and) 2492@RegisterPForWithArgs("BitwiseOr", bitwise_ops.bitwise_or) 2493@RegisterPForWithArgs("BitwiseXor", bitwise_ops.bitwise_xor) 2494@RegisterPForWithArgs("Ceil", math_ops.ceil) 2495@RegisterPForWithArgs("Complex", math_ops.complex) 2496@RegisterPForWithArgs("ComplexAbs", math_ops.complex_abs) 2497@RegisterPForWithArgs("Conj", math_ops.conj) 2498@RegisterPForWithArgs("Cos", math_ops.cos) 2499@RegisterPForWithArgs("Cosh", math_ops.cosh) 2500@RegisterPForWithArgs("Dawsn", special_math_ops.dawsn) 2501@RegisterPForWithArgs("Digamma", math_ops.digamma) 2502@RegisterPForWithArgs("Div", math_ops.div) 2503@RegisterPForWithArgs("DivNoNan", math_ops.div_no_nan) 2504@RegisterPForWithArgs("Elu", nn_ops.elu) 2505@RegisterPForWithArgs("Erf", math_ops.erf) 2506@RegisterPForWithArgs("Erfc", math_ops.erfc) 2507@RegisterPForWithArgs("Erfinv", math_ops.erfinv) 2508@RegisterPForWithArgs("Exp", math_ops.exp) 2509@RegisterPForWithArgs("Expint", special_math_ops.expint) 2510@RegisterPForWithArgs("Expm1", math_ops.expm1) 2511@RegisterPForWithArgs("Floor", math_ops.floor) 2512@RegisterPForWithArgs("FloorDiv", math_ops.floor_div) 2513@RegisterPForWithArgs("FloorMod", math_ops.floor_mod) 2514@RegisterPForWithArgs("FresnelCos", special_math_ops.fresnel_cos) 2515@RegisterPForWithArgs("FresnelSin", special_math_ops.fresnel_sin) 2516@RegisterPForWithArgs("Greater", math_ops.greater) 2517@RegisterPForWithArgs("GreaterEqual", math_ops.greater_equal) 2518@RegisterPForWithArgs("Igamma", math_ops.igamma) 2519@RegisterPForWithArgs("IgammaGradA", math_ops.igamma_grad_a) 2520@RegisterPForWithArgs("Igammac", math_ops.igammac) 2521@RegisterPForWithArgs("Imag", math_ops.imag) 2522@RegisterPForWithArgs("Inv", math_ops.inv) 2523@RegisterPForWithArgs("Invert", bitwise_ops.invert) 2524@RegisterPForWithArgs("IsFinite", math_ops.is_finite) 2525@RegisterPForWithArgs("IsInf", math_ops.is_inf) 2526@RegisterPForWithArgs("IsNan", math_ops.is_nan) 2527@RegisterPForWithArgs("LeftShift", bitwise_ops.left_shift) 2528@RegisterPForWithArgs("Less", math_ops.less) 2529@RegisterPForWithArgs("LessEqual", math_ops.less_equal) 2530@RegisterPForWithArgs("Lgamma", math_ops.lgamma) 2531@RegisterPForWithArgs("Log", math_ops.log) 2532@RegisterPForWithArgs("Log1p", math_ops.log1p) 2533@RegisterPForWithArgs("LogicalAnd", math_ops.logical_and) 2534@RegisterPForWithArgs("LogicalNot", math_ops.logical_not) 2535@RegisterPForWithArgs("LogicalOr", math_ops.logical_or) 2536@RegisterPForWithArgs("LogicalXor", math_ops.logical_xor) 2537@RegisterPForWithArgs("Maximum", math_ops.maximum) 2538@RegisterPForWithArgs("Minimum", math_ops.minimum) 2539@RegisterPForWithArgs("Mod", math_ops.mod) 2540@RegisterPForWithArgs("Mul", math_ops.multiply) 2541@RegisterPForWithArgs("MulNoNan", math_ops.mul_no_nan) 2542@RegisterPForWithArgs("Ndtri", math_ops.ndtri) 2543@RegisterPForWithArgs("Neg", math_ops.negative) 2544@RegisterPForWithArgs("Polygamma", math_ops.polygamma) 2545@RegisterPForWithArgs("Pow", math_ops.pow) 2546@RegisterPForWithArgs("Real", math_ops.real) 2547@RegisterPForWithArgs("RealDiv", math_ops.divide) 2548@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal) 2549@RegisterPForWithArgs("Relu", nn_ops.relu) 2550@RegisterPForWithArgs("Relu6", nn_ops.relu6) 2551@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift) 2552@RegisterPForWithArgs("Rint", math_ops.rint) 2553@RegisterPForWithArgs("Round", math_ops.round) 2554@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt) 2555@RegisterPForWithArgs("Selu", nn_ops.selu) 2556@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid) 2557@RegisterPForWithArgs("Sign", math_ops.sign) 2558@RegisterPForWithArgs("Sin", math_ops.sin) 2559@RegisterPForWithArgs("Sinh", math_ops.sinh) 2560@RegisterPForWithArgs("Softplus", nn_ops.softplus) 2561@RegisterPForWithArgs("Softsign", nn_ops.softsign) 2562@RegisterPForWithArgs("Spence", special_math_ops.spence) 2563@RegisterPForWithArgs("Sqrt", math_ops.sqrt) 2564@RegisterPForWithArgs("Square", math_ops.square) 2565@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference) 2566@RegisterPForWithArgs("Sub", math_ops.subtract) 2567@RegisterPForWithArgs("Tan", math_ops.tan) 2568@RegisterPForWithArgs("Tanh", math_ops.tanh) 2569@RegisterPForWithArgs("TruncateDiv", math_ops.truncate_div) 2570@RegisterPForWithArgs("TruncateMod", math_ops.truncate_mod) 2571@RegisterPForWithArgs("Xdivy", math_ops.xdivy) 2572@RegisterPForWithArgs("Xlogy", math_ops.xlogy) 2573@RegisterPForWithArgs("Xlog1py", math_ops.xlog1py) 2574@RegisterPForWithArgs("Zeta", math_ops.zeta) 2575def _convert_cwise(pfor_input, op_type, op_func): 2576 # Note that ops handled here do not have attributes except those listed below 2577 # and hence don't need extra arguments passed to the cwise_op call below. 2578 for attr in pfor_input.op.node_def.attr.keys(): 2579 assert attr in [u"T", u"Tout", u"_xla_compile_id"], (op_type, attr) 2580 pfor_input.expanddim_inputs_for_broadcast() 2581 return wrap(op_func(*[x.t for x in pfor_input.inputs]), True) 2582 2583 2584@RegisterPFor("Equal") 2585def _convert_equal(pfor_input): 2586 pfor_input.expanddim_inputs_for_broadcast() 2587 x = pfor_input.input(0)[0] 2588 y = pfor_input.input(1)[0] 2589 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 2590 assert incompatible_shape_error 2591 return wrap(math_ops.equal(x, y), True) 2592 2593 2594@RegisterPFor("NotEqual") 2595def _convert_not_equal(pfor_input): 2596 pfor_input.expanddim_inputs_for_broadcast() 2597 x = pfor_input.input(0)[0] 2598 y = pfor_input.input(1)[0] 2599 incompatible_shape_error = pfor_input.get_attr("incompatible_shape_error") 2600 assert incompatible_shape_error 2601 return wrap(math_ops.not_equal(x, y), True) 2602 2603 2604@RegisterPFor("ApproximateEqual") 2605def _convert_approximate_equal(pfor_input): 2606 pfor_input.expanddim_inputs_for_broadcast() 2607 x = pfor_input.input(0)[0] 2608 y = pfor_input.input(1)[0] 2609 tolerance = pfor_input.get_attr("tolerance") 2610 return wrap(math_ops.approximate_equal(x, y, tolerance=tolerance), True) 2611 2612 2613@RegisterPFor("Shape") 2614def _convert_shape(pfor_input): 2615 out_type = pfor_input.get_attr("out_type") 2616 return wrap( 2617 array_ops.shape(pfor_input.stacked_input(0), out_type=out_type)[1:], 2618 False) 2619 2620 2621@RegisterPFor("ShapeN") 2622def _convert_shape_n(pfor_input): 2623 out_type = pfor_input.get_attr("out_type") 2624 shapes = [ 2625 array_ops.shape(x, out_type=out_type)[1:] if stacked else array_ops.shape( 2626 x, out_type=out_type) for x, stacked, _ in pfor_input.inputs 2627 ] 2628 return [wrap(x, False) for x in shapes] 2629 2630 2631@RegisterPFor("Size") 2632def _convert_size(pfor_input): 2633 out_type = pfor_input.get_attr("out_type") 2634 n = math_ops.cast(pfor_input.pfor.loop_len_vector[0], out_type) 2635 return wrap( 2636 array_ops.size(pfor_input.stacked_input(0), out_type=out_type) // n, 2637 False) 2638 2639 2640@RegisterPFor("Rank") 2641def _convert_rank(pfor_input): 2642 return wrap(array_ops.rank(pfor_input.stacked_input(0)) - 1, False) 2643 2644 2645@RegisterPFor("AddN") 2646def _convert_addn(pfor_input): 2647 # AddN does not support broadcasting. 2648 pfor_input.stack_inputs() 2649 return wrap(math_ops.add_n([x.t for x in pfor_input.inputs]), True) 2650 2651 2652@RegisterPFor("Cross") 2653def _convert_cross(pfor_input): 2654 pfor_input.stack_inputs() 2655 a = pfor_input.stacked_input(0) 2656 b = pfor_input.stacked_input(1) 2657 return wrap(math_ops.cross(a, b), True) 2658 2659 2660@RegisterPFor("BiasAddGrad") 2661def _convert_biasaddgrad(pfor_input): 2662 grad = pfor_input.stacked_input(0) 2663 fmt = pfor_input.get_attr("data_format") 2664 if fmt == b"NCHW": 2665 output = math_ops.reduce_sum(grad, axis=[1, 3, 4], keepdims=False) 2666 else: 2667 grad_shape = array_ops.shape(grad) 2668 last_dim_shape = grad_shape[-1] 2669 first_dim_shape = grad_shape[0] 2670 output = array_ops.reshape(grad, [first_dim_shape, -1, last_dim_shape]) 2671 output = math_ops.reduce_sum(output, axis=[1], keepdims=False) 2672 return wrap(output, True) 2673 2674 2675# Some required ops are not exposed under the tf namespace. Hence relying on 2676# _create_op to create them. 2677@RegisterPForWithArgs("EluGrad") 2678@RegisterPForWithArgs("Relu6Grad") 2679@RegisterPForWithArgs("ReluGrad") 2680@RegisterPForWithArgs("SeluGrad") 2681@RegisterPForWithArgs("SigmoidGrad") 2682@RegisterPForWithArgs("SoftplusGrad") 2683@RegisterPForWithArgs("SoftsignGrad") 2684@RegisterPForWithArgs("TanhGrad") 2685@RegisterPForWithArgs("SqrtGrad") 2686@RegisterPForWithArgs("RsqrtGrad") 2687@RegisterPForWithArgs("ReciprocalGrad") 2688def _convert_grads(pfor_input, op_type, *args, **kw_args): 2689 del args 2690 del kw_args 2691 # TODO(agarwal): Looks like these ops don't support broadcasting. Hence we 2692 # have to use tiling here. 2693 pfor_input.stack_inputs() 2694 outputs = _create_op( 2695 op_type, [x.t for x in pfor_input.inputs], 2696 [x.dtype for x in pfor_input.outputs], 2697 attrs=pfor_input.op.node_def.attr).outputs 2698 return [wrap(x, True) for x in outputs] 2699 2700 2701@RegisterPFor("Select") 2702def _convert_select(pfor_input): 2703 pfor_input.stack_inputs() 2704 cond = pfor_input.stacked_input(0) 2705 t = pfor_input.stacked_input(1) 2706 e = pfor_input.stacked_input(2) 2707 cond_rank = array_ops.rank(cond) 2708 cond, t, e = control_flow_ops.cond( 2709 cond_rank > 1, lambda: _inputs_with_flattening(pfor_input, [0, 1, 2]), 2710 lambda: [cond, t, e]) 2711 outputs = _create_op( 2712 pfor_input.op_type, [cond, t, e], [x.dtype for x in pfor_input.outputs], 2713 attrs=pfor_input.op.node_def.attr).outputs 2714 n = pfor_input.pfor.loop_len_vector 2715 out = control_flow_ops.cond(cond_rank > 1, 2716 lambda: _unflatten_first_dim(outputs[0], n), 2717 lambda: outputs[0]) 2718 return [wrap(out, True) for x in outputs] 2719 2720 2721@RegisterPFor("SelectV2") 2722def _convert_selectv2(pfor_input): 2723 pfor_input.expanddim_inputs_for_broadcast() 2724 cond = pfor_input.input(0)[0] 2725 t = pfor_input.input(1)[0] 2726 e = pfor_input.input(2)[0] 2727 out = array_ops.where_v2(cond, t, e) 2728 return wrap(out, True) 2729 2730 2731# random_ops 2732 2733 2734def _transpose_dim_to_front(x, dim): 2735 rank = array_ops.rank(x) 2736 return array_ops.transpose( 2737 x, 2738 perm=array_ops.concat( 2739 [[dim], math_ops.range(0, dim), 2740 math_ops.range(dim + 1, rank)], 2741 axis=0)) 2742 2743 2744@RegisterPForWithArgs("RandomUniform") 2745@RegisterPForWithArgs("RandomUniformInt") 2746@RegisterPForWithArgs("RandomStandardNormal") 2747@RegisterPForWithArgs("TruncatedNormal") 2748def _convert_random(pfor_input, op_type, *args, **kw_args): 2749 del args 2750 del kw_args 2751 inputs = [pfor_input.unstacked_input(i) for i in range(pfor_input.num_inputs)] 2752 # inputs[0] is "shape" 2753 inputs[0] = array_ops.concat([pfor_input.pfor.loop_len_vector, inputs[0]], 2754 axis=0) 2755 logging.warning( 2756 "Note that %s inside pfor op may not give same output as " 2757 "inside a sequential loop.", op_type) 2758 outputs = _create_op( 2759 op_type, 2760 inputs, [x.dtype for x in pfor_input.outputs], 2761 attrs=pfor_input.op.node_def.attr).outputs 2762 return [wrap(x, True) for x in outputs] 2763 2764 2765@RegisterPFor("RandomGamma") 2766@RegisterPFor("RandomPoissonV2") 2767def _convert_random_with_param(pfor_input): 2768 shape = pfor_input.unstacked_input(0) 2769 # param is lam (Poisson rate) or alpha (Gamma shape). 2770 param, param_stacked, _ = pfor_input.input(1) 2771 logging.warning( 2772 "Note that %s inside pfor op may not give same output as " 2773 "inside a sequential loop.", pfor_input.op_type) 2774 2775 if param_stacked: 2776 samples = _create_op( 2777 pfor_input.op_type, 2778 inputs=[shape, param], 2779 op_dtypes=[x.dtype for x in pfor_input.outputs], 2780 attrs=pfor_input.op.node_def.attr).outputs[0] 2781 loop_dim = array_ops.shape(shape)[0] 2782 stacked_samples = _transpose_dim_to_front(samples, loop_dim) 2783 else: 2784 shape = array_ops.concat([pfor_input.pfor.loop_len_vector, shape], axis=0) 2785 stacked_samples = _create_op( 2786 pfor_input.op_type, 2787 inputs=[shape, param], 2788 op_dtypes=[x.dtype for x in pfor_input.outputs], 2789 attrs=pfor_input.op.node_def.attr).outputs[0] 2790 2791 return wrap(stacked_samples, True) 2792 2793 2794@RegisterPFor("Multinomial") 2795def _convert_multinomial(pfor_input): 2796 logits, logits_stacked, _ = pfor_input.input(0) 2797 num_samples = pfor_input.unstacked_input(1) 2798 seed = pfor_input.get_attr("seed") 2799 seed2 = pfor_input.get_attr("seed2") 2800 output_dtype = pfor_input.get_attr("output_dtype") 2801 logging.warning( 2802 "Note that Multinomial inside pfor op may not give same output as " 2803 "inside a sequential loop.") 2804 2805 n = pfor_input.pfor.loop_len_vector[0] 2806 if logits_stacked: 2807 flattened_logits = _flatten_first_two_dims(logits) 2808 samples = gen_random_ops.multinomial( 2809 flattened_logits, 2810 num_samples, 2811 seed=seed, 2812 seed2=seed2, 2813 output_dtype=output_dtype) 2814 stacked_samples = _unflatten_first_dim(samples, [n]) 2815 else: 2816 samples = gen_random_ops.multinomial( 2817 logits, 2818 num_samples * n, 2819 seed=seed, 2820 seed2=seed2, 2821 output_dtype=output_dtype) 2822 stacked_samples = array_ops.transpose( 2823 array_ops.reshape(samples, [-1, n, num_samples]), [1, 0, 2]) 2824 2825 return wrap(stacked_samples, True) 2826 2827 2828# linalg_ops 2829 2830 2831# TODO(jmenick) - the same logic applies to other einsums. Generalize this 2832# in a future CL. 2833@RegisterPFor("XlaEinsum") 2834def _convert_einsum(pfor_input): 2835 first_input, first_input_stacked, _ = pfor_input.input(0) 2836 second_input, second_input_stacked, _ = pfor_input.input(1) 2837 2838 # Parse the einsum equation. 2839 equation = pfor_input.get_attr("equation").decode("utf-8") 2840 input_expr, output_expr = equation.split("->") 2841 input_a_expr, input_b_expr = input_expr.split(",") 2842 2843 # pick a placeholder symbol to use for the new axis 2844 chosen_symbol = None 2845 for s in string.ascii_letters: 2846 if s in equation: 2847 continue 2848 else: 2849 chosen_symbol = s 2850 break 2851 2852 if chosen_symbol is None: 2853 raise ValueError("Could not figure out what symbol to use for new axis.") 2854 2855 assert first_input_stacked or second_input_stacked 2856 if first_input_stacked: 2857 input_a_expr = "{}{}".format(chosen_symbol, input_a_expr) 2858 if second_input_stacked: 2859 input_b_expr = "{}{}".format(chosen_symbol, input_b_expr) 2860 output_expr = "{}{}".format(chosen_symbol, output_expr) 2861 2862 new_equation = "{},{}->{}".format(input_a_expr, input_b_expr, output_expr) 2863 result = xla.einsum(equation=new_equation, a=first_input, b=second_input) 2864 return wrap(result, True) 2865 2866 2867@RegisterPFor("Cholesky") 2868def _convert_cholesky(pfor_input): 2869 t = pfor_input.stacked_input(0) 2870 return wrap(linalg_ops.cholesky(t), True) 2871 2872 2873@RegisterPFor("LogMatrixDeterminant") 2874def _convert_log_matrix_determinant(pfor_input): 2875 t = pfor_input.stacked_input(0) 2876 return [wrap(x, True) for x in linalg_ops.log_matrix_determinant(t)] 2877 2878 2879@RegisterPFor("MatrixTriangularSolve") 2880def _convert_matrix_triangular_solve(pfor_input): 2881 pfor_input.expanddim_inputs_for_broadcast() 2882 matrix = pfor_input.input(0)[0] 2883 rhs = pfor_input.input(1)[0] 2884 lower = pfor_input.get_attr("lower") 2885 adjoint = pfor_input.get_attr("adjoint") 2886 output = linalg_ops.matrix_triangular_solve( 2887 matrix, rhs, lower=lower, adjoint=adjoint) 2888 return wrap(output, True) 2889 2890 2891@RegisterPFor("SelfAdjointEigV2") 2892def _convert_self_adjoint_eig(pfor_input): 2893 t = pfor_input.stacked_input(0) 2894 compute_v = pfor_input.get_attr("compute_v") 2895 e, v = gen_linalg_ops.self_adjoint_eig_v2(t, compute_v=compute_v) 2896 # If compute_v is False, v will have shape [0]. 2897 return wrap(e, True), wrap(v, compute_v) 2898 2899 2900# logging_ops 2901 2902 2903@RegisterPFor("Assert") 2904def _convert_assert(pfor_input): 2905 cond, cond_stacked, _ = pfor_input.input(0) 2906 if cond_stacked: 2907 cond = math_ops.reduce_all(cond) 2908 2909 data_list = [x.t for x in pfor_input.inputs][1:] 2910 return _create_op( 2911 "Assert", [cond] + data_list, [], attrs=pfor_input.op.node_def.attr) 2912 2913 2914@RegisterPFor("Print") 2915def _convert_print(pfor_input): 2916 # Note that we don't stack all the inputs. Hence unstacked values are printed 2917 # once here vs multiple times in a while_loop. 2918 pfor_input.stack_inputs([0]) 2919 outputs = _create_op( 2920 "Print", [x.t for x in pfor_input.inputs], 2921 [x.dtype for x in pfor_input.outputs], 2922 attrs=pfor_input.op.node_def.attr).outputs 2923 return [wrap(x, True) for x in outputs] 2924 2925 2926# data_flow_ops 2927 2928# TensorArray conversion is tricky since we don't support arrays of 2929# TensorArrays. For converting them, we consider two distinct cases: 2930# 2931# 1. The array is constructed outside the pfor call, and read/written inside the 2932# loop. 2933# This is an easier case since we don't need to make an array of TensorArrays. 2934# A correctness requirement is that these parallel iterations shouldn't attempt 2935# to write to the same location. Hence at conversion time we disallow indices to 2936# be loop-invariant as that would guarantee a collision. Even if the indices are 2937# not loop-invariant, they could conflict and that shall trigger runtime errors. 2938# 2939# 2. The array is constructed and used entirely inside each pfor iteration. 2940# For simplicity, here we require that the indices used for write/scatter are 2941# "unstacked". Otherwise it becomes hard to merge the TensorArrays created in 2942# different pfor iterations. We consider two sub_cases: 2943# 2944# 2a Elements written to the array are "stacked" 2945# To simulate multiple TensorArrays, we may increase the dimension of each 2946# element of the array. i.e. the i_th row of the j_th entry of the converted 2947# TensorArray corresponds to the j_th entry of the TensorArray in the i_th 2948# pfor iteration. 2949# 2950# 2b Elements written to the array are "unstacked" 2951# In this case we don't increase the dimensions to avoid redundant tiling. Each 2952# iteration is trying to write the same value. So we convert that to a single 2953# write. 2954# 2955# Here are some tricks used to implement the above: 2956# - TensorArrayV3 constructor encodes the element shape as an attr. Instead of 2957# trying to trace whether future writes are stacked or unstacked in order to set 2958# this attr, we set it to correspond to unknown shape. 2959# - We use the "flow" output of the different ops to track whether the array 2960# elements are stacked or unstacked. If a stacked write/scatter is done, we make 2961# the flow stacked as well. 2962# - We use some heuristic traversal of the graph to track whether the 2963# TensorArray handle was created inside or outside the pfor loop. 2964 2965 2966@RegisterPFor("TensorArrayV3") 2967def _convert_tensor_array_v3(pfor_input): 2968 size = pfor_input.unstacked_input(0) 2969 dtype = pfor_input.get_attr("dtype") 2970 dynamic_size = pfor_input.get_attr("dynamic_size") 2971 clear_after_read = pfor_input.get_attr("clear_after_read") 2972 identical_element_shapes = pfor_input.get_attr("identical_element_shapes") 2973 tensor_array_name = pfor_input.get_attr("tensor_array_name") 2974 handle, flow = data_flow_ops.tensor_array_v3( 2975 size, 2976 dtype=dtype, 2977 # We don't set element shape since we don't know if writes are stacked or 2978 # not yet. 2979 element_shape=None, 2980 dynamic_size=dynamic_size, 2981 clear_after_read=clear_after_read, 2982 identical_element_shapes=identical_element_shapes, 2983 tensor_array_name=tensor_array_name) 2984 # Note we keep flow unstacked for now since we don't know if writes will be 2985 # stacked or not. 2986 return wrap(handle, False), wrap(flow, False) 2987 2988 2989@RegisterPFor("TensorArraySizeV3") 2990def _convert_tensor_array_size_v3(pfor_input): 2991 handle = pfor_input.unstacked_input(0) 2992 flow, flow_stacked, _ = pfor_input.input(1) 2993 if flow_stacked: 2994 flow = _unstack_flow(flow) 2995 size = data_flow_ops.tensor_array_size_v3(handle, flow) 2996 return wrap(size, False) 2997 2998 2999def _handle_inside_pfor(pfor_input, handle): 3000 """Returns True if handle was created inside the pfor loop.""" 3001 # We use some heuristic to find the original TensorArray creation op. 3002 # The logic should handle the common cases (except cond based subgraphs). 3003 # In theory the user could perform different operations on the handle (like 3004 # Reshape, stack multiple handles, etc) which could break this logic. 3005 # TODO(agarwal): handle Switch/Merge. 3006 while handle.op.type in ("Enter", "Identity"): 3007 handle = handle.op.inputs[0] 3008 if handle.op.type not in [ 3009 "TensorArrayV3", "TensorArrayGradV3", "TensorArrayGradWithShape" 3010 ]: 3011 raise ValueError("Unable to find source for handle %s" % handle) 3012 else: 3013 return pfor_input.pfor.op_is_inside_loop(handle.op) 3014 3015 3016def _unstack_flow(value): 3017 # TODO(agarwal): consider looking if this is a Tile op then get its input. 3018 # This may avoid running the Tile operations. 3019 return array_ops.gather(value, 0) 3020 3021 3022@RegisterPFor("TensorArrayReadV3") 3023def _convert_tensor_array_read_v3(pfor_input): 3024 handle = pfor_input.unstacked_input(0) 3025 index, index_stacked, _ = pfor_input.input(1) 3026 dtype = pfor_input.get_attr("dtype") 3027 flow, flow_stacked, _ = pfor_input.input(2) 3028 if flow_stacked: 3029 flow = _unstack_flow(flow) 3030 3031 is_inside_pfor = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3032 if is_inside_pfor: 3033 # Note that if we are inside a control flow construct inside the pfor, and 3034 # only some of the iterations are doing the read (i.e. 3035 # `all_indices_partitioned` is True), then the read operation should only 3036 # return values for the currently active pfor iterations (`all_indices` 3037 # below). Hence, whenever the returned value is stacked (i.e. `flow` is 3038 # stacked), we may need to do an extra gather after reading the values. Also 3039 # note that if `is_inside` is false, then values in the tensor array are 3040 # unstacked. So the check is only needed in this branch. 3041 all_indices = pfor_input.pfor.all_indices 3042 all_indices_partitioned = pfor_input.pfor.all_indices_partitioned 3043 # Note: flow_stacked indicates if values in the TensorArray are stacked or 3044 # not. 3045 if index_stacked: 3046 if flow_stacked: 3047 raise ValueError( 3048 "It looks like TensorArrayReadV3 was called on a TensorArray whose" 3049 " values are not loop-invariant, and the read indices were also" 3050 " not loop invariant. This is currently unsupported.") 3051 value = data_flow_ops.tensor_array_gather_v3( 3052 handle, index, flow, dtype=dtype) 3053 return wrap(value, True) 3054 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3055 if flow_stacked and all_indices_partitioned: 3056 value = array_ops.gather(value, all_indices) 3057 return wrap(value, flow_stacked) 3058 # Values in the TensorArray should be unstacked (since different iterations 3059 # couldn't write to the same location). So whether output is stacked or not 3060 # depends on index_stacked. 3061 if index_stacked: 3062 value = data_flow_ops.tensor_array_gather_v3( 3063 handle, index, flow, dtype=dtype) 3064 else: 3065 value = data_flow_ops.tensor_array_read_v3(handle, index, flow, dtype=dtype) 3066 return wrap(value, index_stacked) 3067 3068 3069@RegisterPFor("TensorArrayWriteV3") 3070def _convert_tensor_array_write_v3(pfor_input): 3071 handle = pfor_input.unstacked_input(0) 3072 index, index_stacked, _ = pfor_input.input(1) 3073 value, value_stacked, _ = pfor_input.input(2) 3074 flow, flow_stacked, _ = pfor_input.input(3) 3075 if value_stacked and pfor_input.pfor.all_indices_partitioned: 3076 # Looks like we are in a control flow in a pfor where not all iterations are 3077 # active now. We don't allow that since that could lead to different indices 3078 # having different shapes which will be hard to merge later. 3079 raise ValueError("Writing non loop invariant values to TensorArray from " 3080 "inside a while_loop/cond not supported.") 3081 if flow_stacked: 3082 flow = _unstack_flow(flow) 3083 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3084 if is_inside: 3085 if index_stacked: 3086 raise ValueError("Need indices for %s to be loop invariant" % handle) 3087 if not flow_stacked and not value_stacked: 3088 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3089 return wrap(flow_out, False) 3090 else: 3091 if not value_stacked: 3092 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3093 # TODO(agarwal): Note that if flow is unstacked and value is stacked, then 3094 # this may or may not be a safe situation. flow is unstacked both for a 3095 # freshly created TensorArray, as well as after unstacked values are 3096 # written to it. If it is the latter, then we cannot write a stacked value 3097 # now since that may cause runtime errors due to different shapes in the 3098 # array. At the moment we are not able to handle this gracefully and 3099 # distinguish between the two cases. That would require some heuristic 3100 # traversal of the graph to figure out whether all the writes are 3101 # unstacked or not. 3102 flow_out = data_flow_ops.tensor_array_write_v3(handle, index, value, flow) 3103 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3104 else: 3105 if not index_stacked: 3106 raise ValueError("Need indices for %s to be not loop invariant" % handle) 3107 # Note that even when index_stacked is true, actual values in index may 3108 # still not be unique. However that will cause runtime error when executing 3109 # the scatter operation below. 3110 if not value_stacked: 3111 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3112 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, index, value, flow) 3113 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3114 3115 3116def _transpose_first_two_dims(value): 3117 # TODO(agarwal): optimize if one of the dims == 1. 3118 value_shape = array_ops.shape(value) 3119 v0 = value_shape[0] 3120 v1 = value_shape[1] 3121 value = array_ops.reshape(value, [v0, v1, -1]) 3122 value = array_ops.transpose(value, [1, 0, 2]) 3123 new_shape = array_ops.concat([[v1, v0], value_shape[2:]], axis=0) 3124 return array_ops.reshape(value, new_shape) 3125 3126 3127@RegisterPFor("TensorArrayGatherV3") 3128def _convert_tensor_array_gather_v3(pfor_input): 3129 handle = pfor_input.unstacked_input(0) 3130 indices, indices_stacked, _ = pfor_input.input(1) 3131 indices = array_ops.reshape(indices, [-1]) 3132 flow, flow_stacked, _ = pfor_input.input(2) 3133 if flow_stacked: 3134 flow = _unstack_flow(flow) 3135 dtype = pfor_input.get_attr("dtype") 3136 # TODO(agarwal): support element_shape attr? 3137 3138 n = pfor_input.pfor.loop_len_vector 3139 value = data_flow_ops.tensor_array_gather_v3( 3140 handle, indices, flow, dtype=dtype) 3141 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3142 if is_inside: 3143 # flow_stacked indicates if values in the TensorArray are stacked or not. 3144 if indices_stacked: 3145 if flow_stacked: 3146 raise ValueError( 3147 "It looks like TensorArrayGatherV3 was called on a TensorArray " 3148 "whose values are not loop-invariant, and the indices were also " 3149 "not loop invariant. This is currently unsupported.") 3150 else: 3151 value = _unflatten_first_dim(value, n) 3152 return wrap(value, True) 3153 else: 3154 if flow_stacked: 3155 # Since elements in this array are stacked and `value` was produced by 3156 # gather, its first two dims are "gathered elements" and "stack 3157 # dimension". Our semantics require these two to be flipped. 3158 value = _transpose_first_two_dims(value) 3159 return wrap(value, flow_stacked) 3160 else: 3161 # Values in the TensorArray should be unstacked (since different iterations 3162 # couldn't write to the same location). So whether output is stacked or not 3163 # depends on indices_stacked. 3164 if indices_stacked: 3165 value = _unflatten_first_dim(value, n) 3166 return wrap(value, indices_stacked) 3167 3168 3169@RegisterPFor("TensorArrayScatterV3") 3170def _convert_tensor_array_scatter_v3(pfor_input): 3171 handle = pfor_input.unstacked_input(0) 3172 indices, indices_stacked, _ = pfor_input.input(1) 3173 indices = array_ops.reshape(indices, [-1]) 3174 value, value_stacked, _ = pfor_input.input(2) 3175 flow, flow_stacked, _ = pfor_input.input(3) 3176 3177 if flow_stacked: 3178 flow = _unstack_flow(flow) 3179 3180 is_inside = _handle_inside_pfor(pfor_input, pfor_input.op.inputs[0]) 3181 if is_inside: 3182 if indices_stacked: 3183 raise ValueError("Need indices for %s to be loop invariant" % handle) 3184 # Note that flow_stacked indicates if existing values in the array are 3185 # stacked or not. 3186 if not flow_stacked and not value_stacked: 3187 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3188 flow) 3189 return wrap(flow_out, False) 3190 if not value_stacked: 3191 # TODO(agarwal): tile in the second dimension directly instead of 3192 # transposing below. 3193 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3194 3195 value = _transpose_first_two_dims(value) 3196 # TODO(agarwal): Note that if a previous write was unstacked, flow will be 3197 # unstacked, and a stacked value may be written here which may cause 3198 # runtime error due to different elements having different shape. We do 3199 # not try to prevent that. 3200 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, 3201 flow) 3202 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3203 if not indices_stacked: 3204 raise ValueError("Need indices for %s to be not loop invariant" % handle) 3205 if not value_stacked: 3206 value = _stack(value, pfor_input.pfor.loop_len_vector).t 3207 value = _flatten_first_two_dims(value) 3208 flow_out = data_flow_ops.tensor_array_scatter_v3(handle, indices, value, flow) 3209 return _stack(flow_out, pfor_input.pfor.loop_len_vector) 3210 3211 3212@RegisterPFor("TensorArrayGradV3") 3213def _convert_tensor_array_grad_v3(pfor_input): 3214 handle = pfor_input.unstacked_input(0) 3215 flow, flow_stacked, _ = pfor_input.input(1) 3216 if flow_stacked: 3217 flow = _unstack_flow(flow) 3218 source = pfor_input.get_attr("source") 3219 # TODO(agarwal): For now, we assume that gradients are stacked if the 3220 # TensorArrayGradV3 call is being done inside the pfor. Getting that wrong 3221 # will give runtime error due to incorrect shape being written to the 3222 # accumulator. It is difficult to know in advance if gradients written will be 3223 # stacked or not. Note that flow being stacked is not indicative of the 3224 # gradient being stacked or not. Revisit this later. 3225 shape_to_prepend = pfor_input.pfor.loop_len_vector 3226 grad_handle, flow_out = data_flow_ops.tensor_array_grad_with_shape( 3227 handle=handle, 3228 flow_in=flow, 3229 shape_to_prepend=shape_to_prepend, 3230 source=source) 3231 flow_out = _stack(flow_out, pfor_input.pfor.loop_len_vector).t 3232 return [wrap(grad_handle, False), wrap(flow_out, True)] 3233 3234 3235# StackV2 conversion is tricky since we don't have arrays of StackV2. So similar 3236# to TensorArrays, we convert them by changing the dimension of the elements 3237# inside the stack. 3238# 3239# We consider two cases: 3240# 3241# 1. StackV2 is constructed and used entirely inside the pfor loop. 3242# We keep a single Stack and perform the push/pop operations of all the 3243# iterations in lock-step. We also assume that all the iterations perform these 3244# operations. In case of dynamic control flow, if only some of the iterations 3245# try to perform a push/pop, then the conversion may not work correctly and may 3246# cause undefined behavior. 3247# TODO(agarwal): test StackV2 with dynamic control flow. 3248# 3249# 2. StackV2 is constructed outside the pfor loop. 3250# Performing stack push/pop in a parallel fashion is ill-defined. However given 3251# that reading stacks created externally is a common operation when computing 3252# jacobians, we provide some special semantics here as follows. 3253# - disallow push operations to the stack 3254# - pop operations are performed in lock step by all iterations, similar to the 3255# case when the stack is created inside. A single value is popped during the 3256# lock-step operation and broadcast to all the iterations. Values in the stack 3257# are assumed to be loop-invariant. 3258# 3259# Some other implementation details: 3260# We use an ugly logic to find whether values in Stack data structure are 3261# loop invariant or not. When converting push/pop operations, we keep track of 3262# whether the last conversion used a stacked value or not (see _stack_cache 3263# below). As a result if an unstacked value is written first, subsequent stacked 3264# writes are disallowed when they could have been allowed in theory. 3265 3266# Map from cache key based on StackV2 handle to a bool indicating whether values 3267# are stacked or not. 3268# TODO(agarwal): move _stack_cache inside pfor? 3269_stack_cache = {} 3270 3271 3272def _stack_cache_key(pfor_input): 3273 """Create cache key corresponding to a stack handle.""" 3274 op_type = pfor_input.op_type 3275 assert op_type in ["StackPushV2", "StackPopV2"], op_type 3276 orig_handle = pfor_input.op.inputs[0] 3277 while orig_handle.op.type in ["Identity", "Enter"]: 3278 orig_handle = orig_handle.op.inputs[0] 3279 assert orig_handle.op.type == "StackV2", orig_handle.op 3280 return ops.get_default_graph(), pfor_input.pfor, orig_handle 3281 3282 3283def _stack_handle_inside_pfor(handle, pfor_input): 3284 while handle.op.type in ["Identity", "Enter"]: 3285 handle = handle.op.inputs[0] 3286 assert handle.op.type == "StackV2", ("Unable to find StackV2 op. Got %s" % 3287 handle.op) 3288 return pfor_input.pfor.op_is_inside_loop(handle.op) 3289 3290 3291@RegisterPFor("StackPushV2") 3292def _convert_stack_push_v2(pfor_input): 3293 handle = pfor_input.unstacked_input(0) 3294 elem, elem_stacked, _ = pfor_input.input(1) 3295 swap_memory = pfor_input.get_attr("swap_memory") 3296 3297 if not _stack_handle_inside_pfor(pfor_input.op.inputs[0], pfor_input): 3298 raise ValueError("StackPushV2 not allowed on stacks created outside pfor") 3299 stack_cache_key = _stack_cache_key(pfor_input) 3300 stacked = _stack_cache.get(stack_cache_key, None) 3301 if stacked is None: 3302 stacked = elem_stacked 3303 _stack_cache[stack_cache_key] = stacked 3304 else: 3305 # If we previously made it unstacked then we can't revert to being stacked. 3306 if not stacked and elem_stacked: 3307 raise ValueError( 3308 "It looks like the stack was previously determined to be loop" 3309 " invariant, but we are now trying to push a loop dependent value" 3310 " to it. This is currently unsupported.") 3311 if stacked and not elem_stacked: 3312 elem = _stack(elem, pfor_input.pfor.loop_len_vector).t 3313 out = data_flow_ops.stack_push_v2(handle, elem, swap_memory=swap_memory) 3314 return wrap(out, stacked) 3315 3316 3317# Note that inputs to this convertor will be unstacked. However it should get 3318# called since it is a stateful op. 3319@RegisterPFor("StackPopV2") 3320def _convert_stack_pop_v2(pfor_input): 3321 handle = pfor_input.unstacked_input(0) 3322 stack_cache_key = _stack_cache_key(pfor_input) 3323 stacked = _stack_cache.get(stack_cache_key, None) 3324 # If a StackPushV2 has not been converted yet, we default to unstacked since 3325 # the push could be outside of pfor, or the covertor may not be called if the 3326 # inputs are unconverted. 3327 if stacked is None: 3328 stacked = False 3329 _stack_cache[stack_cache_key] = False 3330 elem_type = pfor_input.get_attr("elem_type") 3331 out = data_flow_ops.stack_pop_v2(handle, elem_type) 3332 return wrap(out, stacked) 3333 3334 3335# parsing_ops 3336 3337 3338@RegisterPFor("DecodeCSV") 3339def _convert_decode_csv(pfor_input): 3340 lines = pfor_input.stacked_input(0) 3341 record_defaults = [ 3342 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 3343 ] 3344 field_delim = pfor_input.get_attr("field_delim") 3345 use_quote_delim = pfor_input.get_attr("use_quote_delim") 3346 select_cols = pfor_input.get_attr("select_cols") 3347 if not select_cols: 3348 select_cols = None 3349 return [ 3350 wrap(t, True) for t in parsing_ops.decode_csv( 3351 lines, 3352 record_defaults, 3353 field_delim=field_delim, 3354 use_quote_delim=use_quote_delim, 3355 select_cols=select_cols) 3356 ] 3357 3358 3359@RegisterPFor("ParseSingleExample") 3360def _convert_parse_single_example(pfor_input): 3361 serialized = pfor_input.stacked_input(0) 3362 dense_defaults = [ 3363 pfor_input.unstacked_input(i) for i in range(1, pfor_input.num_inputs) 3364 ] 3365 sparse_keys = pfor_input.get_attr("sparse_keys") 3366 dense_keys = pfor_input.get_attr("dense_keys") 3367 sparse_types = pfor_input.get_attr("sparse_types") 3368 dense_shapes = pfor_input.get_attr("dense_shapes") 3369 output = gen_parsing_ops.parse_example( 3370 serialized=serialized, 3371 names=[], 3372 dense_defaults=dense_defaults, 3373 sparse_keys=sparse_keys, 3374 dense_keys=dense_keys, 3375 sparse_types=sparse_types, 3376 dense_shapes=dense_shapes) 3377 return [wrap(t, True, True) for t in nest.flatten(output)] 3378 3379 3380@RegisterPFor("ParseExampleV2") 3381def _convert_parse_example_v2(pfor_input): 3382 serialized = pfor_input.stacked_input(0) 3383 sparse_keys = pfor_input.unstacked_input(2) 3384 dense_keys = pfor_input.unstacked_input(3) 3385 ragged_keys = pfor_input.unstacked_input(4) 3386 dense_defaults = [ 3387 pfor_input.unstacked_input(i) for i in range(5, pfor_input.num_inputs) 3388 ] 3389 num_sparse = pfor_input.get_attr("num_sparse") 3390 sparse_types = pfor_input.get_attr("sparse_types") 3391 ragged_value_types = pfor_input.get_attr("ragged_value_types") 3392 ragged_split_types = pfor_input.get_attr("ragged_split_types") 3393 dense_shapes = pfor_input.get_attr("dense_shapes") 3394 if serialized.shape.ndims not in (None, 1): 3395 raise ValueError("ParseExampleV2 can only be converted if `serialized` " 3396 "is scalar.") 3397 output = gen_parsing_ops.parse_example_v2( 3398 serialized=serialized, 3399 names=[], 3400 sparse_keys=sparse_keys, 3401 dense_keys=dense_keys, 3402 ragged_keys=ragged_keys, 3403 dense_defaults=dense_defaults, 3404 num_sparse=num_sparse, 3405 sparse_types=sparse_types, 3406 ragged_value_types=ragged_value_types, 3407 ragged_split_types=ragged_split_types, 3408 dense_shapes=dense_shapes) 3409 return [wrap(t, True, True) for t in nest.flatten(output)] 3410 3411 3412# functional_ops 3413 3414 3415@RegisterPFor("StatefulPartitionedCall") 3416@RegisterPFor("PartitionedCall") 3417def _convert_partitioned_call(pfor_input): 3418 func_name = pfor_input.get_attr("f").name 3419 func = pfor_input.op.graph._get_function(compat.as_bytes(func_name)) 3420 assert isinstance(func.graph, func_graph.FuncGraph), ( 3421 "Could not find FuncGraph object for %s. Got func %s" % (func_name, func)) 3422 pfor = pfor_input.pfor 3423 converter = PFor( 3424 loop_var=pfor.loop_var, 3425 loop_len=pfor.loop_len_vector[0], 3426 pfor_ops=func.graph.get_operations(), 3427 all_indices=pfor.all_indices, 3428 all_indices_partitioned=pfor.all_indices_partitioned, 3429 pfor_config=pfor.pfor_config) 3430 3431 # TODO(agarwal): consider caching this function definition. 3432 @def_function.function 3433 def f(*args): 3434 assert all(isinstance(arg, WrappedTensor) for arg in args), args 3435 assert len(args) == len(func.graph.inputs), (args, func.graph.inputs) 3436 # Map inputs to function arguments. 3437 for inp, arg in zip(func.graph.inputs, args): 3438 converter._add_conversion(inp, arg) 3439 # Convert output tensors. 3440 return tuple( 3441 [converter._convert_helper(x).t for x in func._func_graph_outputs]) 3442 3443 call_outputs = f(*pfor_input.inputs) 3444 assert len(call_outputs) == len(func._func_graph_outputs) 3445 outputs = [] 3446 for call_output, output_tensor in zip(call_outputs, func._func_graph_outputs): 3447 func_output = converter._convert_helper(output_tensor) 3448 outputs.append( 3449 wrap(call_output, func_output.is_stacked, 3450 func_output.is_sparse_stacked)) 3451 return outputs 3452