1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental library that exposes XLA operations directly in TensorFlow. 16 17It is sometimes useful to be able to build HLO programs directly from 18TensorFlow. This file provides Tensorflow operators that mirror the semantics of 19HLO operators as closely as possible. 20 21Note: Most of the operators defined in this module are used by the jax2tf 22converter (see go/jax2tf for details) and are used in SavedModel produced 23by jax2tf. Hence, we need to maintain backwards compatibility for these 24operators. Please reach out to the JAX team if you want to make changes. 25""" 26 27from __future__ import absolute_import 28from __future__ import division 29from __future__ import print_function 30 31from tensorflow.compiler.tf2xla.ops import gen_xla_ops 32from tensorflow.core.framework import attr_value_pb2 33from tensorflow.python.framework import constant_op 34from tensorflow.python.framework import dtypes 35from tensorflow.python.framework import ops 36from tensorflow.python.ops import array_ops 37from tensorflow.python.ops import bitwise_ops 38from tensorflow.python.ops import gen_math_ops 39from tensorflow.python.ops import gen_random_ops 40from tensorflow.python.ops import math_ops 41from tensorflow.python.ops import random_ops 42from tensorflow.python.ops import special_math_ops 43from tensorflow.python.ops import stateless_random_ops 44from tensorflow.python.ops.numpy_ops import np_utils 45 46# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing 47# ops include: 48# infeed/outfeed (available via tf.contrib.tpu) 49# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) 50# conditional 51# gather/scatter 52# collapse 53 54# This file reuses builtin names (following XLA's names, so we can call things 55# like xla.max), so we capture the builtin versions here. 56# pylint: disable=redefined-builtin 57_max = max 58_min = min 59_slice = slice # pylint: disable=invalid-name 60 61constant = constant_op.constant 62 63# Unary operators. 64 65# For most arithmetic operators there is a TensorFlow operator 66# that exactly corresponds to each XLA operator. Rather than defining 67# XLA-specific variants, we reuse the corresponding TensorFlow operator. 68# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 69# wrap every HLO operator, because that would allow us to be confident that the 70# semantics match. 71 72 73def _unary_op(fn): 74 """Wrapper that restricts `fn` to have the correct signature.""" 75 76 def unary_op_wrapper(x, name=None): 77 return fn(x, name=name) 78 79 return unary_op_wrapper 80 81 82abs = _unary_op(math_ops.abs) 83# TODO(phawkins): implement clz. 84conj = _unary_op(math_ops.conj) 85cos = _unary_op(math_ops.cos) 86ceil = _unary_op(math_ops.ceil) 87digamma = _unary_op(math_ops.digamma) 88erf = _unary_op(math_ops.erf) 89erfc = _unary_op(math_ops.erfc) 90erfinv = _unary_op(math_ops.erfinv) 91ndtri = _unary_op(math_ops.ndtri) 92exp = _unary_op(math_ops.exp) 93expm1 = _unary_op(math_ops.expm1) 94floor = _unary_op(math_ops.floor) 95imag = _unary_op(math_ops.imag) 96is_finite = _unary_op(math_ops.is_finite) 97lgamma = _unary_op(math_ops.lgamma) 98log = _unary_op(math_ops.log) 99log1p = _unary_op(math_ops.log1p) 100logical_not = _unary_op(math_ops.logical_not) 101neg = _unary_op(math_ops.neg) 102real = _unary_op(math_ops.real) 103# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for 104# numbers halfway between two integers. 105round = _unary_op(math_ops.round) 106sin = _unary_op(math_ops.sin) 107sign = _unary_op(math_ops.sign) 108tanh = _unary_op(math_ops.tanh) 109 110# Bessel 111bessel_i0e = _unary_op(special_math_ops.bessel_i0e) 112bessel_i1e = _unary_op(special_math_ops.bessel_i1e) 113 114# Binary operators 115 116# The main difference between TensorFlow and XLA binary ops is the broadcasting 117# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA 118# requires an explicit specification of which dimensions to broadcast if the 119# arguments have different ranks. 120 121 122def _broadcasting_binary_op(fn): 123 """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" 124 125 def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): 126 """Inner wrapper function.""" 127 broadcast_dims = broadcast_dims or [] 128 broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) 129 # Rather than relying on having static shape information in the TensorFlow 130 # graph, we use an XlaBroadcastHelper op that can compute the correct shapes 131 # at JIT compilation time. 132 x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) 133 return fn(x, y, name=name) 134 135 return broadcasting_binary_op_wrapper 136 137 138# Map from TF signed types to TF unsigned types. 139_SIGNED_TO_UNSIGNED_TABLE = { 140 dtypes.int8: dtypes.uint8, 141 dtypes.int16: dtypes.uint16, 142 dtypes.int32: dtypes.uint32, 143 dtypes.int64: dtypes.uint64, 144} 145 146# Map from TF unsigned types to TF signed types. 147_UNSIGNED_TO_SIGNED_TABLE = { 148 dtypes.uint8: dtypes.int8, 149 dtypes.uint16: dtypes.int16, 150 dtypes.uint32: dtypes.int32, 151 dtypes.uint64: dtypes.int64, 152} 153 154 155def _shift_right_logical_helper(x, y, name=None): 156 """Performs an integer right logical shift irrespective of input type.""" 157 assert y.dtype == x.dtype 158 dtype = x.dtype 159 signed = dtype in _SIGNED_TO_UNSIGNED_TABLE 160 if signed: 161 unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] 162 x = math_ops.cast(x, unsigned_dtype) 163 y = math_ops.cast(y, unsigned_dtype) 164 output = bitwise_ops.right_shift(x, y, name=name) 165 if signed: 166 output = math_ops.cast(output, dtype) 167 return output 168 169 170def _shift_right_arithmetic_helper(x, y, name=None): 171 """Performs an integer right arithmetic shift irrespective of input type.""" 172 assert y.dtype == x.dtype 173 dtype = x.dtype 174 unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE 175 if unsigned: 176 signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] 177 x = math_ops.cast(x, signed_dtype) 178 y = math_ops.cast(y, signed_dtype) 179 output = bitwise_ops.right_shift(x, y, name=name) 180 if unsigned: 181 output = math_ops.cast(output, dtype) 182 return output 183 184 185add = _broadcasting_binary_op(math_ops.add) 186sub = _broadcasting_binary_op(math_ops.sub) 187mul = _broadcasting_binary_op(math_ops.mul) 188div = _broadcasting_binary_op(math_ops.div) 189rem = _broadcasting_binary_op(gen_math_ops.mod) 190max = _broadcasting_binary_op(math_ops.maximum) 191min = _broadcasting_binary_op(math_ops.minimum) 192atan2 = _broadcasting_binary_op(math_ops.atan2) 193complex = _broadcasting_binary_op(math_ops.complex) 194logical_and = _broadcasting_binary_op(math_ops.logical_and) 195logical_or = _broadcasting_binary_op(math_ops.logical_or) 196logical_xor = _broadcasting_binary_op(math_ops.logical_xor) 197eq = _broadcasting_binary_op(math_ops.equal) 198ne = _broadcasting_binary_op(math_ops.not_equal) 199ge = _broadcasting_binary_op(math_ops.greater_equal) 200gt = _broadcasting_binary_op(math_ops.greater) 201le = _broadcasting_binary_op(math_ops.less_equal) 202lt = _broadcasting_binary_op(math_ops.less) 203pow = _broadcasting_binary_op(math_ops.pow) 204shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) 205shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) 206shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) 207 208igamma = _broadcasting_binary_op(math_ops.igamma) 209igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) 210random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) 211igammac = _broadcasting_binary_op(math_ops.igammac) 212polygamma = _broadcasting_binary_op(math_ops.polygamma) 213zeta = _broadcasting_binary_op(math_ops.zeta) 214 215 216def _binary_op(fn): 217 """Wrapper that restricts `fn` to have the correct signature.""" 218 219 def binary_op_wrapper(x, y, name=None): 220 return fn(x, y, name=name) 221 222 return binary_op_wrapper 223 224 225transpose = _binary_op(array_ops.transpose) 226rev = _binary_op(array_ops.reverse) 227 228bitcast_convert_type = array_ops.bitcast 229 230 231def broadcast(x, dims, name=None): 232 x = ops.convert_to_tensor(x) 233 shape = array_ops.concat([constant_op.constant(dims), 234 array_ops.shape(x)], 235 axis=0) 236 return array_ops.broadcast_to(x, shape, name=name) 237 238 239def clamp(a, x, b, name=None): 240 return min(max(a, x, name=name), b, name=name) 241 242 243concatenate = array_ops.concat 244 245 246def conv(lhs, 247 rhs, 248 window_strides, 249 padding, 250 lhs_dilation, 251 rhs_dilation, 252 dimension_numbers, 253 feature_group_count=1, 254 precision_config=None, 255 preferred_element_type=None, 256 name=None, 257 use_v2=False): 258 """Wraps the XLA ConvGeneralDilated operator. 259 260 ConvGeneralDilated is the most general form of XLA convolution and is 261 documented at 262 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 263 264 Args: 265 lhs: the input tensor 266 rhs: the kernel tensor 267 window_strides: the inter-window strides 268 padding: the padding to apply at the start and end of each input dimensions 269 lhs_dilation: dilation to apply between input elements 270 rhs_dilation: dilation to apply between kernel elements 271 dimension_numbers: a `ConvolutionDimensionNumbers` proto. 272 feature_group_count: number of feature groups for grouped convolution. 273 precision_config: a `xla.PrecisionConfig` proto. 274 preferred_element_type: the result `dtype`. 275 name: an optional name for the operator. 276 use_v2: an optional request to use the XlaConvV2 op even if not necessary. 277 278 Returns: 279 A tensor representing the output of the convolution. 280 """ 281 precision_config_proto = "" 282 if precision_config: 283 precision_config_proto = precision_config.SerializeToString() 284 needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) 285 if preferred_element_type is None: 286 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) 287 if needs_v2 or use_v2: 288 return gen_xla_ops.xla_conv_v2( 289 lhs, 290 rhs, 291 window_strides=window_strides, 292 padding=padding, 293 lhs_dilation=lhs_dilation, 294 rhs_dilation=rhs_dilation, 295 feature_group_count=feature_group_count, 296 dimension_numbers=dimension_numbers.SerializeToString(), 297 precision_config=precision_config_proto, 298 preferred_element_type=preferred_element_type, 299 name=name) 300 return gen_xla_ops.xla_conv( 301 lhs, 302 rhs, 303 window_strides=window_strides, 304 padding=padding, 305 lhs_dilation=lhs_dilation, 306 rhs_dilation=rhs_dilation, 307 feature_group_count=feature_group_count, 308 dimension_numbers=dimension_numbers.SerializeToString(), 309 precision_config=precision_config_proto, 310 name=name) 311 312 313convert_element_type = math_ops.cast 314 315 316def dot(lhs, rhs, name=None): 317 return math_ops.tensordot(lhs, rhs, axes=1, name=name) 318 319 320def dot_general(lhs, 321 rhs, 322 dimension_numbers, 323 precision_config=None, 324 preferred_element_type=None, 325 name=None, 326 use_v2=False): 327 precision_config_proto = "" 328 if precision_config: 329 precision_config_proto = precision_config.SerializeToString() 330 needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) 331 if preferred_element_type is None: 332 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) 333 if needs_v2 or use_v2: 334 return gen_xla_ops.xla_dot_v2( 335 lhs, 336 rhs, 337 dimension_numbers=dimension_numbers.SerializeToString(), 338 precision_config=precision_config_proto, 339 preferred_element_type=preferred_element_type, 340 name=name) 341 return gen_xla_ops.xla_dot( 342 lhs, 343 rhs, 344 dimension_numbers=dimension_numbers.SerializeToString(), 345 precision_config=precision_config_proto, 346 name=name) 347 348 349def self_adjoint_eig(a, lower, max_iter, epsilon): 350 return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) 351 352 353def svd(a, max_iter, epsilon, precision_config=None): 354 precision_config_proto = "" 355 if precision_config: 356 precision_config_proto = precision_config.SerializeToString() 357 return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) 358 359 360dynamic_slice = gen_xla_ops.xla_dynamic_slice 361dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice 362einsum = gen_xla_ops.xla_einsum 363 364# TODO(phawkins): generalize tf.pad to support interior padding, and then remove 365# the XLA-specific pad operator. 366pad = gen_xla_ops.xla_pad 367 368 369def random_normal(mu, sigma, dims, name=None): 370 mu = ops.convert_to_tensor(mu) 371 return random_ops.random_normal( 372 dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) 373 374 375def random_uniform(minval, maxval, dims, name=None): 376 minval = ops.convert_to_tensor(minval) 377 return random_ops.random_uniform( 378 dims, minval, maxval, dtype=minval.dtype, name=name) 379 380 381def rng_bit_generator(algorithm, initial_state, shape, dtype): 382 """Stateless PRNG bit generator. 383 384 Wraps the XLA RngBitGenerator operator, documented at 385 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. 386 387 Args: 388 algorithm: The PRNG algorithm to use, one of 389 tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. 390 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it 391 should be a u64[2] and for PHILOX a u64[3]. 392 shape: The output shape of the generated data. 393 dtype: The type of the tensor. 394 395 Returns: 396 a tuple with a new state and generated data of the given shape. 397 """ 398 alg_int = stateless_random_ops.convert_alg_to_int(algorithm) 399 return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape, 400 dtype=dtype) 401 402 403recv = gen_xla_ops.xla_recv 404reduce = gen_xla_ops.xla_reduce 405variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2 406 407ops.no_gradient("XlaVariadicReduce") 408 409 410def reduce_window(operand, 411 init, 412 reducer, 413 window_dimensions, 414 window_strides=None, 415 base_dilations=None, 416 window_dilations=None, 417 padding=None, 418 name=None): 419 """Wraps the XLA ReduceWindow operator. 420 421 ReduceWindow is documented at 422 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . 423 424 Args: 425 operand: the input tensor 426 init: a scalar tensor representing the initial value for the reduction 427 reducer: a reduction function that combines a pair of scalars. 428 window_dimensions: shape of the window, as a list of integers 429 window_strides: inter-window strides, as a list of integers. Optional; if 430 omitted, defaults to strides of 1. 431 padding: padding to apply to 'operand'. List of (low, high) pairs of 432 integers that specify the padding to apply before and after each 433 dimension. Optional; if omitted, defaults to no padding. 434 name: the operator name, or None. 435 436 Returns: 437 A tensor that represents the output of the reduce_window operator. 438 """ 439 window_strides = window_strides or [1] * len(window_dimensions) 440 base_dilations = base_dilations or [1] * len(window_dimensions) 441 window_dilations = window_dilations or [1] * len(window_dimensions) 442 padding = padding or [(0, 0)] * len(window_dimensions) 443 return gen_xla_ops.xla_reduce_window( 444 input=operand, 445 init_value=init, 446 window_dimensions=window_dimensions, 447 window_strides=window_strides, 448 base_dilations=base_dilations, 449 window_dilations=window_dilations, 450 padding=padding, 451 computation=reducer, 452 name=name) 453 454 455replica_id = gen_xla_ops.xla_replica_id 456 457# Set a static bound for the given input value as a hint to Xla compiler, 458# returns the same value. 459# Usage: 460# def f(t, p): 461# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. 462# return t[:p] # xla knows the bound of the slice is 3. 463set_bound = gen_xla_ops.xla_set_bound 464 465 466# Make a static dimension into a xla bounded dynamic dimension. The current 467# static dimension size will become the bound and the second operand becomes the 468# dynamic size of the dimension. 469# 470# This should mostly be used for testing. 471# 472# def f(): 473# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) 474# # Tells xla the valid size of the array is 3. 475# dim = 0 476# p = xla_set_dynamic_dimension_size(array, dim, 3) 477# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. 478set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size 479 480 481# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic 482# dimension into a static dimension. The bound of the size of dimension 483# `dim_index` becomes the static dimension size. 484remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size 485 486 487def reshape(x, new_sizes, dimensions=None, name=None): 488 if dimensions is not None: 489 x = array_ops.transpose(x, dimensions) 490 x = array_ops.reshape(x, new_sizes, name=name) 491 return x 492 493 494def select(condition, x, y, name=None): 495 return array_ops.where(condition, x, y, name) 496 497 498select_and_scatter = gen_xla_ops.xla_select_and_scatter 499send = gen_xla_ops.xla_send 500 501 502def slice(x, start_dims, limit_dims, strides): 503 spec = [ 504 _slice(start, limit, stride) 505 for (start, limit, stride) in zip(start_dims, limit_dims, strides) 506 ] 507 return x[tuple(spec)] 508 509 510sharding = gen_xla_ops.xla_sharding 511 512 513@ops.RegisterGradient("XlaSharding") 514def _sharding_grad(op, grad): 515 sharding_attr = op.get_attr("sharding") 516 grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr) 517 # pylint: disable=protected-access 518 grad_sharding.op._set_attr("_XlaSharding", 519 attr_value_pb2.AttrValue(s=sharding_attr)) 520 return [grad_sharding] 521 522 523spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape 524spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape 525 526 527@ops.RegisterGradient("XlaSpmdFullToShardShape") 528def _spmd_full_to_shard_shape_grad(op, grad): 529 s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( 530 grad, 531 manual_sharding=op.get_attr("manual_sharding"), 532 full_shape=op.inputs[0].shape.as_list()) 533 return [s2f] 534 535 536@ops.RegisterGradient("XlaSpmdShardToFullShape") 537def _spmd_shard_to_full_shape_grad(op, grad): 538 f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( 539 grad, manual_sharding=op.get_attr("manual_sharding")) 540 return [f2s] 541 542 543sort = gen_xla_ops.xla_sort 544key_value_sort = gen_xla_ops.xla_key_value_sort 545variadic_sort = gen_xla_ops.xla_variadic_sort 546while_loop = gen_xla_ops.xla_while 547dequantize = gen_xla_ops.xla_dequantize 548 549 550def gather(operand, start_indices, dimension_numbers, slice_sizes, 551 indices_are_sorted=False, name=None): 552 return gen_xla_ops.xla_gather( 553 operand, 554 start_indices, 555 slice_sizes=slice_sizes, 556 dimension_numbers=dimension_numbers.SerializeToString(), 557 indices_are_sorted=indices_are_sorted, 558 name=name) 559 560 561def scatter(operand, scatter_indices, updates, update_computation, 562 dimension_numbers, indices_are_sorted=False, name=None): 563 return gen_xla_ops.xla_scatter( 564 operand, 565 scatter_indices, 566 updates, 567 update_computation=update_computation, 568 dimension_numbers=dimension_numbers.SerializeToString(), 569 indices_are_sorted=indices_are_sorted, 570 name=name) 571