• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# =============================================================================
15"""xla is an experimental library that provides XLA support APIs."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import collections
22import contextlib
23
24from six.moves import xrange  # pylint: disable=redefined-builtin
25
26from tensorflow.compiler.jit.ops import xla_ops
27from tensorflow.compiler.jit.ops import xla_ops_grad  # pylint: disable=unused-import
28from tensorflow.core.framework import attr_value_pb2
29from tensorflow.python.distribute import summary_op_util
30from tensorflow.python.eager import context
31from tensorflow.python.eager import def_function
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import control_flow_ops
35from tensorflow.python.ops import variable_scope
36from tensorflow.python.platform import tf_logging as logging
37from tensorflow.python.util import compat
38from tensorflow.python.util import nest
39from tensorflow.python.util import tf_inspect
40from tensorflow.python.util.tf_export import tf_export
41
42_XLA_COMPILE_ATTR = '_xla_compile_id'
43_MAX_WARNING_LINES = 5
44
45# Operations that indicate some error in the users graph. For example, XLA
46# computation should not have any Placeholder op.
47_BLACKLISTED_OPS = set([
48    'Placeholder',
49])
50
51# XLA doesn't currently support reading of intermediate tensors, thus some ops
52# are not supported.
53_UNSUPPORTED_OPS = set([
54    'AudioSummary',
55    'AudioSummaryV2',
56    'HistogramSummary',
57    'ImageSummary',
58    'MergeSummary',
59    'Print',
60    'ScalarSummary',
61    'TensorSummary',
62    'TensorSummaryV2',
63])
64
65
66@tf_export('xla.experimental.compile')
67def compile(computation, inputs=None):  # pylint: disable=redefined-builtin
68  """Builds an operator that compiles and runs `computation` with XLA.
69
70  NOTE: In eager mode, `computation` will have `@tf.function` semantics.
71
72  Args:
73    computation: A Python function that builds a computation to apply to the
74      input. If the function takes n inputs, 'inputs' should be a list of n
75      tensors.
76
77      `computation` may return a list of operations and tensors.  Tensors must
78      come before operations in the returned list.  The return value of
79      `compile` is a list of tensors corresponding to the tensors from the
80      output of `computation`.
81
82      All `Operation`s returned from `computation` will be executed when
83      evaluating any of the returned output tensors.
84    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
85      can be a nested structure containing values that are convertible to
86      tensors. Note that passing an N-dimension list of compatible values will
87      result in a N-dimension list of scalar tensors rather than a single Rank-N
88      tensors. If you need different behavior, convert part of inputs to tensors
89      with `tf.convert_to_tensor`.
90
91  Returns:
92    Same data structure as if computation(*inputs) is called directly with some
93    exceptions for correctness. Exceptions include:
94      1) None output: a NoOp would be returned which control-depends on
95         computation.
96      2) Single value output: A tuple containing the value would be returned.
97      3) Operation-only outputs: a NoOp would be returned which
98         control-depends on computation.
99      TODO(b/121383831): Investigate into removing these special cases.
100
101  Raises:
102    RuntimeError: if called when eager execution is enabled.
103
104  Known issues:
105    When a tf.random operation is built with XLA, the implementation doesn't
106      pass the user provided seed to the XLA compiler. As such, the XLA compiler
107      generates a random number and uses it as a seed when compiling the
108      operation. This implementation causes a violation of the Tensorflow
109      defined semantics in two aspects. First, changing the value of the user
110      defined seed doesn't change the numbers generated by the operation.
111      Second, when a seed is not specified, running the program multiple times
112      will generate the same numbers.
113
114  """
115  if context.executing_eagerly():
116    @def_function.function
117    def xla_compile_wrapper():
118      return _compile_internal(computation, inputs)
119
120    return xla_compile_wrapper()
121
122  return _compile_internal(computation, inputs)
123
124
125class XLACompileContext(control_flow_ops.XLAControlFlowContext):
126  """A `ControlFlowContext` for nodes inside an XLA computation cluster.
127
128  THIS IS ONLY FOR TENSORFLOW INTERNAL IMPLEMENTATION, DO NO USE DIRECTLY.
129
130  The primary role of `XLACompileContext` is to mark operators inside a
131  xla.compile() computation with attribute "_xla_compile_id=XYZ", where XYZ is
132  a unique name.
133
134  `ControlFlowContext` is used to perform the annotation since it integrates
135  with Tensorflow constructs like ResourceVariables. For example, if a
136  `ResourceVariable` is constructed inside a xla.compile() block, the
137  `ResourceVariable` implementation can use
138  `with ops.control_dependencies(None)` to build the variable's definition
139  outside the compiled computation.
140  """
141
142  def __init__(self, name, pivot):
143    """Builds a new XLACompileContext.
144
145    Args:
146      name: a unique name for the context, used to populate the
147        `_xla_compile_id` attribute.
148      pivot: a pivot node. Nodes in the XLACompileContext that do not have any
149        inputs will have a control dependency on the pivot node. This ensures
150        that nodes are correctly included in any enclosing control flow
151        contexts.
152    """
153    super(XLACompileContext, self).__init__()
154    self._name = name
155    self._name_as_bytes = compat.as_bytes(name)
156    self._unsupported_ops = []
157    self._pivot = pivot
158
159  def report_unsupported_operations(self):
160    if self._unsupported_ops:
161      op_str = '\n'.join([
162          '  %s (%s)' % (op.type, op.name)
163          for op in self._unsupported_ops[:_MAX_WARNING_LINES]
164      ])
165      logging.warning('%d unsupported operations found: \n%s',
166                      len(self._unsupported_ops), op_str)
167      if len(self._unsupported_ops) > _MAX_WARNING_LINES:
168        logging.warning('... and %d more',
169                        len(self._unsupported_ops) - _MAX_WARNING_LINES)
170
171  def _RemoveExternalControlEdges(self, op):
172    """Remove any external control dependency on this op."""
173    internal_control_inputs = []
174    external_control_inputs = []
175    for x in op.control_inputs:
176      # pylint: disable=protected-access
177      is_internal_op = False
178      ctxt = x._get_control_flow_context()
179      while ctxt is not None:
180        if ctxt == self:
181          is_internal_op = True
182          break
183        ctxt = ctxt._outer_context
184      if is_internal_op:
185        internal_control_inputs.append(x)
186      else:
187        external_control_inputs.append(x)
188      # pylint: enable=protected-access
189    # pylint: disable=protected-access
190    op._remove_all_control_inputs()
191    op._add_control_inputs(internal_control_inputs)
192    # pylint: enable=protected-access
193    return internal_control_inputs, external_control_inputs
194
195  def AddOp(self, op):
196    """Create op in XLACompileContext and notifies outer context recursively."""
197    # pylint: disable=protected-access
198    if op.type in _BLACKLISTED_OPS:
199      logging.error(
200          'Operation of type %s (%s) is not supported in XLA. Execution will '
201          'fail if this op is used in the graph. ', op.type, op.name)
202
203    # TODO(ycao): Automatically disable summaries instead of reporting them.
204    if op.type in _UNSUPPORTED_OPS:
205      self._unsupported_ops.append(op)
206
207    if any(x.dtype._is_ref_dtype for x in op.inputs):
208      raise NotImplementedError(
209          'Non-resource Variables are not supported inside XLA computations '
210          '(operator name: %s)' % op.name)
211
212    if _XLA_COMPILE_ATTR in op.node_def.attr:
213      raise ValueError('XLA compiled computations cannot be nested, (operator '
214                       'name: %s)' % op.name)
215
216    op._set_attr(
217        _XLA_COMPILE_ATTR, attr_value_pb2.AttrValue(s=self._name_as_bytes))
218
219    op.graph.prevent_feeding(op)
220    op.graph.prevent_fetching(op)
221
222    # Remove any control edges from outer control flow contexts. These may cause
223    # mismatched frame errors. An example is when one of op's inputs is
224    # generated in a different While control flow context.
225    (internal_control_inputs,
226     external_control_inputs) = self._RemoveExternalControlEdges(op)
227
228    if not op.inputs:
229      # Add a control edge from the control pivot to this op.
230      if not internal_control_inputs:
231        # pylint: disable=protected-access
232        op._add_control_input(self._pivot)
233        # pylint: enable=protected-access
234    else:
235      for index in xrange(len(op.inputs)):
236        x = op.inputs[index]
237        real_x = self.AddValue(x)
238        if real_x is not x:
239          op._update_input(index, real_x)  # pylint: disable=protected-access
240
241    if external_control_inputs:
242      # Use an identity to pull control inputs as data inputs. Note that we
243      # ignore ops which don't have outputs. TODO(phawkins): fix that.
244      with ops.control_dependencies(None):
245        self.Enter()
246        external_control_inputs = [
247            array_ops.identity(x.outputs[0]).op
248            for x in external_control_inputs
249            if x.outputs
250        ]
251        self.Exit()
252      # pylint: disable=protected-access
253      op._add_control_inputs(external_control_inputs)
254      # pylint: enable=protected-access
255
256    # Mark op's outputs as seen by this context and any outer contexts.
257    output_names = [x.name for x in op.outputs]
258    context = self
259    while context is not None:
260      # pylint: disable=protected-access
261      context._values.update(output_names)
262      context = context._outer_context
263      # pylint: enable=protected-access
264
265    if self._outer_context:
266      self._outer_context.AddInnerOp(op)
267
268  def AddValue(self, val):
269    """Add `val` to the current context and its outer context recursively."""
270    if val.name in self._values:
271      # Use the real value if it comes from outer context.
272      result = self._external_values.get(val.name)
273      return val if result is None else result
274
275    result = val
276    self._values.add(val.name)
277    if self._outer_context:
278      result = self._outer_context.AddValue(val)
279      self._values.add(result.name)
280
281    self._external_values[val.name] = result
282
283    return result
284
285  def AddInnerOp(self, op):
286    self.AddOp(op)
287    if self._outer_context:
288      self._outer_context.AddInnerOp(op)
289
290  @property
291  def grad_state(self):
292    # Define the gradient loop state associated with the XLACompileContext to
293    # be None as the XLACompileContext does not get nested nor does the
294    # grad_state outside the XLACompileContext affect the graph inside so the
295    # grad_state should be as if this is the top-level gradient state.
296    return None
297
298  @property
299  def back_prop(self):
300    """Forwards to the enclosing while context, if any."""
301    if self.GetWhileContext():
302      return self.GetWhileContext().back_prop
303    return False
304
305
306def _compile_internal(computation, inputs=None):
307  """Builds graph operators that compiles and symbolically executes computation.
308
309  Args:
310    computation: A Python function that builds the computation to compile and
311      execute.
312    inputs: A list of inputs or `None` (equivalent to an empty list). Each input
313      can be a nested structure containing values that are convertible to
314      tensors. Note that passing an N-dimension list of compatible values will
315      result in a N-dimension list of scalar tensors rather than a single Rank-N
316      tensors. If you need different behavior, convert part of inputs to tensors
317      with `tf.convert_to_tensor`.
318
319  Returns:
320    Same data structure as if computation(*inputs) is called directly with some
321    exceptions for correctness. Exceptions include: 1) None output 2) Single
322    value output 3) Operation-only outputs
323  Raises:
324    ValueError: If any element in computation outputs is neither an operations
325      or a value that can be converted to tensor.
326    ValueError: If computation outputs is non-flat and contains any Operations.
327    TypeError: If `inputs` is not a list or tuple.
328  """
329  if inputs is None:
330    inputs = []
331
332  if not isinstance(inputs, collections.Sequence):
333    raise TypeError('inputs must be a list')
334
335  # Flatten inputs.
336  flat_inputs = nest.flatten(inputs)
337  # Converts inputs to Tensors.
338  flat_inputs = [ops.convert_to_tensor(x) for x in flat_inputs]
339
340  cluster_name = ops.get_default_graph().unique_name('cluster')
341  pivot = control_flow_ops.no_op(name=cluster_name + '/pivot')
342  context = XLACompileContext(name=cluster_name, pivot=pivot)
343  try:
344    context.Enter()
345
346    # Add identity ops so even unused inputs are 'consumed' by the
347    # computation.
348    flat_inputs = [
349        array_ops.identity(x, name='input_{}'.format(i))
350        for i, x in enumerate(flat_inputs)
351    ]
352
353    # Re-pack flat_inputs in same structure as 'inputs'.
354    computation_inputs = nest.pack_sequence_as(
355        structure=inputs, flat_sequence=flat_inputs)
356
357    # Only resource variables work inside an XLA computation, so turn on
358    # resource variables for the computation.
359    vscope = variable_scope.get_variable_scope()
360    saved_use_resource = vscope.use_resource
361    vscope.set_use_resource(True)
362
363    with _disable_summary_context():
364      outputs = computation(*computation_inputs)
365
366    # Restore variable scope after computation.
367    vscope.set_use_resource(saved_use_resource)
368
369    outputs_is_flat = is_flat(outputs)
370    if outputs_is_flat:
371      output_tensors, control_deps = _postprocess_flat_outputs(outputs)
372    else:
373      output_tensors, control_deps = _postprocess_non_flat_outputs(outputs)
374
375    context.ExitResult(output_tensors)
376  finally:
377    context.report_unsupported_operations()
378    context.Exit()
379
380  # When XLA computation returns only operations and no tensors, a NoOp
381  # dependent on the operations in outputs is returned. Otherwise final
382  # outputs would be empty and there is no way to trigger returned
383  # operations.
384  if not output_tensors:
385    return control_flow_ops.group(control_deps, name='output_0')
386
387  output_tensors = [
388      xla_ops.xla_cluster_output(o, name='output{}'.format(i))
389      for i, o in enumerate(output_tensors)
390  ]
391
392  with ops.control_dependencies(control_deps):
393    # Wraps the outputs in identity operators that carries control
394    # dependencies.
395    output_tensors = [
396        array_ops.identity(o, name='output_%d' % i)
397        for i, o in enumerate(output_tensors)
398    ]
399
400  # If `computation` returned non-flat output structure, pack output tensors
401  # back into same structure.
402  if not outputs_is_flat:
403    output_tensors = nest.pack_sequence_as(
404        structure=outputs, flat_sequence=output_tensors)
405
406  return output_tensors
407
408
409def is_flat(outputs):
410  """Checks if outputs is a flat structure.
411
412    Following structures and values are considered flat:
413    1) None
414    2) A single object
415    3) A list or tuple of Tensors/Operations
416
417    The only structures that this function understands are sequences and
418    dictionaries.  E.g. this means that if outputs contains a single
419    user-defined Object, it is considered to be flat. Errors are raised later on
420    if that Object cannot be converted to a Tensor.
421
422  Args:
423    outputs: Output from `computation` inside `xla.compile`.
424
425  Returns:
426    A boolean indicates whether outputs is flat.
427  """
428  # If outputs is a list or tuple, check if it has any nested structure. If
429  # there is, then outputs is non-flat.
430  if isinstance(outputs, collections.Sequence):
431    for o in outputs:
432      if isinstance(o, collections.Sequence) or isinstance(o, dict):
433        return False
434
435  # If outputs is a dict, it is non-flat.
436  if isinstance(outputs, dict):
437    return False
438
439  # Getting here means either outputs itself is a single non-structured value
440  # or it is a flat list of single non-structured values.
441  return True
442
443
444def _postprocess_flat_outputs(outputs):
445  """Validates flat outputs and adds back device assignments.
446
447  Args:
448    outputs: Output from `computation` inside `xla.compile`.
449
450  Returns:
451    Tensors and Operations extracted from outputs.
452  """
453  # Following code segment is to preserve legacy behavior. Previously we only
454  # supported flat outputs and thus for consistency it was nice to convert even
455  # single element into a tuple. But now that we support arbitrary output
456  # structure, this is no longer necessary.
457  # TODO(b/121383831): Migrate all legacy use cases and delete this special
458  # case.
459  # If the computation returns `None`, make it an empty tuple.
460  if outputs is None:
461    outputs = tuple()
462  # If the computation only returned one value, make it a tuple.
463  if not isinstance(outputs, collections.Sequence):
464    outputs = (outputs,)
465
466  # Append `no_op` here so that return value of this function always contains
467  # at least one op that can trigger XlaLaunch node.
468  outputs += (control_flow_ops.no_op(),)
469  try:
470    outputs = [
471        o if isinstance(o, ops.Operation) else ops.convert_to_tensor(o)
472        for o in outputs
473    ]
474  except Exception as e:
475    raise ValueError(
476        'XLA computation function return values must all either be Operations'
477        ' or convertible to Tensors. Got error: "%s"' % str(e))
478
479  # Separates the returned Operations and Tensors.
480  output_operations = [o for o in outputs if isinstance(o, ops.Operation)]
481  output_tensors = [o for o in outputs if not isinstance(o, ops.Operation)]
482
483  if outputs != output_tensors + output_operations:
484    raise ValueError(
485        'XLA computation function must return zero or more Tensor values '
486        'followed by zero or more Operations.')
487
488  new_output_tensors = []
489  for t in output_tensors:
490    with ops.device(t.device if t.device else ''):
491      new_output_tensors.append(array_ops.identity(t))
492
493  return new_output_tensors, output_operations
494
495
496def _postprocess_non_flat_outputs(outputs):
497  """Validates non-flat outputs and adds back device assignments.
498
499  Args:
500    outputs: Output from `computation` inside `xla.compile`.
501
502  Returns:
503    Tensors extracted from outputs and an empty list because Operations are not
504    allowed in non-flat outputs..
505  """
506  # Convert all non-Operation outputs to Tensors.
507  new_output_tensors = []
508  for o in nest.flatten(outputs):
509    if isinstance(o, ops.Operation):
510      raise ValueError(
511          'xla.compile does not support Operation as return value in non-flat '
512          'output structure. You can set returned Operations as control '
513          'dependencies of returned Tensors so Operations are triggered when '
514          'Tensors are evaluated. Operation found: "%s"' % o.name)
515
516    try:
517      o = ops.convert_to_tensor(o)
518    except Exception as e:
519      raise ValueError(
520          'XLA computation function return values must all either be '
521          'Operations or convertible to Tensors. Got error: "%s"' % str(e))
522
523    # Makes sure even pass-through inputs/outputs are touched in compile
524    # context by creating an Identity node inside compile context.
525    with ops.device(o.device if o.device else ''):
526      new_output_tensors.append(array_ops.identity(o))
527
528  return new_output_tensors, []
529
530
531@contextlib.contextmanager
532def _disable_summary_context():
533  """Enters a context where all summary ops are skipped.
534
535  Summaries are not yet supported in xla.compile(). So we provide this context
536  manager that can skip creating summary ops. This is a temporary workaround due
537  to XLA not supporting summary ops.
538
539  Yields:
540    None.
541  """
542  original_skip_summary_func = summary_op_util.skip_summary
543  summary_op_util.skip_summary = lambda: True
544
545  try:
546    yield
547  finally:
548    summary_op_util.skip_summary = original_skip_summary_func
549
550
551class _CapturedObject(object):
552  """A placeholder to capture an object."""
553
554  def __init__(self):
555    self._object = None
556
557  def capture(self, o):
558    if self._object:
559      raise RuntimeError(
560          'InternalError: _CapturedObject can capture only once. Please file '
561          'bug.')
562
563    self._object = o
564
565  def get(self):
566    return self._object
567
568
569def _get_scaffold(captured_scaffold_fn):
570  """Retrieves the Scaffold from `captured_scaffold_fn`."""
571  scaffold_fn = captured_scaffold_fn.get()
572
573  if not scaffold_fn:
574    return None
575
576  scaffold = scaffold_fn()
577  if scaffold is None:
578    raise ValueError(
579        'TPUEstimatorSpec.scaffold_fn returns None, which is not allowed')
580
581  return scaffold
582
583
584def check_function_argument_count(func, input_arity, infeed_queue):
585  """Validate the number of input arguments to an XLA function.
586
587  Args:
588    func: the Python function that will be called to generate the body of an XLA
589      computation graph.
590    input_arity: the number of explicit arguments supplied by the caller.
591    infeed_queue: if not None, the infeed queue that will supply
592      additional arguments to the function.
593
594  Returns:
595    None if function can be called with the supplied number of
596      arguments, or an error string if it cannot.
597  """
598  def format_error(complaint, quantity):
599    return '%s %d argument%s' % (complaint, quantity, ''
600                                 if quantity == 1 else 's')
601
602  num_args_supplied = input_arity
603  if infeed_queue is not None:
604    num_args_supplied += infeed_queue.number_of_tuple_elements
605  arg_spec = tf_inspect.getargspec(func)
606  num_func_args = len(arg_spec.args)
607  if arg_spec.defaults is None:
608    num_func_defaults = 0
609  else:
610    num_func_defaults = len(arg_spec.defaults)
611  min_func_args = num_func_args - num_func_defaults
612  if num_args_supplied < min_func_args:
613    # The required number of arguments is not enough to call the function.
614    if num_func_defaults == 0 and arg_spec.varargs is None:
615      return format_error('exactly', num_func_args)
616    else:
617      return format_error('at least', min_func_args)
618  if arg_spec.varargs is None and num_args_supplied > num_func_args:
619    # The required number of arguments is too many to call the function.
620    if num_func_defaults == 0:
621      return format_error('exactly', num_func_args)
622    else:
623      return format_error('at most', num_func_args)
624  # Reaching here means either
625  # 1) There are varargs, func can accept any number of arguments greater than
626  # the minimum.
627  # 2) Number of supplied arguments falls in range of acceptable argument count
628  # of func.
629  return None
630