1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""xla is an experimental library that provides XLA support APIs.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22 23from six.moves import xrange # pylint: disable=redefined-builtin 24 25from tensorflow.compiler.jit.ops import xla_ops 26from tensorflow.compiler.jit.ops import xla_ops_grad # pylint: disable=unused-import 27from tensorflow.core.framework import attr_value_pb2 28from tensorflow.python.distribute import summary_op_util 29from tensorflow.python.eager import context 30from tensorflow.python.eager import def_function 31from tensorflow.python.framework import ops 32from tensorflow.python.ops import array_ops 33from tensorflow.python.ops import control_flow_ops 34from tensorflow.python.ops import variable_scope 35from tensorflow.python.platform import tf_logging as logging 36from tensorflow.python.util import compat 37from tensorflow.python.util import nest 38from tensorflow.python.util import tf_inspect 39from tensorflow.python.util.compat import collections_abc 40from tensorflow.python.util.deprecation import deprecated 41from tensorflow.python.util.tf_export import tf_export 42 43_XLA_COMPILE_ATTR = '_xla_compile_id' 44_MAX_WARNING_LINES = 5 45 46# Operations that indicate some error in the users graph. For example, XLA 47# computation should not have any Placeholder op. 48_DENYLISTED_OPS = set([ 49 'Placeholder', 50]) 51 52# XLA doesn't currently support reading of intermediate tensors, thus some ops 53# are not supported. 54_UNSUPPORTED_OPS = set([ 55 'AudioSummary', 56 'AudioSummaryV2', 57 'HistogramSummary', 58 'ImageSummary', 59 'MergeSummary', 60 'Print', 61 'ScalarSummary', 62 'TensorSummary', 63 'TensorSummaryV2', 64]) 65 66 67@tf_export('xla.experimental.compile') 68@deprecated( 69 None, 'xla.experimental.compile is deprecated. Consider using ' 70 'tf.function(jit_compile=True)', 71 warn_once=True) 72def compile(computation, inputs=None): # pylint: disable=redefined-builtin 73 """Builds an operator that compiles and runs `computation` with XLA. 74 75 NOTE: In eager mode, `computation` will have `@tf.function` semantics. 76 77 Args: 78 computation: A Python function that builds a computation to apply to the 79 input. If the function takes n inputs, 'inputs' should be a list of n 80 tensors. 81 82 `computation` may return a list of operations and tensors. Tensors must 83 come before operations in the returned list. The return value of 84 `compile` is a list of tensors corresponding to the tensors from the 85 output of `computation`. 86 87 All `Operation`s returned from `computation` will be executed when 88 evaluating any of the returned output tensors. 89 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 90 can be a nested structure containing values that are convertible to 91 tensors. Note that passing an N-dimension list of compatible values will 92 result in a N-dimension list of scalar tensors rather than a single Rank-N 93 tensors. If you need different behavior, convert part of inputs to tensors 94 with `tf.convert_to_tensor`. 95 96 Returns: 97 Same data structure as if computation(*inputs) is called directly with some 98 exceptions for correctness. Exceptions include: 99 1) None output: a NoOp would be returned which control-depends on 100 computation. 101 2) Single value output: A tuple containing the value would be returned. 102 3) Operation-only outputs: a NoOp would be returned which 103 control-depends on computation. 104 TODO(b/121383831): Investigate into removing these special cases. 105 106 Raises: 107 RuntimeError: if called when eager execution is enabled. 108 109 Known issues: 110 When a tf.random operation is built with XLA, the implementation doesn't 111 pass the user provided seed to the XLA compiler. As such, the XLA compiler 112 generates a random number and uses it as a seed when compiling the 113 operation. This implementation causes a violation of the Tensorflow 114 defined semantics in two aspects. First, changing the value of the user 115 defined seed doesn't change the numbers generated by the operation. 116 Second, when a seed is not specified, running the program multiple times 117 will generate the same numbers. 118 119 """ 120 if context.executing_eagerly(): 121 @def_function.function 122 def xla_compile_wrapper(): 123 return _compile_internal(computation, inputs) 124 125 return xla_compile_wrapper() 126 127 return _compile_internal(computation, inputs) 128 129 130class XLACompileContext(control_flow_ops.XLAControlFlowContext): 131 """A `ControlFlowContext` for nodes inside an XLA computation cluster. 132 133 THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY. 134 135 The primary role of `XLACompileContext` is to mark operators inside a 136 xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is 137 a unique name. 138 139 `ControlFlowContext` is used to perform the annotation since it integrates 140 with Tensorflow constructs like ResourceVariables. For example, if a 141 `ResourceVariable` is constructed inside a xla.compile() block, the 142 `ResourceVariable` implementation can use 143 `with ops.control_dependencies(None)` to build the variable's definition 144 outside the compiled computation. 145 """ 146 147 def __init__(self, name, pivot): 148 """Builds a new XLACompileContext. 149 150 Args: 151 name: a unique name for the context, used to populate the 152 `_xla_compile_id` attribute. 153 pivot: a pivot node. Nodes in the XLACompileContext that do not have any 154 inputs will have a control dependency on the pivot node. This ensures 155 that nodes are correctly included in any enclosing control flow 156 contexts. 157 """ 158 super(XLACompileContext, self).__init__() 159 self._name = name 160 self._name_as_bytes = compat.as_bytes(name) 161 self._unsupported_ops = [] 162 self._pivot = pivot 163 164 def report_unsupported_operations(self): 165 if self._unsupported_ops: 166 op_str = '\n'.join([ 167 ' %s (%s)' % (op.type, op.name) 168 for op in self._unsupported_ops[:_MAX_WARNING_LINES] 169 ]) 170 logging.warning('%d unsupported operations found: \n%s', 171 len(self._unsupported_ops), op_str) 172 if len(self._unsupported_ops) > _MAX_WARNING_LINES: 173 logging.warning('... and %d more', 174 len(self._unsupported_ops) - _MAX_WARNING_LINES) 175 176 def _RemoveExternalControlEdges(self, op): 177 """Remove any external control dependency on this op.""" 178 internal_control_inputs = [] 179 external_control_inputs = [] 180 for x in op.control_inputs: 181 # pylint: disable=protected-access 182 is_internal_op = False 183 ctxt = x._get_control_flow_context() 184 while ctxt is not None: 185 if ctxt == self: 186 is_internal_op = True 187 break 188 ctxt = ctxt._outer_context 189 if is_internal_op: 190 internal_control_inputs.append(x) 191 else: 192 external_control_inputs.append(x) 193 # pylint: enable=protected-access 194 # pylint: disable=protected-access 195 op._remove_all_control_inputs() 196 op._add_control_inputs(internal_control_inputs) 197 # pylint: enable=protected-access 198 return internal_control_inputs, external_control_inputs 199 200 def AddOp(self, op): 201 """Create op in XLACompileContext and notifies outer context recursively.""" 202 # pylint: disable=protected-access 203 if op.type in _DENYLISTED_OPS: 204 logging.error( 205 'Operation of type %s (%s) is not supported in XLA. Execution will ' 206 'fail if this op is used in the graph. ', op.type, op.name) 207 208 # TODO(ycao): Automatically disable summaries instead of reporting them. 209 if op.type in _UNSUPPORTED_OPS: 210 self._unsupported_ops.append(op) 211 212 if any(x.dtype._is_ref_dtype for x in op.inputs): 213 raise NotImplementedError( 214 'Non-resource Variables are not supported inside XLA computations ' 215 '(operator name: %s)' % op.name) 216 217 if _XLA_COMPILE_ATTR in op.node_def.attr: 218 raise ValueError('XLA compiled computations cannot be nested, (operator ' 219 'name: %s)' % op.name) 220 221 op._set_attr( 222 _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes)) 223 224 op.graph.prevent_feeding(op) 225 op.graph.prevent_fetching(op) 226 227 # Remove any control edges from outer control flow contexts. These may cause 228 # mismatched frame errors. An example is when one of op's inputs is 229 # generated in a different While control flow context. 230 (internal_control_inputs, 231 external_control_inputs) = self._RemoveExternalControlEdges(op) 232 233 if not op.inputs: 234 # Add a control edge from the control pivot to this op. 235 if not internal_control_inputs: 236 # pylint: disable=protected-access 237 op._add_control_input(self._pivot) 238 # pylint: enable=protected-access 239 else: 240 for index in xrange(len(op.inputs)): 241 x = op.inputs[index] 242 real_x = self.AddValue(x) 243 if real_x is not x: 244 op._update_input(index, real_x) # pylint: disable=protected-access 245 246 if external_control_inputs: 247 # Use an identity to pull control inputs as data inputs. Note that we 248 # ignore ops which don't have outputs. TODO(phawkins): fix that. 249 with ops.control_dependencies(None): 250 self.Enter() 251 external_control_inputs = [ 252 array_ops.identity(x.outputs[0]).op 253 for x in external_control_inputs 254 if x.outputs 255 ] 256 self.Exit() 257 # pylint: disable=protected-access 258 op._add_control_inputs(external_control_inputs) 259 # pylint: enable=protected-access 260 261 # Mark op's outputs as seen by this context and any outer contexts. 262 output_names = [x.name for x in op.outputs] 263 context = self 264 while context is not None: 265 # pylint: disable=protected-access 266 context._values.update(output_names) 267 context = context._outer_context 268 # pylint: enable=protected-access 269 270 if self._outer_context: 271 self._outer_context.AddInnerOp(op) 272 273 def AddValue(self, val): 274 """Add `val` to the current context and its outer context recursively.""" 275 if val.name in self._values: 276 # Use the real value if it comes from outer context. 277 result = self._external_values.get(val.name) 278 return val if result is None else result 279 280 result = val 281 self._values.add(val.name) 282 if self._outer_context: 283 result = self._outer_context.AddValue(val) 284 self._values.add(result.name) 285 286 self._external_values[val.name] = result 287 288 return result 289 290 def AddInnerOp(self, op): 291 self.AddOp(op) 292 if self._outer_context: 293 self._outer_context.AddInnerOp(op) 294 295 @property 296 def grad_state(self): 297 # Define the gradient loop state associated with the XLACompileContext to 298 # be None as the XLACompileContext does not get nested nor does the 299 # grad_state outside the XLACompileContext affect the graph inside so the 300 # grad_state should be as if this is the top-level gradient state. 301 return None 302 303 @property 304 def back_prop(self): 305 """Forwards to the enclosing while context, if any.""" 306 if self.GetWhileContext(): 307 return self.GetWhileContext().back_prop 308 return False 309 310 311def _compile_internal(computation, inputs=None): 312 """Builds graph operators that compiles and symbolically executes computation. 313 314 Args: 315 computation: A Python function that builds the computation to compile and 316 execute. 317 inputs: A list of inputs or `None` (equivalent to an empty list). Each input 318 can be a nested structure containing values that are convertible to 319 tensors. Note that passing an N-dimension list of compatible values will 320 result in a N-dimension list of scalar tensors rather than a single Rank-N 321 tensors. If you need different behavior, convert part of inputs to tensors 322 with `tf.convert_to_tensor`. 323 324 Returns: 325 Same data structure as if computation(*inputs) is called directly with some 326 exceptions for correctness. Exceptions include: 1) None output 2) Single 327 value output 3) Operation-only outputs 328 Raises: 329 ValueError: If any element in computation outputs is neither an operations 330 or a value that can be converted to tensor. 331 ValueError: If computation outputs is non-flat and contains any Operations. 332 TypeError: If `inputs` is not a list or tuple. 333 """ 334 if inputs is None: 335 inputs = [] 336 337 if not isinstance(inputs, collections_abc.Sequence): 338 raise TypeError('inputs must be a list') 339 340 # Flatten inputs. 341 flat_inputs = nest.flatten(inputs) 342 # Converts inputs to Tensors. 343 flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs] 344 345 cluster_name = ops.get_default_graph().unique_name('cluster') 346 pivot = control_flow_ops.no_op(name=cluster_name + '/pivot') 347 context = XLACompileContext(name=cluster_name, pivot=pivot) 348 try: 349 context.Enter() 350 351 # Add identity ops so even unused inputs are 'consumed' by the 352 # computation. 353 flat_inputs = [ 354 array_ops.identity(x, name='input_{}'.format(i)) 355 for i, x in enumerate(flat_inputs) 356 ] 357 358 # Re-pack flat_inputs in same structure as 'inputs'. 359 computation_inputs = nest.pack_sequence_as( 360 structure=inputs, flat_sequence=flat_inputs) 361 362 # Only resource variables work inside an XLA computation, so turn on 363 # resource variables for the computation. 364 vscope = variable_scope.get_variable_scope() 365 saved_use_resource = vscope.use_resource 366 vscope.set_use_resource(True) 367 368 with _disable_summary_context(): 369 outputs = computation(*computation_inputs) 370 371 # Restore variable scope after computation. 372 vscope.set_use_resource(saved_use_resource) 373 374 outputs_is_flat = is_flat(outputs) 375 if outputs_is_flat: 376 output_tensors, control_deps = _postprocess_flat_outputs(outputs) 377 else: 378 output_tensors, control_deps = _postprocess_non_flat_outputs(outputs) 379 380 context.ExitResult(output_tensors) 381 finally: 382 context.report_unsupported_operations() 383 context.Exit() 384 385 # When XLA computation returns only operations and no tensors, a NoOp 386 # dependent on the operations in outputs is returned. Otherwise final 387 # outputs would be empty and there is no way to trigger returned 388 # operations. 389 if not output_tensors: 390 return control_flow_ops.group(control_deps, name='output_0') 391 392 output_tensors = [ 393 xla_ops.xla_cluster_output(o, name='output{}'.format(i)) 394 for i, o in enumerate(output_tensors) 395 ] 396 397 with ops.control_dependencies(control_deps): 398 # Wraps the outputs in identity operators that carries control 399 # dependencies. 400 output_tensors = [ 401 array_ops.identity(o, name='output_%d' % i) 402 for i, o in enumerate(output_tensors) 403 ] 404 405 # If `computation` returned non-flat output structure, pack output tensors 406 # back into same structure. 407 if not outputs_is_flat: 408 output_tensors = nest.pack_sequence_as( 409 structure=outputs, flat_sequence=output_tensors) 410 411 return output_tensors 412 413 414def is_flat(outputs): 415 """Checks if outputs is a flat structure. 416 417 Following structures and values are considered flat: 418 1) None 419 2) A single object 420 3) A list or tuple of Tensors/Operations 421 422 The only structures that this function understands are sequences, 423 dictionaries and types defined using the attrs library. E.g. this means 424 that if outputs contains a single user-defined Object, it is considered to 425 be flat. Errors are raised later on if that Object cannot be converted to a 426 Tensor. 427 428 Args: 429 outputs: Output from `computation` inside `xla.compile`. 430 431 Returns: 432 A boolean indicates whether outputs is flat. 433 """ 434 # If outputs is a list or tuple, check if it has any nested structure. If 435 # there is, then outputs is non-flat. 436 if isinstance(outputs, collections_abc.Sequence): 437 for o in outputs: 438 if (isinstance(o, collections_abc.Sequence) or 439 isinstance(o, collections_abc.Mapping) or 440 hasattr(o.__class__, '__attrs_attrs__')): 441 return False 442 443 # If outputs is a dict, it is non-flat. 444 if isinstance(outputs, collections_abc.Mapping): 445 return False 446 447 # If outputs is from the attrs library, it is non-flat. 448 if hasattr(outputs.__class__, '__attrs_attrs__'): 449 return False 450 451 # Getting here means either outputs itself is a single non-structured value 452 # or it is a flat list of single non-structured values. 453 return True 454 455 456def _postprocess_flat_outputs(outputs): 457 """Validates flat outputs and adds back device assignments. 458 459 Args: 460 outputs: Output from `computation` inside `xla.compile`. 461 462 Returns: 463 Tensors and Operations extracted from outputs. 464 """ 465 # Following code segment is to preserve legacy behavior. Previously we only 466 # supported flat outputs and thus for consistency it was nice to convert even 467 # single element into a tuple. But now that we support arbitrary output 468 # structure, this is no longer necessary. 469 # TODO(b/121383831): Migrate all legacy use cases and delete this special 470 # case. 471 # If the computation returns `None`, make it an empty tuple. 472 if outputs is None: 473 outputs = tuple() 474 # If the computation only returned one value, make it a tuple. 475 if not isinstance(outputs, collections_abc.Sequence): 476 outputs = (outputs,) 477 478 # Append `no_op` here so that return value of this function always contains 479 # at least one op that can trigger XlaLaunch node. 480 outputs += (control_flow_ops.no_op(),) 481 try: 482 outputs = [ 483 o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o) 484 for o in outputs 485 ] 486 except Exception as e: 487 raise ValueError( 488 'XLA computation function return values must all either be Operations' 489 ' or convertible to Tensors. Got error: "%s"' % str(e)) 490 491 # Separates the returned Operations and Tensors. 492 output_operations = [o for o in outputs if isinstance(o, ops.Operation)] 493 output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)] 494 495 if outputs != output_tensors + output_operations: 496 raise ValueError( 497 'XLA computation function must return zero or more Tensor values ' 498 'followed by zero or more Operations.') 499 500 new_output_tensors = [] 501 for t in output_tensors: 502 with ops.device(t.device if t.device else ''): 503 new_output_tensors.append(array_ops.identity(t)) 504 505 return new_output_tensors, output_operations 506 507 508def _postprocess_non_flat_outputs(outputs): 509 """Validates non-flat outputs and adds back device assignments. 510 511 Args: 512 outputs: Output from `computation` inside `xla.compile`. 513 514 Returns: 515 Tensors extracted from outputs and an empty list because Operations are not 516 allowed in non-flat outputs.. 517 """ 518 # Convert all non-Operation outputs to Tensors. 519 new_output_tensors = [] 520 for o in nest.flatten(outputs): 521 if isinstance(o, ops.Operation): 522 raise ValueError( 523 'xla.compile does not support Operation as return value in non-flat ' 524 'output structure. You can set returned Operations as control ' 525 'dependencies of returned Tensors so Operations are triggered when ' 526 'Tensors are evaluated. Operation found: "%s"' % o.name) 527 528 try: 529 o = ops.convert_to_tensor(o) 530 except Exception as e: 531 raise ValueError( 532 'XLA computation function return values must all either be ' 533 'Operations or convertible to Tensors. Got error: "%s"' % str(e)) 534 535 # Makes sure even pass-through inputs/outputs are touched in compile 536 # context by creating an Identity node inside compile context. 537 with ops.device(o.device if o.device else ''): 538 new_output_tensors.append(array_ops.identity(o)) 539 540 return new_output_tensors, [] 541 542 543@contextlib.contextmanager 544def _disable_summary_context(): 545 """Enters a context where all summary ops are skipped. 546 547 Summaries are not yet supported in xla.compile(). So we provide this context 548 manager that can skip creating summary ops. This is a temporary workaround due 549 to XLA not supporting summary ops. 550 551 Yields: 552 None. 553 """ 554 original_skip_summary_func = summary_op_util.skip_summary 555 summary_op_util.skip_summary = lambda: True 556 557 try: 558 yield 559 finally: 560 summary_op_util.skip_summary = original_skip_summary_func 561 562 563class _CapturedObject(object): 564 """A placeholder to capture an object.""" 565 566 def __init__(self): 567 self._object = None 568 569 def capture(self, o): 570 if self._object: 571 raise RuntimeError( 572 'InternalError: _CapturedObject can capture only once. Please file ' 573 'bug.') 574 575 self._object = o 576 577 def get(self): 578 return self._object 579 580 581def _get_scaffold(captured_scaffold_fn): 582 """Retrieves the Scaffold from `captured_scaffold_fn`.""" 583 scaffold_fn = captured_scaffold_fn.get() 584 585 if not scaffold_fn: 586 return None 587 588 scaffold = scaffold_fn() 589 if scaffold is None: 590 raise ValueError( 591 'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed') 592 593 return scaffold 594 595 596def check_function_argument_count(func, input_arity, infeed_queue): 597 """Validate the number of input arguments to an XLA function. 598 599 Args: 600 func: the Python function that will be called to generate the body of an XLA 601 computation graph. 602 input_arity: the number of explicit arguments supplied by the caller. 603 infeed_queue: if not None, the infeed queue that will supply 604 additional arguments to the function. 605 606 Returns: 607 None if function can be called with the supplied number of 608 arguments, or an error string if it cannot. 609 """ 610 def format_error(complaint, quantity): 611 return '%s %d argument%s' % (complaint, quantity, '' 612 if quantity == 1 else 's') 613 614 num_args_supplied = input_arity 615 if infeed_queue is not None: 616 num_args_supplied += infeed_queue.number_of_tuple_elements 617 arg_spec = tf_inspect.getargspec(func) 618 num_func_args = len(arg_spec.args) 619 if arg_spec.defaults is None: 620 num_func_defaults = 0 621 else: 622 num_func_defaults = len(arg_spec.defaults) 623 min_func_args = num_func_args - num_func_defaults 624 if num_args_supplied < min_func_args: 625 # The required number of arguments is not enough to call the function. 626 if num_func_defaults == 0 and arg_spec.varargs is None: 627 return format_error('exactly', num_func_args) 628 else: 629 return format_error('at least', min_func_args) 630 if arg_spec.varargs is None and num_args_supplied > num_func_args: 631 # The required number of arguments is too many to call the function. 632 if num_func_defaults == 0: 633 return format_error('exactly', num_func_args) 634 else: 635 return format_error('at most', num_func_args) 636 # Reaching here means either 637 # 1) There are varargs, func can accept any number of arguments greater than 638 # the minimum. 639 # 2) Number of supplied arguments falls in range of acceptable argument count 640 # of func. 641 return None 642