1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental library that exposes XLA operations directly in TensorFlow. 16 17It is sometimes useful to be able to build HLO programs directly from 18TensorFlow. This file provides Tensorflow operators that mirror the semantics of 19HLO operators as closely as possible. 20 21Note: Most of the operators defined in this module are used by the jax2tf 22converter (see go/jax2tf for details) and are used in SavedModel produced 23by jax2tf. Hence, we need to maintain backwards compatibility for these 24operators. Please reach out to the JAX team if you want to make changes. 25""" 26 27from tensorflow.compiler.tf2xla.ops import gen_xla_ops 28from tensorflow.compiler.xla import xla_data_pb2 29from tensorflow.core.framework import attr_value_pb2 30from tensorflow.python.framework import constant_op 31from tensorflow.python.framework import dtypes 32from tensorflow.python.framework import ops 33from tensorflow.python.ops import array_ops 34from tensorflow.python.ops import bitwise_ops 35from tensorflow.python.ops import gen_math_ops 36from tensorflow.python.ops import gen_random_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import random_ops 39from tensorflow.python.ops import special_math_ops 40from tensorflow.python.ops import stateless_random_ops 41from tensorflow.python.ops.numpy_ops import np_utils 42 43# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing 44# ops include: 45# infeed/outfeed (available via tf.contrib.tpu) 46# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) 47# conditional 48# gather/scatter 49# collapse 50 51# This file reuses builtin names (following XLA's names, so we can call things 52# like xla.max), so we capture the builtin versions here. 53# pylint: disable=redefined-builtin 54_max = max 55_min = min 56_slice = slice # pylint: disable=invalid-name 57 58constant = constant_op.constant 59 60# Unary operators. 61 62# For most arithmetic operators there is a TensorFlow operator 63# that exactly corresponds to each XLA operator. Rather than defining 64# XLA-specific variants, we reuse the corresponding TensorFlow operator. 65# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 66# wrap every HLO operator, because that would allow us to be confident that the 67# semantics match. 68 69 70def _unary_op(fn): 71 """Wrapper that restricts `fn` to have the correct signature.""" 72 73 def unary_op_wrapper(x, name=None): 74 return fn(x, name=name) 75 76 return unary_op_wrapper 77 78 79abs = _unary_op(math_ops.abs) 80# TODO(phawkins): implement clz. 81conj = _unary_op(math_ops.conj) 82cos = _unary_op(math_ops.cos) 83ceil = _unary_op(math_ops.ceil) 84digamma = _unary_op(math_ops.digamma) 85erf = _unary_op(math_ops.erf) 86erfc = _unary_op(math_ops.erfc) 87erfinv = _unary_op(math_ops.erfinv) 88ndtri = _unary_op(math_ops.ndtri) 89exp = _unary_op(math_ops.exp) 90expm1 = _unary_op(math_ops.expm1) 91floor = _unary_op(math_ops.floor) 92imag = _unary_op(math_ops.imag) 93is_finite = _unary_op(math_ops.is_finite) 94lgamma = _unary_op(math_ops.lgamma) 95log = _unary_op(math_ops.log) 96log1p = _unary_op(math_ops.log1p) 97logical_not = _unary_op(math_ops.logical_not) 98neg = _unary_op(math_ops.neg) 99real = _unary_op(math_ops.real) 100# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for 101# numbers halfway between two integers. 102round = _unary_op(math_ops.round) 103sin = _unary_op(math_ops.sin) 104sign = _unary_op(math_ops.sign) 105tanh = _unary_op(math_ops.tanh) 106 107# Bessel 108bessel_i0e = _unary_op(special_math_ops.bessel_i0e) 109bessel_i1e = _unary_op(special_math_ops.bessel_i1e) 110 111# Binary operators 112 113# The main difference between TensorFlow and XLA binary ops is the broadcasting 114# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA 115# requires an explicit specification of which dimensions to broadcast if the 116# arguments have different ranks. 117 118 119def _broadcasting_binary_op(fn): 120 """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" 121 122 def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): 123 """Inner wrapper function.""" 124 broadcast_dims = broadcast_dims or [] 125 broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) 126 # Rather than relying on having static shape information in the TensorFlow 127 # graph, we use an XlaBroadcastHelper op that can compute the correct shapes 128 # at JIT compilation time. 129 x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) 130 return fn(x, y, name=name) 131 132 return broadcasting_binary_op_wrapper 133 134 135# Map from TF signed types to TF unsigned types. 136_SIGNED_TO_UNSIGNED_TABLE = { 137 dtypes.int8: dtypes.uint8, 138 dtypes.int16: dtypes.uint16, 139 dtypes.int32: dtypes.uint32, 140 dtypes.int64: dtypes.uint64, 141} 142 143# Map from TF unsigned types to TF signed types. 144_UNSIGNED_TO_SIGNED_TABLE = { 145 dtypes.uint8: dtypes.int8, 146 dtypes.uint16: dtypes.int16, 147 dtypes.uint32: dtypes.int32, 148 dtypes.uint64: dtypes.int64, 149} 150 151 152def _shift_right_logical_helper(x, y, name=None): 153 """Performs an integer right logical shift irrespective of input type.""" 154 assert y.dtype == x.dtype 155 dtype = x.dtype 156 signed = dtype in _SIGNED_TO_UNSIGNED_TABLE 157 if signed: 158 unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] 159 x = math_ops.cast(x, unsigned_dtype) 160 y = math_ops.cast(y, unsigned_dtype) 161 output = bitwise_ops.right_shift(x, y, name=name) 162 if signed: 163 output = math_ops.cast(output, dtype) 164 return output 165 166 167def _shift_right_arithmetic_helper(x, y, name=None): 168 """Performs an integer right arithmetic shift irrespective of input type.""" 169 assert y.dtype == x.dtype 170 dtype = x.dtype 171 unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE 172 if unsigned: 173 signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] 174 x = math_ops.cast(x, signed_dtype) 175 y = math_ops.cast(y, signed_dtype) 176 output = bitwise_ops.right_shift(x, y, name=name) 177 if unsigned: 178 output = math_ops.cast(output, dtype) 179 return output 180 181 182add = _broadcasting_binary_op(math_ops.add) 183sub = _broadcasting_binary_op(math_ops.sub) 184mul = _broadcasting_binary_op(math_ops.mul) 185div = _broadcasting_binary_op(math_ops.div) 186rem = _broadcasting_binary_op(gen_math_ops.mod) 187max = _broadcasting_binary_op(math_ops.maximum) 188min = _broadcasting_binary_op(math_ops.minimum) 189atan2 = _broadcasting_binary_op(math_ops.atan2) 190complex = _broadcasting_binary_op(math_ops.complex) 191logical_and = _broadcasting_binary_op(math_ops.logical_and) 192logical_or = _broadcasting_binary_op(math_ops.logical_or) 193logical_xor = _broadcasting_binary_op(math_ops.logical_xor) 194eq = _broadcasting_binary_op(math_ops.equal) 195ne = _broadcasting_binary_op(math_ops.not_equal) 196ge = _broadcasting_binary_op(math_ops.greater_equal) 197gt = _broadcasting_binary_op(math_ops.greater) 198le = _broadcasting_binary_op(math_ops.less_equal) 199lt = _broadcasting_binary_op(math_ops.less) 200pow = _broadcasting_binary_op(math_ops.pow) 201shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) 202shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) 203shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) 204 205igamma = _broadcasting_binary_op(math_ops.igamma) 206igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) 207random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) 208igammac = _broadcasting_binary_op(math_ops.igammac) 209polygamma = _broadcasting_binary_op(math_ops.polygamma) 210zeta = _broadcasting_binary_op(math_ops.zeta) 211 212 213def _binary_op(fn): 214 """Wrapper that restricts `fn` to have the correct signature.""" 215 216 def binary_op_wrapper(x, y, name=None): 217 return fn(x, y, name=name) 218 219 return binary_op_wrapper 220 221 222transpose = _binary_op(array_ops.transpose) 223rev = _binary_op(array_ops.reverse) 224 225bitcast_convert_type = array_ops.bitcast 226 227 228def broadcast(x, dims, name=None): 229 x = ops.convert_to_tensor(x) 230 shape = array_ops.concat([constant_op.constant(dims), 231 array_ops.shape(x)], 232 axis=0) 233 return array_ops.broadcast_to(x, shape, name=name) 234 235 236def clamp(a, x, b, name=None): 237 return min(max(a, x, name=name), b, name=name) 238 239 240concatenate = array_ops.concat 241 242 243def conv(lhs, 244 rhs, 245 window_strides, 246 padding, 247 lhs_dilation, 248 rhs_dilation, 249 dimension_numbers, 250 feature_group_count=1, 251 precision_config=None, 252 preferred_element_type=None, 253 name=None, 254 use_v2=False, 255 batch_group_count=1): 256 """Wraps the XLA ConvGeneralDilated operator. 257 258 ConvGeneralDilated is the most general form of XLA convolution and is 259 documented at 260 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 261 262 Args: 263 lhs: the input tensor 264 rhs: the kernel tensor 265 window_strides: the inter-window strides 266 padding: the padding to apply at the start and end of each input dimensions 267 lhs_dilation: dilation to apply between input elements 268 rhs_dilation: dilation to apply between kernel elements 269 dimension_numbers: a `ConvolutionDimensionNumbers` proto. 270 feature_group_count: number of feature groups for grouped convolution. 271 precision_config: a `xla.PrecisionConfig` proto. 272 preferred_element_type: the result `dtype`. 273 name: an optional name for the operator. 274 use_v2: an optional request to use the XlaConvV2 op even if not necessary. 275 batch_group_count: number of batch groups or grouped filters. 276 277 Returns: 278 A tensor representing the output of the convolution. 279 """ 280 precision_config_proto = "" 281 if precision_config: 282 precision_config_proto = precision_config.SerializeToString() 283 needs_v2 = ( 284 preferred_element_type or (lhs.dtype != rhs.dtype) or 285 batch_group_count > 1) 286 if preferred_element_type is None: 287 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) 288 if needs_v2 or use_v2: 289 return gen_xla_ops.xla_conv_v2( 290 lhs, 291 rhs, 292 window_strides=window_strides, 293 padding=padding, 294 lhs_dilation=lhs_dilation, 295 rhs_dilation=rhs_dilation, 296 feature_group_count=feature_group_count, 297 batch_group_count=batch_group_count, 298 dimension_numbers=dimension_numbers.SerializeToString(), 299 precision_config=precision_config_proto, 300 preferred_element_type=preferred_element_type, 301 name=name) 302 return gen_xla_ops.xla_conv( 303 lhs, 304 rhs, 305 window_strides=window_strides, 306 padding=padding, 307 lhs_dilation=lhs_dilation, 308 rhs_dilation=rhs_dilation, 309 feature_group_count=feature_group_count, 310 dimension_numbers=dimension_numbers.SerializeToString(), 311 precision_config=precision_config_proto, 312 name=name) 313 314 315convert_element_type = math_ops.cast 316 317 318def dot(lhs, rhs, name=None): 319 return math_ops.tensordot(lhs, rhs, axes=1, name=name) 320 321 322DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers 323PrecisionConfig = xla_data_pb2.PrecisionConfig 324 325 326def dot_general(lhs, 327 rhs, 328 dimension_numbers, 329 precision_config=None, 330 preferred_element_type=None, 331 name=None, 332 use_v2=False): 333 precision_config_proto = "" 334 if precision_config: 335 precision_config_proto = precision_config.SerializeToString() 336 needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype) 337 if preferred_element_type is None: 338 preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype) 339 if needs_v2 or use_v2: 340 return gen_xla_ops.xla_dot_v2( 341 lhs, 342 rhs, 343 dimension_numbers=dimension_numbers.SerializeToString(), 344 precision_config=precision_config_proto, 345 preferred_element_type=preferred_element_type, 346 name=name) 347 return gen_xla_ops.xla_dot( 348 lhs, 349 rhs, 350 dimension_numbers=dimension_numbers.SerializeToString(), 351 precision_config=precision_config_proto, 352 name=name) 353 354 355def self_adjoint_eig(a, lower, max_iter, epsilon): 356 return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) 357 358 359def svd(a, max_iter, epsilon, precision_config=None): 360 precision_config_proto = "" 361 if precision_config: 362 precision_config_proto = precision_config.SerializeToString() 363 return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) 364 365 366dynamic_slice = gen_xla_ops.xla_dynamic_slice 367dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice 368einsum = gen_xla_ops.xla_einsum 369 370# TODO(phawkins): generalize tf.pad to support interior padding, and then remove 371# the XLA-specific pad operator. 372pad = gen_xla_ops.xla_pad 373 374 375def random_normal(mu, sigma, dims, name=None): 376 mu = ops.convert_to_tensor(mu) 377 return random_ops.random_normal( 378 dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) 379 380 381def random_uniform(minval, maxval, dims, name=None): 382 minval = ops.convert_to_tensor(minval) 383 return random_ops.random_uniform( 384 dims, minval, maxval, dtype=minval.dtype, name=name) 385 386 387def rng_bit_generator(algorithm, initial_state, shape, dtype): 388 """Stateless PRNG bit generator. 389 390 Wraps the XLA RngBitGenerator operator, documented at 391 https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator. 392 393 Args: 394 algorithm: The PRNG algorithm to use, one of 395 tf.random.Algorithm.{PHILOX, THREEFRY, AUTO_SELECT}. 396 initial_state: Initial state for the PRNG algorithm. For THREEFRY, it 397 should be a u64[2] and for PHILOX a u64[3]. 398 shape: The output shape of the generated data. 399 dtype: The type of the tensor. 400 401 Returns: 402 a tuple with a new state and generated data of the given shape. 403 """ 404 alg_int = stateless_random_ops.convert_alg_to_int(algorithm) 405 return gen_xla_ops.xla_rng_bit_generator(alg_int, initial_state, shape, 406 dtype=dtype) 407 408 409recv = gen_xla_ops.xla_recv 410reduce = gen_xla_ops.xla_reduce 411variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2 412 413ops.no_gradient("XlaVariadicReduce") 414 415 416def reduce_window(operand, 417 init, 418 reducer, 419 window_dimensions, 420 window_strides=None, 421 base_dilations=None, 422 window_dilations=None, 423 padding=None, 424 name=None): 425 """Wraps the XLA ReduceWindow operator. 426 427 ReduceWindow is documented at 428 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . 429 430 Args: 431 operand: the input tensor 432 init: a scalar tensor representing the initial value for the reduction 433 reducer: a reduction function that combines a pair of scalars. 434 window_dimensions: shape of the window, as a list of integers 435 window_strides: inter-window strides, as a list of integers. Optional; if 436 omitted, defaults to strides of 1. 437 padding: padding to apply to 'operand'. List of (low, high) pairs of 438 integers that specify the padding to apply before and after each 439 dimension. Optional; if omitted, defaults to no padding. 440 name: the operator name, or None. 441 442 Returns: 443 A tensor that represents the output of the reduce_window operator. 444 """ 445 window_strides = window_strides or [1] * len(window_dimensions) 446 base_dilations = base_dilations or [1] * len(window_dimensions) 447 window_dilations = window_dilations or [1] * len(window_dimensions) 448 padding = padding or [(0, 0)] * len(window_dimensions) 449 return gen_xla_ops.xla_reduce_window( 450 input=operand, 451 init_value=init, 452 window_dimensions=window_dimensions, 453 window_strides=window_strides, 454 base_dilations=base_dilations, 455 window_dilations=window_dilations, 456 padding=padding, 457 computation=reducer, 458 name=name) 459 460 461replica_id = gen_xla_ops.xla_replica_id 462 463# Set a static bound for the given input value as a hint to Xla compiler, 464# returns the same value. 465# Usage: 466# def f(t, p): 467# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. 468# return t[:p] # xla knows the bound of the slice is 3. 469set_bound = gen_xla_ops.xla_set_bound 470 471 472# Make a static dimension into a xla bounded dynamic dimension. The current 473# static dimension size will become the bound and the second operand becomes the 474# dynamic size of the dimension. 475# 476# This should mostly be used for testing. 477# 478# def f(): 479# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) 480# # Tells xla the valid size of the array is 3. 481# dim = 0 482# p = xla_set_dynamic_dimension_size(array, dim, 3) 483# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. 484set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size 485 486 487# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic 488# dimension into a static dimension. The bound of the size of dimension 489# `dim_index` becomes the static dimension size. 490remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size 491 492 493def reshape(x, new_sizes, dimensions=None, name=None): 494 if dimensions is not None: 495 x = array_ops.transpose(x, dimensions) 496 x = array_ops.reshape(x, new_sizes, name=name) 497 return x 498 499 500def select(condition, x, y, name=None): 501 return array_ops.where(condition, x, y, name) 502 503 504select_and_scatter = gen_xla_ops.xla_select_and_scatter 505send = gen_xla_ops.xla_send 506 507 508def slice(x, start_dims, limit_dims, strides): 509 spec = [ 510 _slice(start, limit, stride) 511 for (start, limit, stride) in zip(start_dims, limit_dims, strides) 512 ] 513 return x[tuple(spec)] 514 515 516sharding = gen_xla_ops.xla_sharding 517 518 519@ops.RegisterGradient("XlaSharding") 520def _sharding_grad(op, grad): 521 """Gradient for XlaSharding op.""" 522 sharding_attr = op.get_attr("sharding") 523 grad_sharding = gen_xla_ops.xla_sharding( 524 grad, 525 sharding=sharding_attr, 526 unspecified_dims=op.get_attr("unspecified_dims")) 527 # pylint: disable=protected-access 528 grad_sharding.op._set_attr("_XlaSharding", 529 attr_value_pb2.AttrValue(s=sharding_attr)) 530 return [grad_sharding] 531 532 533spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape 534spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape 535 536 537@ops.RegisterGradient("XlaSpmdFullToShardShape") 538def _spmd_full_to_shard_shape_grad(op, grad): 539 s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( 540 grad, 541 manual_sharding=op.get_attr("manual_sharding"), 542 full_shape=op.inputs[0].shape.as_list(), 543 dim=op.get_attr("dim"), 544 unspecified_dims=op.get_attr("unspecified_dims")) 545 return [s2f] 546 547 548@ops.RegisterGradient("XlaSpmdShardToFullShape") 549def _spmd_shard_to_full_shape_grad(op, grad): 550 f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( 551 grad, 552 manual_sharding=op.get_attr("manual_sharding"), 553 dim=op.get_attr("dim"), 554 unspecified_dims=op.get_attr("unspecified_dims")) 555 return [f2s] 556 557 558sort = gen_xla_ops.xla_sort 559key_value_sort = gen_xla_ops.xla_key_value_sort 560variadic_sort = gen_xla_ops.xla_variadic_sort 561while_loop = gen_xla_ops.xla_while 562dequantize = gen_xla_ops.xla_dequantize 563custom_call = gen_xla_ops.xla_custom_call 564 565 566def call_module(args, *, module, Tout, Sout, dim_args_spec=()): 567 return gen_xla_ops.xla_call_module( 568 args, module=module, dim_args_spec=dim_args_spec, Tout=Tout, Sout=Sout) 569 570 571def gather(operand, start_indices, dimension_numbers, slice_sizes, 572 indices_are_sorted=False, name=None): 573 return gen_xla_ops.xla_gather( 574 operand, 575 start_indices, 576 slice_sizes=slice_sizes, 577 dimension_numbers=dimension_numbers.SerializeToString(), 578 indices_are_sorted=indices_are_sorted, 579 name=name) 580 581 582def scatter(operand, scatter_indices, updates, update_computation, 583 dimension_numbers, indices_are_sorted=False, name=None): 584 return gen_xla_ops.xla_scatter( 585 operand, 586 scatter_indices, 587 updates, 588 update_computation=update_computation, 589 dimension_numbers=dimension_numbers.SerializeToString(), 590 indices_are_sorted=indices_are_sorted, 591 name=name) 592 593 594def optimization_barrier(*args): 595 return gen_xla_ops.xla_optimization_barrier(args) 596 597 598def reduce_precision(operand, exponent_bits, mantissa_bits): 599 return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits) 600