1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental library that exposes XLA operations directly in TensorFlow. 16 17It is sometimes useful to be able to build HLO programs directly from 18TensorFlow. This file provides Tensorflow operators that mirror the semantics of 19HLO operators as closely as possible. 20 21Note: There is no promise of backward or forward compatibility for operators 22defined in this module. This is primarily because the underlying HLO operators 23do not promise backward or forward compatibility. 24""" 25 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30from tensorflow.compiler.tf2xla.ops import gen_xla_ops 31from tensorflow.core.framework import attr_value_pb2 32from tensorflow.python.framework import constant_op 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.ops import array_ops 36from tensorflow.python.ops import bitwise_ops 37from tensorflow.python.ops import gen_math_ops 38from tensorflow.python.ops import gen_random_ops 39from tensorflow.python.ops import math_ops 40from tensorflow.python.ops import random_ops 41from tensorflow.python.ops import special_math_ops 42 43# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing 44# ops include: 45# infeed/outfeed (available via tf.contrib.tpu) 46# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) 47# conditional 48# gather/scatter 49# collapse 50 51# This file reuses builtin names (following XLA's names, so we can call things 52# like xla.max), so we capture the builtin versions here. 53# pylint: disable=redefined-builtin 54_max = max 55_min = min 56_slice = slice # pylint: disable=invalid-name 57 58constant = constant_op.constant 59 60# Unary operators. 61 62# For most arithmetic operators there is a TensorFlow operator 63# that exactly corresponds to each XLA operator. Rather than defining 64# XLA-specific variants, we reuse the corresponding TensorFlow operator. 65# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 66# wrap every HLO operator, because that would allow us to be confident that the 67# semantics match. 68 69 70def _unary_op(fn): 71 """Wrapper that restricts `fn` to have the correct signature.""" 72 73 def unary_op_wrapper(x, name=None): 74 return fn(x, name=name) 75 76 return unary_op_wrapper 77 78 79abs = _unary_op(math_ops.abs) 80# TODO(phawkins): implement clz. 81conj = _unary_op(math_ops.conj) 82cos = _unary_op(math_ops.cos) 83ceil = _unary_op(math_ops.ceil) 84digamma = _unary_op(math_ops.digamma) 85erf = _unary_op(math_ops.erf) 86erfc = _unary_op(math_ops.erfc) 87erfinv = _unary_op(math_ops.erfinv) 88ndtri = _unary_op(math_ops.ndtri) 89exp = _unary_op(math_ops.exp) 90expm1 = _unary_op(math_ops.expm1) 91floor = _unary_op(math_ops.floor) 92imag = _unary_op(math_ops.imag) 93is_finite = _unary_op(math_ops.is_finite) 94lgamma = _unary_op(math_ops.lgamma) 95log = _unary_op(math_ops.log) 96log1p = _unary_op(math_ops.log1p) 97logical_not = _unary_op(math_ops.logical_not) 98neg = _unary_op(math_ops.neg) 99real = _unary_op(math_ops.real) 100# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for 101# numbers halfway between two integers. 102round = _unary_op(math_ops.round) 103sin = _unary_op(math_ops.sin) 104sign = _unary_op(math_ops.sign) 105tanh = _unary_op(math_ops.tanh) 106 107# Bessel 108bessel_i0e = _unary_op(special_math_ops.bessel_i0e) 109bessel_i1e = _unary_op(special_math_ops.bessel_i1e) 110 111# Binary operators 112 113# The main difference between TensorFlow and XLA binary ops is the broadcasting 114# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA 115# requires an explicit specification of which dimensions to broadcast if the 116# arguments have different ranks. 117 118 119def _broadcasting_binary_op(fn): 120 """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" 121 122 def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): 123 """Inner wrapper function.""" 124 broadcast_dims = broadcast_dims or [] 125 broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) 126 # Rather than relying on having static shape information in the TensorFlow 127 # graph, we use an XlaBroadcastHelper op that can compute the correct shapes 128 # at JIT compilation time. 129 x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) 130 return fn(x, y, name=name) 131 132 return broadcasting_binary_op_wrapper 133 134 135# Map from TF signed types to TF unsigned types. 136_SIGNED_TO_UNSIGNED_TABLE = { 137 dtypes.int8: dtypes.uint8, 138 dtypes.int16: dtypes.uint16, 139 dtypes.int32: dtypes.uint32, 140 dtypes.int64: dtypes.uint64, 141} 142 143# Map from TF unsigned types to TF signed types. 144_UNSIGNED_TO_SIGNED_TABLE = { 145 dtypes.uint8: dtypes.int8, 146 dtypes.uint16: dtypes.int16, 147 dtypes.uint32: dtypes.int32, 148 dtypes.uint64: dtypes.int64, 149} 150 151 152def _shift_right_logical_helper(x, y, name=None): 153 """Performs an integer right logical shift irrespective of input type.""" 154 assert y.dtype == x.dtype 155 dtype = x.dtype 156 signed = dtype in _SIGNED_TO_UNSIGNED_TABLE 157 if signed: 158 unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] 159 x = math_ops.cast(x, unsigned_dtype) 160 y = math_ops.cast(y, unsigned_dtype) 161 output = bitwise_ops.right_shift(x, y, name=name) 162 if signed: 163 output = math_ops.cast(output, dtype) 164 return output 165 166 167def _shift_right_arithmetic_helper(x, y, name=None): 168 """Performs an integer right arithmetic shift irrespective of input type.""" 169 assert y.dtype == x.dtype 170 dtype = x.dtype 171 unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE 172 if unsigned: 173 signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] 174 x = math_ops.cast(x, signed_dtype) 175 y = math_ops.cast(y, signed_dtype) 176 output = bitwise_ops.right_shift(x, y, name=name) 177 if unsigned: 178 output = math_ops.cast(output, dtype) 179 return output 180 181 182add = _broadcasting_binary_op(math_ops.add) 183sub = _broadcasting_binary_op(math_ops.sub) 184mul = _broadcasting_binary_op(math_ops.mul) 185div = _broadcasting_binary_op(math_ops.div) 186rem = _broadcasting_binary_op(gen_math_ops.mod) 187max = _broadcasting_binary_op(math_ops.maximum) 188min = _broadcasting_binary_op(math_ops.minimum) 189atan2 = _broadcasting_binary_op(math_ops.atan2) 190complex = _broadcasting_binary_op(math_ops.complex) 191logical_and = _broadcasting_binary_op(math_ops.logical_and) 192logical_or = _broadcasting_binary_op(math_ops.logical_or) 193logical_xor = _broadcasting_binary_op(math_ops.logical_xor) 194eq = _broadcasting_binary_op(math_ops.equal) 195ne = _broadcasting_binary_op(math_ops.not_equal) 196ge = _broadcasting_binary_op(math_ops.greater_equal) 197gt = _broadcasting_binary_op(math_ops.greater) 198le = _broadcasting_binary_op(math_ops.less_equal) 199lt = _broadcasting_binary_op(math_ops.less) 200pow = _broadcasting_binary_op(math_ops.pow) 201shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) 202shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) 203shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) 204 205igamma = _broadcasting_binary_op(math_ops.igamma) 206igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a) 207random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad) 208igammac = _broadcasting_binary_op(math_ops.igammac) 209polygamma = _broadcasting_binary_op(math_ops.polygamma) 210zeta = _broadcasting_binary_op(math_ops.zeta) 211 212 213def _binary_op(fn): 214 """Wrapper that restricts `fn` to have the correct signature.""" 215 216 def binary_op_wrapper(x, y, name=None): 217 return fn(x, y, name=name) 218 219 return binary_op_wrapper 220 221 222transpose = _binary_op(array_ops.transpose) 223rev = _binary_op(array_ops.reverse) 224 225bitcast_convert_type = array_ops.bitcast 226 227 228def broadcast(x, dims, name=None): 229 x = ops.convert_to_tensor(x) 230 shape = array_ops.concat([constant_op.constant(dims), 231 array_ops.shape(x)], 232 axis=0) 233 return array_ops.broadcast_to(x, shape, name=name) 234 235 236def clamp(a, x, b, name=None): 237 return min(max(a, x, name=name), b, name=name) 238 239 240concatenate = array_ops.concat 241 242 243def conv(lhs, 244 rhs, 245 window_strides, 246 padding, 247 lhs_dilation, 248 rhs_dilation, 249 dimension_numbers, 250 feature_group_count=1, 251 precision_config=None, 252 name=None): 253 """Wraps the XLA ConvGeneralDilated operator. 254 255 ConvGeneralDilated is the most general form of XLA convolution and is 256 documented at 257 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 258 259 Args: 260 lhs: the input tensor 261 rhs: the kernel tensor 262 window_strides: the inter-window strides 263 padding: the padding to apply at the start and end of each input dimensions 264 lhs_dilation: dilation to apply between input elements 265 rhs_dilation: dilation to apply between kernel elements 266 dimension_numbers: a `ConvolutionDimensionNumbers` proto. 267 feature_group_count: number of feature groups for grouped convolution. 268 precision_config: a `xla.PrecisionConfig` proto. 269 name: an optional name for the operator 270 271 Returns: 272 A tensor representing the output of the convolution. 273 """ 274 precision_config_proto = "" 275 if precision_config: 276 precision_config_proto = precision_config.SerializeToString() 277 return gen_xla_ops.xla_conv( 278 lhs, 279 rhs, 280 window_strides=window_strides, 281 padding=padding, 282 lhs_dilation=lhs_dilation, 283 rhs_dilation=rhs_dilation, 284 feature_group_count=feature_group_count, 285 dimension_numbers=dimension_numbers.SerializeToString(), 286 precision_config=precision_config_proto, 287 name=name) 288 289 290convert_element_type = math_ops.cast 291 292 293def dot(lhs, rhs, name=None): 294 return math_ops.tensordot(lhs, rhs, axes=1, name=name) 295 296 297def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): 298 precision_config_proto = "" 299 if precision_config: 300 precision_config_proto = precision_config.SerializeToString() 301 return gen_xla_ops.xla_dot( 302 lhs, 303 rhs, 304 dimension_numbers=dimension_numbers.SerializeToString(), 305 precision_config=precision_config_proto, 306 name=name) 307 308 309def self_adjoint_eig(a, lower, max_iter, epsilon): 310 return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) 311 312 313def svd(a, max_iter, epsilon, precision_config=None): 314 precision_config_proto = "" 315 if precision_config: 316 precision_config_proto = precision_config.SerializeToString() 317 return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) 318 319 320dynamic_slice = gen_xla_ops.xla_dynamic_slice 321dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice 322einsum = gen_xla_ops.xla_einsum 323 324# TODO(phawkins): generalize tf.pad to support interior padding, and then remove 325# the XLA-specific pad operator. 326pad = gen_xla_ops.xla_pad 327 328 329def random_normal(mu, sigma, dims, name=None): 330 mu = ops.convert_to_tensor(mu) 331 return random_ops.random_normal( 332 dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) 333 334 335def random_uniform(minval, maxval, dims, name=None): 336 minval = ops.convert_to_tensor(minval) 337 return random_ops.random_uniform( 338 dims, minval, maxval, dtype=minval.dtype, name=name) 339 340 341recv = gen_xla_ops.xla_recv 342reduce = gen_xla_ops.xla_reduce 343variadic_reduce = gen_xla_ops.xla_variadic_reduce 344 345ops.no_gradient("XlaVariadicReduce") 346 347 348def reduce_window(operand, 349 init, 350 reducer, 351 window_dimensions, 352 window_strides=None, 353 base_dilations=None, 354 window_dilations=None, 355 padding=None, 356 name=None): 357 """Wraps the XLA ReduceWindow operator. 358 359 ReduceWindow is documented at 360 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . 361 362 Args: 363 operand: the input tensor 364 init: a scalar tensor representing the initial value for the reduction 365 reducer: a reduction function that combines a pair of scalars. 366 window_dimensions: shape of the window, as a list of integers 367 window_strides: inter-window strides, as a list of integers. Optional; if 368 omitted, defaults to strides of 1. 369 padding: padding to apply to 'operand'. List of (low, high) pairs of 370 integers that specify the padding to apply before and after each 371 dimension. Optional; if omitted, defaults to no padding. 372 name: the operator name, or None. 373 374 Returns: 375 A tensor that represents the output of the reduce_window operator. 376 """ 377 window_strides = window_strides or [1] * len(window_dimensions) 378 base_dilations = base_dilations or [1] * len(window_dimensions) 379 window_dilations = window_dilations or [1] * len(window_dimensions) 380 padding = padding or [(0, 0)] * len(window_dimensions) 381 return gen_xla_ops.xla_reduce_window( 382 input=operand, 383 init_value=init, 384 window_dimensions=window_dimensions, 385 window_strides=window_strides, 386 base_dilations=base_dilations, 387 window_dilations=window_dilations, 388 padding=padding, 389 computation=reducer, 390 name=name) 391 392 393replica_id = gen_xla_ops.xla_replica_id 394 395# Set a static bound for the given input value as a hint to Xla compiler, 396# returns the same value. 397# Usage: 398# def f(t, p): 399# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3. 400# return t[:p] # xla knows the bound of the slice is 3. 401set_bound = gen_xla_ops.xla_set_bound 402 403 404# Make a static dimension into a xla bounded dynamic dimension. The current 405# static dimension size will become the bound and the second operand becomes the 406# dynamic size of the dimension. 407# 408# This should mostly be used for testing. 409# 410# def f(): 411# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]]) 412# # Tells xla the valid size of the array is 3. 413# dim = 0 414# p = xla_set_dynamic_dimension_size(array, dim, 3) 415# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid. 416set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size 417 418 419def reshape(x, new_sizes, dimensions=None, name=None): 420 if dimensions is not None: 421 x = array_ops.transpose(x, dimensions) 422 x = array_ops.reshape(x, new_sizes, name=name) 423 return x 424 425 426def select(condition, x, y, name=None): 427 return array_ops.where(condition, x, y, name) 428 429 430select_and_scatter = gen_xla_ops.xla_select_and_scatter 431send = gen_xla_ops.xla_send 432 433 434def slice(x, start_dims, limit_dims, strides): 435 spec = [ 436 _slice(start, limit, stride) 437 for (start, limit, stride) in zip(start_dims, limit_dims, strides) 438 ] 439 return x[tuple(spec)] 440 441 442sharding = gen_xla_ops.xla_sharding 443 444 445@ops.RegisterGradient("XlaSharding") 446def _sharding_grad(op, grad): 447 sharding_attr = op.get_attr("sharding") 448 grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr) 449 # pylint: disable=protected-access 450 grad_sharding.op._set_attr("_XlaSharding", 451 attr_value_pb2.AttrValue(s=sharding_attr)) 452 return [grad_sharding] 453 454 455spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape 456spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape 457 458 459@ops.RegisterGradient("XlaSpmdFullToShardShape") 460def _spmd_full_to_shard_shape_grad(op, grad): 461 s2f = gen_xla_ops.xla_spmd_shard_to_full_shape( 462 grad, 463 manual_sharding=op.get_attr("manual_sharding"), 464 full_shape=op.inputs[0].shape.as_list()) 465 return [s2f] 466 467 468@ops.RegisterGradient("XlaSpmdShardToFullShape") 469def _spmd_shard_to_full_shape_grad(op, grad): 470 f2s = gen_xla_ops.xla_spmd_full_to_shard_shape( 471 grad, manual_sharding=op.get_attr("manual_sharding")) 472 return [f2s] 473 474 475sort = gen_xla_ops.xla_sort 476key_value_sort = gen_xla_ops.xla_key_value_sort 477variadic_sort = gen_xla_ops.xla_variadic_sort 478while_loop = gen_xla_ops.xla_while 479dequantize = gen_xla_ops.xla_dequantize 480 481 482def gather(operand, start_indices, dimension_numbers, slice_sizes, 483 indices_are_sorted=False, name=None): 484 return gen_xla_ops.xla_gather( 485 operand, 486 start_indices, 487 slice_sizes=slice_sizes, 488 dimension_numbers=dimension_numbers.SerializeToString(), 489 indices_are_sorted=indices_are_sorted, 490 name=name) 491 492 493def scatter(operand, scatter_indices, updates, update_computation, 494 dimension_numbers, indices_are_sorted=False, name=None): 495 return gen_xla_ops.xla_scatter( 496 operand, 497 scatter_indices, 498 updates, 499 update_computation=update_computation, 500 dimension_numbers=dimension_numbers.SerializeToString(), 501 indices_are_sorted=indices_are_sorted, 502 name=name) 503