• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental library that exposes XLA operations directly in TensorFlow.
16
17It is sometimes useful to be able to build HLO programs directly from
18TensorFlow. This file provides Tensorflow operators that mirror the semantics of
19HLO operators as closely as possible.
20
21Note: Most of the operators defined in this module are used by the jax2tf
22converter (see go/jax2tf for details) and are used in SavedModel produced
23by jax2tf. Hence, we need to maintain backwards compatibility for these
24operators. Please reach out to the JAX team if you want to make changes.
25"""
26
27from __future__ import absolute_import
28from __future__ import division
29from __future__ import print_function
30
31from tensorflow.compiler.tf2xla.ops import gen_xla_ops
32from tensorflow.core.framework import attr_value_pb2
33from tensorflow.python.framework import constant_op
34from tensorflow.python.framework import dtypes
35from tensorflow.python.framework import ops
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import bitwise_ops
38from tensorflow.python.ops import gen_math_ops
39from tensorflow.python.ops import gen_random_ops
40from tensorflow.python.ops import math_ops
41from tensorflow.python.ops import random_ops
42from tensorflow.python.ops import special_math_ops
43from tensorflow.python.ops import stateless_random_ops
44from tensorflow.python.ops.numpy_ops import np_utils
45
46# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
47# ops include:
48# infeed/outfeed (available via tf.contrib.tpu)
49# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
50# conditional
51# gather/scatter
52# collapse
53
54# This file reuses builtin names (following XLA's names, so we can call things
55# like xla.max), so we capture the builtin versions here.
56# pylint: disable=redefined-builtin
57_max = max
58_min = min
59_slice = slice  # pylint: disable=invalid-name
60
61constant = constant_op.constant
62
63# Unary operators.
64
65# For most arithmetic operators there is a TensorFlow operator
66# that exactly corresponds to each XLA operator. Rather than defining
67# XLA-specific variants, we reuse the corresponding TensorFlow operator.
68# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
69# wrap every HLO operator, because that would allow us to be confident that the
70# semantics match.
71
72
73def _unary_op(fn):
74  """Wrapper that restricts `fn` to have the correct signature."""
75
76  def unary_op_wrapper(x, name=None):
77    return fn(x, name=name)
78
79  return unary_op_wrapper
80
81
82abs = _unary_op(math_ops.abs)
83# TODO(phawkins): implement clz.
84conj = _unary_op(math_ops.conj)
85cos = _unary_op(math_ops.cos)
86ceil = _unary_op(math_ops.ceil)
87digamma = _unary_op(math_ops.digamma)
88erf = _unary_op(math_ops.erf)
89erfc = _unary_op(math_ops.erfc)
90erfinv = _unary_op(math_ops.erfinv)
91ndtri = _unary_op(math_ops.ndtri)
92exp = _unary_op(math_ops.exp)
93expm1 = _unary_op(math_ops.expm1)
94floor = _unary_op(math_ops.floor)
95imag = _unary_op(math_ops.imag)
96is_finite = _unary_op(math_ops.is_finite)
97lgamma = _unary_op(math_ops.lgamma)
98log = _unary_op(math_ops.log)
99log1p = _unary_op(math_ops.log1p)
100logical_not = _unary_op(math_ops.logical_not)
101neg = _unary_op(math_ops.neg)
102real = _unary_op(math_ops.real)
103# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
104# numbers halfway between two integers.
105round = _unary_op(math_ops.round)
106sin = _unary_op(math_ops.sin)
107sign = _unary_op(math_ops.sign)
108tanh = _unary_op(math_ops.tanh)
109
110# Bessel
111bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
112bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
113
114# Binary operators
115
116# The main difference between TensorFlow and XLA binary ops is the broadcasting
117# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
118# requires an explicit specification of which dimensions to broadcast if the
119# arguments have different ranks.
120
121
122def _broadcasting_binary_op(fn):
123  """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
124
125  def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
126    """Inner wrapper function."""
127    broadcast_dims = broadcast_dims or []
128    broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
129    # Rather than relying on having static shape information in the TensorFlow
130    # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
131    # at JIT compilation time.
132    x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
133    return fn(x, y, name=name)
134
135  return broadcasting_binary_op_wrapper
136
137
138# Map from TF signed types to TF unsigned types.
139_SIGNED_TO_UNSIGNED_TABLE = {
140    dtypes.int8: dtypes.uint8,
141    dtypes.int16: dtypes.uint16,
142    dtypes.int32: dtypes.uint32,
143    dtypes.int64: dtypes.uint64,
144}
145
146# Map from TF unsigned types to TF signed types.
147_UNSIGNED_TO_SIGNED_TABLE = {
148    dtypes.uint8: dtypes.int8,
149    dtypes.uint16: dtypes.int16,
150    dtypes.uint32: dtypes.int32,
151    dtypes.uint64: dtypes.int64,
152}
153
154
155def _shift_right_logical_helper(x, y, name=None):
156  """Performs an integer right logical shift irrespective of input type."""
157  assert y.dtype == x.dtype
158  dtype = x.dtype
159  signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
160  if signed:
161    unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
162    x = math_ops.cast(x, unsigned_dtype)
163    y = math_ops.cast(y, unsigned_dtype)
164  output = bitwise_ops.right_shift(x, y, name=name)
165  if signed:
166    output = math_ops.cast(output, dtype)
167  return output
168
169
170def _shift_right_arithmetic_helper(x, y, name=None):
171  """Performs an integer right arithmetic shift irrespective of input type."""
172  assert y.dtype == x.dtype
173  dtype = x.dtype
174  unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
175  if unsigned:
176    signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
177    x = math_ops.cast(x, signed_dtype)
178    y = math_ops.cast(y, signed_dtype)
179  output = bitwise_ops.right_shift(x, y, name=name)
180  if unsigned:
181    output = math_ops.cast(output, dtype)
182  return output
183
184
185add = _broadcasting_binary_op(math_ops.add)
186sub = _broadcasting_binary_op(math_ops.sub)
187mul = _broadcasting_binary_op(math_ops.mul)
188div = _broadcasting_binary_op(math_ops.div)
189rem = _broadcasting_binary_op(gen_math_ops.mod)
190max = _broadcasting_binary_op(math_ops.maximum)
191min = _broadcasting_binary_op(math_ops.minimum)
192atan2 = _broadcasting_binary_op(math_ops.atan2)
193complex = _broadcasting_binary_op(math_ops.complex)
194logical_and = _broadcasting_binary_op(math_ops.logical_and)
195logical_or = _broadcasting_binary_op(math_ops.logical_or)
196logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
197eq = _broadcasting_binary_op(math_ops.equal)
198ne = _broadcasting_binary_op(math_ops.not_equal)
199ge = _broadcasting_binary_op(math_ops.greater_equal)
200gt = _broadcasting_binary_op(math_ops.greater)
201le = _broadcasting_binary_op(math_ops.less_equal)
202lt = _broadcasting_binary_op(math_ops.less)
203pow = _broadcasting_binary_op(math_ops.pow)
204shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
205shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
206shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
207
208igamma = _broadcasting_binary_op(math_ops.igamma)
209igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
210random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
211igammac = _broadcasting_binary_op(math_ops.igammac)
212polygamma = _broadcasting_binary_op(math_ops.polygamma)
213zeta = _broadcasting_binary_op(math_ops.zeta)
214
215
216def _binary_op(fn):
217  """Wrapper that restricts `fn` to have the correct signature."""
218
219  def binary_op_wrapper(x, y, name=None):
220    return fn(x, y, name=name)
221
222  return binary_op_wrapper
223
224
225transpose = _binary_op(array_ops.transpose)
226rev = _binary_op(array_ops.reverse)
227
228bitcast_convert_type = array_ops.bitcast
229
230
231def broadcast(x, dims, name=None):
232  x = ops.convert_to_tensor(x)
233  shape = array_ops.concat([constant_op.constant(dims),
234                            array_ops.shape(x)],
235                           axis=0)
236  return array_ops.broadcast_to(x, shape, name=name)
237
238
239def clamp(a, x, b, name=None):
240  return min(max(a, x, name=name), b, name=name)
241
242
243concatenate = array_ops.concat
244
245
246def conv(lhs,
247         rhs,
248         window_strides,
249         padding,
250         lhs_dilation,
251         rhs_dilation,
252         dimension_numbers,
253         feature_group_count=1,
254         precision_config=None,
255         preferred_element_type=None,
256         name=None,
257         use_v2=False):
258  """Wraps the XLA ConvGeneralDilated operator.
259
260  ConvGeneralDilated is the most general form of XLA convolution and is
261  documented at
262  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
263
264  Args:
265    lhs: the input tensor
266    rhs: the kernel tensor
267    window_strides: the inter-window strides
268    padding: the padding to apply at the start and end of each input dimensions
269    lhs_dilation: dilation to apply between input elements
270    rhs_dilation: dilation to apply between kernel elements
271    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
272    feature_group_count: number of feature groups for grouped convolution.
273    precision_config: a `xla.PrecisionConfig` proto.
274    preferred_element_type: the result `dtype`.
275    name: an optional name for the operator.
276    use_v2: an optional request to use the XlaConvV2 op even if not necessary.
277
278  Returns:
279    A tensor representing the output of the convolution.
280  """
281  precision_config_proto = ""
282  if precision_config:
283    precision_config_proto = precision_config.SerializeToString()
284  needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
285  if preferred_element_type is None:
286    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
287  if needs_v2 or use_v2:
288    return gen_xla_ops.xla_conv_v2(
289        lhs,
290        rhs,
291        window_strides=window_strides,
292        padding=padding,
293        lhs_dilation=lhs_dilation,
294        rhs_dilation=rhs_dilation,
295        feature_group_count=feature_group_count,
296        dimension_numbers=dimension_numbers.SerializeToString(),
297        precision_config=precision_config_proto,
298        preferred_element_type=preferred_element_type,
299        name=name)
300  return gen_xla_ops.xla_conv(
301      lhs,
302      rhs,
303      window_strides=window_strides,
304      padding=padding,
305      lhs_dilation=lhs_dilation,
306      rhs_dilation=rhs_dilation,
307      feature_group_count=feature_group_count,
308      dimension_numbers=dimension_numbers.SerializeToString(),
309      precision_config=precision_config_proto,
310      name=name)
311
312
313convert_element_type = math_ops.cast
314
315
316def dot(lhs, rhs, name=None):
317  return math_ops.tensordot(lhs, rhs, axes=1, name=name)
318
319
320def dot_general(lhs,
321                rhs,
322                dimension_numbers,
323                precision_config=None,
324                preferred_element_type=None,
325                name=None,
326                use_v2=False):
327  precision_config_proto = ""
328  if precision_config:
329    precision_config_proto = precision_config.SerializeToString()
330  needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
331  if preferred_element_type is None:
332    preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
333  if needs_v2 or use_v2:
334    return gen_xla_ops.xla_dot_v2(
335        lhs,
336        rhs,
337        dimension_numbers=dimension_numbers.SerializeToString(),
338        precision_config=precision_config_proto,
339        preferred_element_type=preferred_element_type,
340        name=name)
341  return gen_xla_ops.xla_dot(
342      lhs,
343      rhs,
344      dimension_numbers=dimension_numbers.SerializeToString(),
345      precision_config=precision_config_proto,
346      name=name)
347
348
349def self_adjoint_eig(a, lower, max_iter, epsilon):
350  return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
351
352
353def svd(a, max_iter, epsilon, precision_config=None):
354  precision_config_proto = ""
355  if precision_config:
356    precision_config_proto = precision_config.SerializeToString()
357  return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
358
359
360dynamic_slice = gen_xla_ops.xla_dynamic_slice
361dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
362einsum = gen_xla_ops.xla_einsum
363
364# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
365# the XLA-specific pad operator.
366pad = gen_xla_ops.xla_pad
367
368
369def random_normal(mu, sigma, dims, name=None):
370  mu = ops.convert_to_tensor(mu)
371  return random_ops.random_normal(
372      dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
373
374
375def random_uniform(minval, maxval, dims, name=None):
376  minval = ops.convert_to_tensor(minval)
377  return random_ops.random_uniform(
378      dims, minval, maxval, dtype=minval.dtype, name=name)
379
380
381def rng_bit_generator(algorithm, initial_state, shape, dtype):
382  """Stateless PRNG bit generator.
383
384  Wraps the XLA RngBitGenerator operator, documented at
385    https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
386
387  Args:
388    algorithm: The PRNG algorithm to use, one of
389      tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}.
390    initial_state: Initial state for the PRNG algorithm. For THREEFRY, it
391      should be a u64[2] and for PHILOX a u64[3].
392    shape: The output shape of the generated data.
393    dtype: The type of the tensor.
394
395  Returns:
396    a tuple with a new state and generated data of the given shape.
397  """
398  alg_int = stateless_random_ops.convert_alg_to_int(algorithm)
399  return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape,
400                                           dtype=dtype)
401
402
403recv = gen_xla_ops.xla_recv
404reduce = gen_xla_ops.xla_reduce
405variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
406
407ops.no_gradient("XlaVariadicReduce")
408
409
410def reduce_window(operand,
411                  init,
412                  reducer,
413                  window_dimensions,
414                  window_strides=None,
415                  base_dilations=None,
416                  window_dilations=None,
417                  padding=None,
418                  name=None):
419  """Wraps the XLA ReduceWindow operator.
420
421  ReduceWindow is documented at
422  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
423
424  Args:
425    operand: the input tensor
426    init: a scalar tensor representing the initial value for the reduction
427    reducer: a reduction function that combines a pair of scalars.
428    window_dimensions: shape of the window, as a list of integers
429    window_strides: inter-window strides, as a list of integers. Optional; if
430      omitted, defaults to strides of 1.
431    padding: padding to apply to 'operand'. List of (low, high) pairs of
432      integers that specify the padding to apply before and after each
433      dimension. Optional; if omitted, defaults to no padding.
434    name: the operator name, or None.
435
436  Returns:
437    A tensor that represents the output of the reduce_window operator.
438  """
439  window_strides = window_strides or [1] * len(window_dimensions)
440  base_dilations = base_dilations or [1] * len(window_dimensions)
441  window_dilations = window_dilations or [1] * len(window_dimensions)
442  padding = padding or [(0, 0)] * len(window_dimensions)
443  return gen_xla_ops.xla_reduce_window(
444      input=operand,
445      init_value=init,
446      window_dimensions=window_dimensions,
447      window_strides=window_strides,
448      base_dilations=base_dilations,
449      window_dilations=window_dilations,
450      padding=padding,
451      computation=reducer,
452      name=name)
453
454
455replica_id = gen_xla_ops.xla_replica_id
456
457# Set a static bound for the given input value as a hint to Xla compiler,
458# returns the same value.
459# Usage:
460# def f(t, p):
461#   p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
462#   return t[:p]            # xla knows the bound of the slice is 3.
463set_bound = gen_xla_ops.xla_set_bound
464
465
466# Make a static dimension into a xla bounded dynamic dimension. The current
467# static dimension size will become the bound and the second operand becomes the
468# dynamic size of the dimension.
469#
470# This should mostly be used for testing.
471#
472# def f():
473#   array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
474#   # Tells xla the valid size of the array is 3.
475#   dim = 0
476#   p = xla_set_dynamic_dimension_size(array, dim, 3)
477#   assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
478set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
479
480
481# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic
482# dimension into a static dimension. The bound of the size of dimension
483# `dim_index` becomes the static dimension size.
484remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size
485
486
487def reshape(x, new_sizes, dimensions=None, name=None):
488  if dimensions is not None:
489    x = array_ops.transpose(x, dimensions)
490  x = array_ops.reshape(x, new_sizes, name=name)
491  return x
492
493
494def select(condition, x, y, name=None):
495  return array_ops.where(condition, x, y, name)
496
497
498select_and_scatter = gen_xla_ops.xla_select_and_scatter
499send = gen_xla_ops.xla_send
500
501
502def slice(x, start_dims, limit_dims, strides):
503  spec = [
504      _slice(start, limit, stride)
505      for (start, limit, stride) in zip(start_dims, limit_dims, strides)
506  ]
507  return x[tuple(spec)]
508
509
510sharding = gen_xla_ops.xla_sharding
511
512
513@ops.RegisterGradient("XlaSharding")
514def _sharding_grad(op, grad):
515  sharding_attr = op.get_attr("sharding")
516  grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
517  # pylint: disable=protected-access
518  grad_sharding.op._set_attr("_XlaSharding",
519                             attr_value_pb2.AttrValue(s=sharding_attr))
520  return [grad_sharding]
521
522
523spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
524spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
525
526
527@ops.RegisterGradient("XlaSpmdFullToShardShape")
528def _spmd_full_to_shard_shape_grad(op, grad):
529  s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
530      grad,
531      manual_sharding=op.get_attr("manual_sharding"),
532      full_shape=op.inputs[0].shape.as_list())
533  return [s2f]
534
535
536@ops.RegisterGradient("XlaSpmdShardToFullShape")
537def _spmd_shard_to_full_shape_grad(op, grad):
538  f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
539      grad, manual_sharding=op.get_attr("manual_sharding"))
540  return [f2s]
541
542
543sort = gen_xla_ops.xla_sort
544key_value_sort = gen_xla_ops.xla_key_value_sort
545variadic_sort = gen_xla_ops.xla_variadic_sort
546while_loop = gen_xla_ops.xla_while
547dequantize = gen_xla_ops.xla_dequantize
548
549
550def gather(operand, start_indices, dimension_numbers, slice_sizes,
551           indices_are_sorted=False, name=None):
552  return gen_xla_ops.xla_gather(
553      operand,
554      start_indices,
555      slice_sizes=slice_sizes,
556      dimension_numbers=dimension_numbers.SerializeToString(),
557      indices_are_sorted=indices_are_sorted,
558      name=name)
559
560
561def scatter(operand, scatter_indices, updates, update_computation,
562            dimension_numbers, indices_are_sorted=False, name=None):
563  return gen_xla_ops.xla_scatter(
564      operand,
565      scatter_indices,
566      updates,
567      update_computation=update_computation,
568      dimension_numbers=dimension_numbers.SerializeToString(),
569      indices_are_sorted=indices_are_sorted,
570      name=name)
571