• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental library that exposes XLA operations directly in TensorFlow.
16
17It is sometimes useful to be able to build HLO programs directly from
18TensorFlow. This file provides Tensorflow operators that mirror the semantics of
19HLO operators as closely as possible.
20
21Note: Most of the operators defined in this module are used by the jax2tf
22converter (see go/jax2tf for details) and are used in SavedModel produced
23by jax2tf. Hence, we need to maintain backwards compatibility for these
24operators. Please reach out to the JAX team if you want to make changes.
25"""
26
27from tensorflow.compiler.tf2xla.ops import gen_xla_ops
28from tensorflow.compiler.xla import xla_data_pb2
29from tensorflow.core.framework import attr_value_pb2
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import ops
33from tensorflow.python.ops import array_ops
34from tensorflow.python.ops import bitwise_ops
35from tensorflow.python.ops import gen_math_ops
36from tensorflow.python.ops import gen_random_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import random_ops
39from tensorflow.python.ops import special_math_ops
40from tensorflow.python.ops import stateless_random_ops
41from tensorflow.python.ops.numpy_ops import np_utils
42
43# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
44# ops include:
45# infeed/outfeed (available via tf.contrib.tpu)
46# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
47# conditional
48# gather/scatter
49# collapse
50
51# This file reuses builtin names (following XLA's names, so we can call things
52# like xla.max), so we capture the builtin versions here.
53# pylint: disable=redefined-builtin
54_max = max
55_min = min
56_slice = slice  # pylint: disable=invalid-name
57
58constant = constant_op.constant
59
60# Unary operators.
61
62# For most arithmetic operators there is a TensorFlow operator
63# that exactly corresponds to each XLA operator. Rather than defining
64# XLA-specific variants, we reuse the corresponding TensorFlow operator.
65# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
66# wrap every HLO operator, because that would allow us to be confident that the
67# semantics match.
68
69
70def _unary_op(fn):
71  """Wrapper that restricts `fn` to have the correct signature."""
72
73  def unary_op_wrapper(x, name=None):
74    return fn(x, name=name)
75
76  return unary_op_wrapper
77
78
79abs = _unary_op(math_ops.abs)
80# TODO(phawkins): implement clz.
81conj = _unary_op(math_ops.conj)
82cos = _unary_op(math_ops.cos)
83ceil = _unary_op(math_ops.ceil)
84digamma = _unary_op(math_ops.digamma)
85erf = _unary_op(math_ops.erf)
86erfc = _unary_op(math_ops.erfc)
87erfinv = _unary_op(math_ops.erfinv)
88ndtri = _unary_op(math_ops.ndtri)
89exp = _unary_op(math_ops.exp)
90expm1 = _unary_op(math_ops.expm1)
91floor = _unary_op(math_ops.floor)
92imag = _unary_op(math_ops.imag)
93is_finite = _unary_op(math_ops.is_finite)
94lgamma = _unary_op(math_ops.lgamma)
95log = _unary_op(math_ops.log)
96log1p = _unary_op(math_ops.log1p)
97logical_not = _unary_op(math_ops.logical_not)
98neg = _unary_op(math_ops.neg)
99real = _unary_op(math_ops.real)
100# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
101# numbers halfway between two integers.
102round = _unary_op(math_ops.round)
103sin = _unary_op(math_ops.sin)
104sign = _unary_op(math_ops.sign)
105tanh = _unary_op(math_ops.tanh)
106
107# Bessel
108bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
109bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
110
111# Binary operators
112
113# The main difference between TensorFlow and XLA binary ops is the broadcasting
114# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
115# requires an explicit specification of which dimensions to broadcast if the
116# arguments have different ranks.
117
118
119def _broadcasting_binary_op(fn):
120  """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
121
122  def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
123    """Inner wrapper function."""
124    broadcast_dims = broadcast_dims or []
125    broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
126    # Rather than relying on having static shape information in the TensorFlow
127    # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
128    # at JIT compilation time.
129    x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
130    return fn(x, y, name=name)
131
132  return broadcasting_binary_op_wrapper
133
134
135# Map from TF signed types to TF unsigned types.
136_SIGNED_TO_UNSIGNED_TABLE = {
137    dtypes.int8: dtypes.uint8,
138    dtypes.int16: dtypes.uint16,
139    dtypes.int32: dtypes.uint32,
140    dtypes.int64: dtypes.uint64,
141}
142
143# Map from TF unsigned types to TF signed types.
144_UNSIGNED_TO_SIGNED_TABLE = {
145    dtypes.uint8: dtypes.int8,
146    dtypes.uint16: dtypes.int16,
147    dtypes.uint32: dtypes.int32,
148    dtypes.uint64: dtypes.int64,
149}
150
151
152def _shift_right_logical_helper(x, y, name=None):
153  """Performs an integer right logical shift irrespective of input type."""
154  assert y.dtype == x.dtype
155  dtype = x.dtype
156  signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
157  if signed:
158    unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
159    x = math_ops.cast(x, unsigned_dtype)
160    y = math_ops.cast(y, unsigned_dtype)
161  output = bitwise_ops.right_shift(x, y, name=name)
162  if signed:
163    output = math_ops.cast(output, dtype)
164  return output
165
166
167def _shift_right_arithmetic_helper(x, y, name=None):
168  """Performs an integer right arithmetic shift irrespective of input type."""
169  assert y.dtype == x.dtype
170  dtype = x.dtype
171  unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
172  if unsigned:
173    signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
174    x = math_ops.cast(x, signed_dtype)
175    y = math_ops.cast(y, signed_dtype)
176  output = bitwise_ops.right_shift(x, y, name=name)
177  if unsigned:
178    output = math_ops.cast(output, dtype)
179  return output
180
181
182add = _broadcasting_binary_op(math_ops.add)
183sub = _broadcasting_binary_op(math_ops.sub)
184mul = _broadcasting_binary_op(math_ops.mul)
185div = _broadcasting_binary_op(math_ops.div)
186rem = _broadcasting_binary_op(gen_math_ops.mod)
187max = _broadcasting_binary_op(math_ops.maximum)
188min = _broadcasting_binary_op(math_ops.minimum)
189atan2 = _broadcasting_binary_op(math_ops.atan2)
190complex = _broadcasting_binary_op(math_ops.complex)
191logical_and = _broadcasting_binary_op(math_ops.logical_and)
192logical_or = _broadcasting_binary_op(math_ops.logical_or)
193logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
194eq = _broadcasting_binary_op(math_ops.equal)
195ne = _broadcasting_binary_op(math_ops.not_equal)
196ge = _broadcasting_binary_op(math_ops.greater_equal)
197gt = _broadcasting_binary_op(math_ops.greater)
198le = _broadcasting_binary_op(math_ops.less_equal)
199lt = _broadcasting_binary_op(math_ops.less)
200pow = _broadcasting_binary_op(math_ops.pow)
201shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
202shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
203shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
204
205igamma = _broadcasting_binary_op(math_ops.igamma)
206igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
207random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
208igammac = _broadcasting_binary_op(math_ops.igammac)
209polygamma = _broadcasting_binary_op(math_ops.polygamma)
210zeta = _broadcasting_binary_op(math_ops.zeta)
211
212
213def _binary_op(fn):
214  """Wrapper that restricts `fn` to have the correct signature."""
215
216  def binary_op_wrapper(x, y, name=None):
217    return fn(x, y, name=name)
218
219  return binary_op_wrapper
220
221
222transpose = _binary_op(array_ops.transpose)
223rev = _binary_op(array_ops.reverse)
224
225bitcast_convert_type = array_ops.bitcast
226
227
228def broadcast(x, dims, name=None):
229  x = ops.convert_to_tensor(x)
230  shape = array_ops.concat([constant_op.constant(dims),
231                            array_ops.shape(x)],
232                           axis=0)
233  return array_ops.broadcast_to(x, shape, name=name)
234
235
236def clamp(a, x, b, name=None):
237  return min(max(a, x, name=name), b, name=name)
238
239
240concatenate = array_ops.concat
241
242
243def conv(lhs,
244         rhs,
245         window_strides,
246         padding,
247         lhs_dilation,
248         rhs_dilation,
249         dimension_numbers,
250         feature_group_count=1,
251         precision_config=None,
252         preferred_element_type=None,
253         name=None,
254         use_v2=False,
255         batch_group_count=1):
256  """Wraps the XLA ConvGeneralDilated operator.
257
258  ConvGeneralDilated is the most general form of XLA convolution and is
259  documented at
260  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
261
262  Args:
263    lhs: the input tensor
264    rhs: the kernel tensor
265    window_strides: the inter-window strides
266    padding: the padding to apply at the start and end of each input dimensions
267    lhs_dilation: dilation to apply between input elements
268    rhs_dilation: dilation to apply between kernel elements
269    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
270    feature_group_count: number of feature groups for grouped convolution.
271    precision_config: a `xla.PrecisionConfig` proto.
272    preferred_element_type: the result `dtype`.
273    name: an optional name for the operator.
274    use_v2: an optional request to use the XlaConvV2 op even if not necessary.
275    batch_group_count: number of batch groups or grouped filters.
276
277  Returns:
278    A tensor representing the output of the convolution.
279  """
280  precision_config_proto = ""
281  if precision_config:
282    precision_config_proto = precision_config.SerializeToString()
283  needs_v2 = (
284      preferred_element_type or (lhs.dtype != rhs.dtype) or
285      batch_group_count > 1)
286  if preferred_element_type is None:
287    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
288  if needs_v2 or use_v2:
289    return gen_xla_ops.xla_conv_v2(
290        lhs,
291        rhs,
292        window_strides=window_strides,
293        padding=padding,
294        lhs_dilation=lhs_dilation,
295        rhs_dilation=rhs_dilation,
296        feature_group_count=feature_group_count,
297        batch_group_count=batch_group_count,
298        dimension_numbers=dimension_numbers.SerializeToString(),
299        precision_config=precision_config_proto,
300        preferred_element_type=preferred_element_type,
301        name=name)
302  return gen_xla_ops.xla_conv(
303      lhs,
304      rhs,
305      window_strides=window_strides,
306      padding=padding,
307      lhs_dilation=lhs_dilation,
308      rhs_dilation=rhs_dilation,
309      feature_group_count=feature_group_count,
310      dimension_numbers=dimension_numbers.SerializeToString(),
311      precision_config=precision_config_proto,
312      name=name)
313
314
315convert_element_type = math_ops.cast
316
317
318def dot(lhs, rhs, name=None):
319  return math_ops.tensordot(lhs, rhs, axes=1, name=name)
320
321
322DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers
323PrecisionConfig = xla_data_pb2.PrecisionConfig
324
325
326def dot_general(lhs,
327                rhs,
328                dimension_numbers,
329                precision_config=None,
330                preferred_element_type=None,
331                name=None,
332                use_v2=False):
333  precision_config_proto = ""
334  if precision_config:
335    precision_config_proto = precision_config.SerializeToString()
336  needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
337  if preferred_element_type is None:
338    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
339  if needs_v2 or use_v2:
340    return gen_xla_ops.xla_dot_v2(
341        lhs,
342        rhs,
343        dimension_numbers=dimension_numbers.SerializeToString(),
344        precision_config=precision_config_proto,
345        preferred_element_type=preferred_element_type,
346        name=name)
347  return gen_xla_ops.xla_dot(
348      lhs,
349      rhs,
350      dimension_numbers=dimension_numbers.SerializeToString(),
351      precision_config=precision_config_proto,
352      name=name)
353
354
355def self_adjoint_eig(a, lower, max_iter, epsilon):
356  return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
357
358
359def svd(a, max_iter, epsilon, precision_config=None):
360  precision_config_proto = ""
361  if precision_config:
362    precision_config_proto = precision_config.SerializeToString()
363  return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
364
365
366dynamic_slice = gen_xla_ops.xla_dynamic_slice
367dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
368einsum = gen_xla_ops.xla_einsum
369
370# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
371# the XLA-specific pad operator.
372pad = gen_xla_ops.xla_pad
373
374
375def random_normal(mu, sigma, dims, name=None):
376  mu = ops.convert_to_tensor(mu)
377  return random_ops.random_normal(
378      dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
379
380
381def random_uniform(minval, maxval, dims, name=None):
382  minval = ops.convert_to_tensor(minval)
383  return random_ops.random_uniform(
384      dims, minval, maxval, dtype=minval.dtype, name=name)
385
386
387def rng_bit_generator(algorithm, initial_state, shape, dtype):
388  """Stateless PRNG bit generator.
389
390  Wraps the XLA RngBitGenerator operator, documented at
391    https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
392
393  Args:
394    algorithm: The PRNG algorithm to use, one of
395      tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
396    initial_state: Initial state for the PRNG algorithm. For THREEFRY, it
397      should be a u64[2] and for PHILOX a u64[3].
398    shape: The output shape of the generated data.
399    dtype: The type of the tensor.
400
401  Returns:
402    a tuple with a new state and generated data of the given shape.
403  """
404  alg_int = stateless_random_ops.convert_alg_to_int(algorithm)
405  return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape,
406                                           dtype=dtype)
407
408
409recv = gen_xla_ops.xla_recv
410reduce = gen_xla_ops.xla_reduce
411variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
412
413ops.no_gradient("XlaVariadicReduce")
414
415
416def reduce_window(operand,
417                  init,
418                  reducer,
419                  window_dimensions,
420                  window_strides=None,
421                  base_dilations=None,
422                  window_dilations=None,
423                  padding=None,
424                  name=None):
425  """Wraps the XLA ReduceWindow operator.
426
427  ReduceWindow is documented at
428  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
429
430  Args:
431    operand: the input tensor
432    init: a scalar tensor representing the initial value for the reduction
433    reducer: a reduction function that combines a pair of scalars.
434    window_dimensions: shape of the window, as a list of integers
435    window_strides: inter-window strides, as a list of integers. Optional; if
436      omitted, defaults to strides of 1.
437    padding: padding to apply to 'operand'. List of (low, high) pairs of
438      integers that specify the padding to apply before and after each
439      dimension. Optional; if omitted, defaults to no padding.
440    name: the operator name, or None.
441
442  Returns:
443    A tensor that represents the output of the reduce_window operator.
444  """
445  window_strides = window_strides or [1] * len(window_dimensions)
446  base_dilations = base_dilations or [1] * len(window_dimensions)
447  window_dilations = window_dilations or [1] * len(window_dimensions)
448  padding = padding or [(0, 0)] * len(window_dimensions)
449  return gen_xla_ops.xla_reduce_window(
450      input=operand,
451      init_value=init,
452      window_dimensions=window_dimensions,
453      window_strides=window_strides,
454      base_dilations=base_dilations,
455      window_dilations=window_dilations,
456      padding=padding,
457      computation=reducer,
458      name=name)
459
460
461replica_id = gen_xla_ops.xla_replica_id
462
463# Set a static bound for the given input value as a hint to Xla compiler,
464# returns the same value.
465# Usage:
466# def f(t, p):
467#   p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
468#   return t[:p]            # xla knows the bound of the slice is 3.
469set_bound = gen_xla_ops.xla_set_bound
470
471
472# Make a static dimension into a xla bounded dynamic dimension. The current
473# static dimension size will become the bound and the second operand becomes the
474# dynamic size of the dimension.
475#
476# This should mostly be used for testing.
477#
478# def f():
479#   array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
480#   # Tells xla the valid size of the array is 3.
481#   dim = 0
482#   p = xla_set_dynamic_dimension_size(array, dim, 3)
483#   assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
484set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
485
486
487# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic
488# dimension into a static dimension. The bound of the size of dimension
489# `dim_index` becomes the static dimension size.
490remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size
491
492
493def reshape(x, new_sizes, dimensions=None, name=None):
494  if dimensions is not None:
495    x = array_ops.transpose(x, dimensions)
496  x = array_ops.reshape(x, new_sizes, name=name)
497  return x
498
499
500def select(condition, x, y, name=None):
501  return array_ops.where(condition, x, y, name)
502
503
504select_and_scatter = gen_xla_ops.xla_select_and_scatter
505send = gen_xla_ops.xla_send
506
507
508def slice(x, start_dims, limit_dims, strides):
509  spec = [
510      _slice(start, limit, stride)
511      for (start, limit, stride) in zip(start_dims, limit_dims, strides)
512  ]
513  return x[tuple(spec)]
514
515
516sharding = gen_xla_ops.xla_sharding
517
518
519@ops.RegisterGradient("XlaSharding")
520def _sharding_grad(op, grad):
521  """Gradient for XlaSharding op."""
522  sharding_attr = op.get_attr("sharding")
523  grad_sharding = gen_xla_ops.xla_sharding(
524      grad,
525      sharding=sharding_attr,
526      unspecified_dims=op.get_attr("unspecified_dims"))
527  # pylint: disable=protected-access
528  grad_sharding.op._set_attr("_XlaSharding",
529                             attr_value_pb2.AttrValue(s=sharding_attr))
530  return [grad_sharding]
531
532
533spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
534spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
535
536
537@ops.RegisterGradient("XlaSpmdFullToShardShape")
538def _spmd_full_to_shard_shape_grad(op, grad):
539  s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
540      grad,
541      manual_sharding=op.get_attr("manual_sharding"),
542      full_shape=op.inputs[0].shape.as_list(),
543      dim=op.get_attr("dim"),
544      unspecified_dims=op.get_attr("unspecified_dims"))
545  return [s2f]
546
547
548@ops.RegisterGradient("XlaSpmdShardToFullShape")
549def _spmd_shard_to_full_shape_grad(op, grad):
550  f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
551      grad,
552      manual_sharding=op.get_attr("manual_sharding"),
553      dim=op.get_attr("dim"),
554      unspecified_dims=op.get_attr("unspecified_dims"))
555  return [f2s]
556
557
558sort = gen_xla_ops.xla_sort
559key_value_sort = gen_xla_ops.xla_key_value_sort
560variadic_sort = gen_xla_ops.xla_variadic_sort
561while_loop = gen_xla_ops.xla_while
562dequantize = gen_xla_ops.xla_dequantize
563custom_call = gen_xla_ops.xla_custom_call
564
565
566def call_module(args, *, module, Tout, Sout, dim_args_spec=()):
567  return gen_xla_ops.xla_call_module(
568      args, module=module, dim_args_spec=dim_args_spec, Tout=Tout, Sout=Sout)
569
570
571def gather(operand, start_indices, dimension_numbers, slice_sizes,
572           indices_are_sorted=False, name=None):
573  return gen_xla_ops.xla_gather(
574      operand,
575      start_indices,
576      slice_sizes=slice_sizes,
577      dimension_numbers=dimension_numbers.SerializeToString(),
578      indices_are_sorted=indices_are_sorted,
579      name=name)
580
581
582def scatter(operand, scatter_indices, updates, update_computation,
583            dimension_numbers, indices_are_sorted=False, name=None):
584  return gen_xla_ops.xla_scatter(
585      operand,
586      scatter_indices,
587      updates,
588      update_computation=update_computation,
589      dimension_numbers=dimension_numbers.SerializeToString(),
590      indices_are_sorted=indices_are_sorted,
591      name=name)
592
593
594def optimization_barrier(*args):
595  return gen_xla_ops.xla_optimization_barrier(args)
596
597
598def reduce_precision(operand, exponent_bits, mantissa_bits):
599  return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)
600