• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Experimental library that exposes XLA operations directly in TensorFlow.
16
17It is sometimes useful to be able to build HLO programs directly from
18TensorFlow. This file provides Tensorflow operators that mirror the semantics of
19HLO operators as closely as possible.
20
21Note: There is no promise of backward or forward compatibility for operators
22defined in this module. This is primarily because the underlying HLO operators
23do not promise backward or forward compatibility.
24"""
25
26from __future__ import absolute_import
27from __future__ import division
28from __future__ import print_function
29
30from tensorflow.compiler.tf2xla.ops import gen_xla_ops
31from tensorflow.python.framework import constant_op
32from tensorflow.python.framework import dtypes
33from tensorflow.python.framework import ops
34from tensorflow.python.ops import array_ops
35from tensorflow.python.ops import bitwise_ops
36from tensorflow.python.ops import gen_math_ops
37from tensorflow.python.ops import math_ops
38from tensorflow.python.ops import random_ops
39
40# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
41# ops include:
42# infeed/outfeed (available via tf.contrib.tpu)
43# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
44# conditional
45# gather/scatter
46# collapse
47
48# This file reuses builtin names (following XLA's names, so we can call things
49# like xla.max), so we capture the builtin versions here.
50# pylint: disable=redefined-builtin
51_max = max
52_min = min
53_slice = slice  # pylint: disable=invalid-name
54
55constant = constant_op.constant
56
57# Unary operators.
58
59# For most arithmetic operators there is a TensorFlow operator
60# that exactly corresponds to each XLA operator. Rather than defining
61# XLA-specific variants, we reuse the corresponding TensorFlow operator.
62# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
63# wrap every HLO operator, because that would allow us to be confident that the
64# semantics match.
65
66
67def _unary_op(fn):
68  """Wrapper that restricts `fn` to have the correct signature."""
69
70  def unary_op_wrapper(x, name=None):
71    return fn(x, name=name)
72
73  return unary_op_wrapper
74
75
76abs = _unary_op(math_ops.abs)
77# TODO(phawkins): implement clz.
78conj = _unary_op(math_ops.conj)
79cos = _unary_op(math_ops.cos)
80ceil = _unary_op(math_ops.ceil)
81digamma = _unary_op(math_ops.digamma)
82erf = _unary_op(math_ops.erf)
83erfc = _unary_op(math_ops.erfc)
84erfinv = _unary_op(math_ops.erfinv)
85ndtri = _unary_op(math_ops.ndtri)
86exp = _unary_op(math_ops.exp)
87expm1 = _unary_op(math_ops.expm1)
88floor = _unary_op(math_ops.floor)
89imag = _unary_op(math_ops.imag)
90is_finite = _unary_op(math_ops.is_finite)
91lgamma = _unary_op(math_ops.lgamma)
92log = _unary_op(math_ops.log)
93log1p = _unary_op(math_ops.log1p)
94logical_not = _unary_op(math_ops.logical_not)
95neg = _unary_op(math_ops.neg)
96real = _unary_op(math_ops.real)
97# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
98# numbers halfway between two integers.
99round = _unary_op(math_ops.round)
100sin = _unary_op(math_ops.sin)
101sign = _unary_op(math_ops.sign)
102tanh = _unary_op(math_ops.tanh)
103
104# Bessel
105bessel_i0e = _unary_op(math_ops.bessel_i0e)
106bessel_i1e = _unary_op(math_ops.bessel_i1e)
107
108# Binary operators
109
110# The main difference between TensorFlow and XLA binary ops is the broadcasting
111# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
112# requires an explicit specification of which dimensions to broadcast if the
113# arguments have different ranks.
114
115
116def _broadcasting_binary_op(fn):
117  """Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
118
119  def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
120    """Inner wrapper function."""
121    broadcast_dims = broadcast_dims or []
122    broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
123    # Rather than relying on having static shape information in the TensorFlow
124    # graph, we use an XlaBroadcastHelper op that can compute the correct shapes
125    # at JIT compilation time.
126    x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
127    return fn(x, y, name=name)
128
129  return broadcasting_binary_op_wrapper
130
131
132# Map from TF signed types to TF unsigned types.
133_SIGNED_TO_UNSIGNED_TABLE = {
134    dtypes.int8: dtypes.uint8,
135    dtypes.int16: dtypes.uint16,
136    dtypes.int32: dtypes.uint32,
137    dtypes.int64: dtypes.uint64,
138}
139
140# Map from TF unsigned types to TF signed types.
141_UNSIGNED_TO_SIGNED_TABLE = {
142    dtypes.uint8: dtypes.int8,
143    dtypes.uint16: dtypes.int16,
144    dtypes.uint32: dtypes.int32,
145    dtypes.uint64: dtypes.int64,
146}
147
148
149def _shift_right_logical_helper(x, y, name=None):
150  """Performs an integer right logical shift irrespective of input type."""
151  assert y.dtype == x.dtype
152  dtype = x.dtype
153  signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
154  if signed:
155    unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
156    x = math_ops.cast(x, unsigned_dtype)
157    y = math_ops.cast(y, unsigned_dtype)
158  output = bitwise_ops.right_shift(x, y, name=name)
159  if signed:
160    output = math_ops.cast(output, dtype)
161  return output
162
163
164def _shift_right_arithmetic_helper(x, y, name=None):
165  """Performs an integer right arithmetic shift irrespective of input type."""
166  assert y.dtype == x.dtype
167  dtype = x.dtype
168  unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
169  if unsigned:
170    signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
171    x = math_ops.cast(x, signed_dtype)
172    y = math_ops.cast(y, signed_dtype)
173  output = bitwise_ops.right_shift(x, y, name=name)
174  if unsigned:
175    output = math_ops.cast(output, dtype)
176  return output
177
178
179add = _broadcasting_binary_op(math_ops.add)
180sub = _broadcasting_binary_op(math_ops.sub)
181mul = _broadcasting_binary_op(math_ops.mul)
182div = _broadcasting_binary_op(math_ops.div)
183rem = _broadcasting_binary_op(gen_math_ops.mod)
184max = _broadcasting_binary_op(math_ops.maximum)
185min = _broadcasting_binary_op(math_ops.minimum)
186atan2 = _broadcasting_binary_op(math_ops.atan2)
187complex = _broadcasting_binary_op(math_ops.complex)
188logical_and = _broadcasting_binary_op(math_ops.logical_and)
189logical_or = _broadcasting_binary_op(math_ops.logical_or)
190logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
191eq = _broadcasting_binary_op(math_ops.equal)
192ne = _broadcasting_binary_op(math_ops.not_equal)
193ge = _broadcasting_binary_op(math_ops.greater_equal)
194gt = _broadcasting_binary_op(math_ops.greater)
195le = _broadcasting_binary_op(math_ops.less_equal)
196lt = _broadcasting_binary_op(math_ops.less)
197pow = _broadcasting_binary_op(math_ops.pow)
198shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
199shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
200shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
201
202igamma = _broadcasting_binary_op(math_ops.igamma)
203igammac = _broadcasting_binary_op(math_ops.igammac)
204
205
206def _binary_op(fn):
207  """Wrapper that restricts `fn` to have the correct signature."""
208
209  def binary_op_wrapper(x, y, name=None):
210    return fn(x, y, name=name)
211
212  return binary_op_wrapper
213
214
215transpose = _binary_op(array_ops.transpose)
216rev = _binary_op(array_ops.reverse)
217
218bitcast_convert_type = array_ops.bitcast
219
220
221def broadcast(x, dims, name=None):
222  x = ops.convert_to_tensor(x)
223  shape = array_ops.concat([constant_op.constant(dims),
224                            array_ops.shape(x)],
225                           axis=0)
226  return array_ops.broadcast_to(x, shape, name=name)
227
228
229def clamp(a, x, b, name=None):
230  return min(max(a, x, name=name), b, name=name)
231
232
233concatenate = array_ops.concat
234
235
236def conv(lhs,
237         rhs,
238         window_strides,
239         padding,
240         lhs_dilation,
241         rhs_dilation,
242         dimension_numbers,
243         feature_group_count=1,
244         precision_config=None,
245         name=None):
246  """Wraps the XLA ConvGeneralDilated operator.
247
248  ConvGeneralDilated is the most general form of XLA convolution and is
249  documented at
250  https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
251
252  Args:
253    lhs: the input tensor
254    rhs: the kernel tensor
255    window_strides: the inter-window strides
256    padding: the padding to apply at the start and end of each input dimensions
257    lhs_dilation: dilation to apply between input elements
258    rhs_dilation: dilation to apply between kernel elements
259    dimension_numbers: a `ConvolutionDimensionNumbers` proto.
260    feature_group_count: number of feature groups for grouped convolution.
261    precision_config: a `xla.PrecisionConfig` proto.
262    name: an optional name for the operator
263
264  Returns:
265    A tensor representing the output of the convolution.
266  """
267  precision_config_proto = ""
268  if precision_config:
269    precision_config_proto = precision_config.SerializeToString()
270  return gen_xla_ops.xla_conv(
271      lhs,
272      rhs,
273      window_strides=window_strides,
274      padding=padding,
275      lhs_dilation=lhs_dilation,
276      rhs_dilation=rhs_dilation,
277      feature_group_count=feature_group_count,
278      dimension_numbers=dimension_numbers.SerializeToString(),
279      precision_config=precision_config_proto,
280      name=name)
281
282
283convert_element_type = math_ops.cast
284
285
286def dot(lhs, rhs, name=None):
287  return math_ops.tensordot(lhs, rhs, axes=1, name=name)
288
289
290def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None):
291  precision_config_proto = ""
292  if precision_config:
293    precision_config_proto = precision_config.SerializeToString()
294  return gen_xla_ops.xla_dot(
295      lhs,
296      rhs,
297      dimension_numbers=dimension_numbers.SerializeToString(),
298      precision_config=precision_config_proto,
299      name=name)
300
301
302def self_adjoint_eig(a, lower, max_iter, epsilon):
303  return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
304
305
306def svd(a, max_iter, epsilon, precision_config=None):
307  precision_config_proto = ""
308  if precision_config:
309    precision_config_proto = precision_config.SerializeToString()
310  return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
311
312
313dynamic_slice = gen_xla_ops.xla_dynamic_slice
314dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
315einsum = gen_xla_ops.xla_einsum
316
317# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
318# the XLA-specific pad operator.
319pad = gen_xla_ops.xla_pad
320
321
322def random_normal(mu, sigma, dims, name=None):
323  mu = ops.convert_to_tensor(mu)
324  return random_ops.random_normal(
325      dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name)
326
327
328def random_uniform(minval, maxval, dims, name=None):
329  minval = ops.convert_to_tensor(minval)
330  return random_ops.random_uniform(
331      dims, minval, maxval, dtype=minval.dtype, name=name)
332
333
334recv = gen_xla_ops.xla_recv
335reduce = gen_xla_ops.xla_reduce
336
337
338def reduce_window(operand,
339                  init,
340                  reducer,
341                  window_dimensions,
342                  window_strides=None,
343                  base_dilations=None,
344                  window_dilations=None,
345                  padding=None,
346                  name=None):
347  """Wraps the XLA ReduceWindow operator.
348
349  ReduceWindow is documented at
350  https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
351
352  Args:
353    operand: the input tensor
354    init: a scalar tensor representing the initial value for the reduction
355    reducer: a reduction function that combines a pair of scalars.
356    window_dimensions: shape of the window, as a list of integers
357    window_strides: inter-window strides, as a list of integers. Optional; if
358      omitted, defaults to strides of 1.
359    padding: padding to apply to 'operand'. List of (low, high) pairs of
360      integers that specify the padding to apply before and after each
361      dimension. Optional; if omitted, defaults to no padding.
362    name: the operator name, or None.
363
364  Returns:
365    A tensor that represents the output of the reduce_window operator.
366  """
367  window_strides = window_strides or [1] * len(window_dimensions)
368  base_dilations = base_dilations or [1] * len(window_dimensions)
369  window_dilations = window_dilations or [1] * len(window_dimensions)
370  padding = padding or [(0, 0)] * len(window_dimensions)
371  return gen_xla_ops.xla_reduce_window(
372      input=operand,
373      init_value=init,
374      window_dimensions=window_dimensions,
375      window_strides=window_strides,
376      base_dilations=base_dilations,
377      window_dilations=window_dilations,
378      padding=padding,
379      computation=reducer,
380      name=name)
381
382
383replica_id = gen_xla_ops.xla_replica_id
384
385
386def reshape(x, new_sizes, dimensions=None, name=None):
387  if dimensions is not None:
388    x = array_ops.transpose(x, dimensions)
389  x = array_ops.reshape(x, new_sizes, name=name)
390  return x
391
392
393def select(condition, x, y, name=None):
394  return array_ops.where(condition, x, y, name)
395
396
397select_and_scatter = gen_xla_ops.xla_select_and_scatter
398send = gen_xla_ops.xla_send
399
400
401def slice(x, start_dims, limit_dims, strides):
402  spec = [
403      _slice(start, limit, stride)
404      for (start, limit, stride) in zip(start_dims, limit_dims, strides)
405  ]
406  return x[tuple(spec)]
407
408
409sharding = gen_xla_ops.xla_sharding
410
411
412@ops.RegisterGradient("XlaSharding")
413def _sharding_grad(op, grad):
414  del op  # Unused
415  return [grad]
416
417
418sort = gen_xla_ops.xla_sort
419key_value_sort = gen_xla_ops.xla_key_value_sort
420while_loop = gen_xla_ops.xla_while
421dequantize = gen_xla_ops.xla_dequantize
422
423
424def gather(operand, start_indices, dimension_numbers, slice_sizes,
425           indices_are_sorted=False, name=None):
426  return gen_xla_ops.xla_gather(
427      operand,
428      start_indices,
429      slice_sizes=slice_sizes,
430      dimension_numbers=dimension_numbers.SerializeToString(),
431      indices_are_sorted=indices_are_sorted,
432      name=name)
433
434
435def scatter(operand, scatter_indices, updates, update_computation,
436            dimension_numbers, indices_are_sorted=False, name=None):
437  return gen_xla_ops.xla_scatter(
438      operand,
439      scatter_indices,
440      updates,
441      update_computation=update_computation,
442      dimension_numbers=dimension_numbers.SerializeToString(),
443      indices_are_sorted=indices_are_sorted,
444      name=name)
445