1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Experimental library that exposes XLA operations directly in TensorFlow. 16 17It is sometimes useful to be able to build HLO programs directly from 18TensorFlow. This file provides Tensorflow operators that mirror the semantics of 19HLO operators as closely as possible. 20 21Note: There is no promise of backward or forward compatibility for operators 22defined in this module. This is primarily because the underlying HLO operators 23do not promise backward or forward compatibility. 24""" 25 26from __future__ import absolute_import 27from __future__ import division 28from __future__ import print_function 29 30from tensorflow.compiler.tf2xla.ops import gen_xla_ops 31from tensorflow.python.framework import constant_op 32from tensorflow.python.framework import dtypes 33from tensorflow.python.framework import ops 34from tensorflow.python.ops import array_ops 35from tensorflow.python.ops import bitwise_ops 36from tensorflow.python.ops import gen_math_ops 37from tensorflow.python.ops import math_ops 38from tensorflow.python.ops import random_ops 39 40# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing 41# ops include: 42# infeed/outfeed (available via tf.contrib.tpu) 43# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu) 44# conditional 45# gather/scatter 46# collapse 47 48# This file reuses builtin names (following XLA's names, so we can call things 49# like xla.max), so we capture the builtin versions here. 50# pylint: disable=redefined-builtin 51_max = max 52_min = min 53_slice = slice # pylint: disable=invalid-name 54 55constant = constant_op.constant 56 57# Unary operators. 58 59# For most arithmetic operators there is a TensorFlow operator 60# that exactly corresponds to each XLA operator. Rather than defining 61# XLA-specific variants, we reuse the corresponding TensorFlow operator. 62# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1 63# wrap every HLO operator, because that would allow us to be confident that the 64# semantics match. 65 66 67def _unary_op(fn): 68 """Wrapper that restricts `fn` to have the correct signature.""" 69 70 def unary_op_wrapper(x, name=None): 71 return fn(x, name=name) 72 73 return unary_op_wrapper 74 75 76abs = _unary_op(math_ops.abs) 77# TODO(phawkins): implement clz. 78conj = _unary_op(math_ops.conj) 79cos = _unary_op(math_ops.cos) 80ceil = _unary_op(math_ops.ceil) 81digamma = _unary_op(math_ops.digamma) 82erf = _unary_op(math_ops.erf) 83erfc = _unary_op(math_ops.erfc) 84erfinv = _unary_op(math_ops.erfinv) 85ndtri = _unary_op(math_ops.ndtri) 86exp = _unary_op(math_ops.exp) 87expm1 = _unary_op(math_ops.expm1) 88floor = _unary_op(math_ops.floor) 89imag = _unary_op(math_ops.imag) 90is_finite = _unary_op(math_ops.is_finite) 91lgamma = _unary_op(math_ops.lgamma) 92log = _unary_op(math_ops.log) 93log1p = _unary_op(math_ops.log1p) 94logical_not = _unary_op(math_ops.logical_not) 95neg = _unary_op(math_ops.neg) 96real = _unary_op(math_ops.real) 97# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for 98# numbers halfway between two integers. 99round = _unary_op(math_ops.round) 100sin = _unary_op(math_ops.sin) 101sign = _unary_op(math_ops.sign) 102tanh = _unary_op(math_ops.tanh) 103 104# Bessel 105bessel_i0e = _unary_op(math_ops.bessel_i0e) 106bessel_i1e = _unary_op(math_ops.bessel_i1e) 107 108# Binary operators 109 110# The main difference between TensorFlow and XLA binary ops is the broadcasting 111# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA 112# requires an explicit specification of which dimensions to broadcast if the 113# arguments have different ranks. 114 115 116def _broadcasting_binary_op(fn): 117 """Wraps a binary Tensorflow operator and performs XLA-style broadcasting.""" 118 119 def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None): 120 """Inner wrapper function.""" 121 broadcast_dims = broadcast_dims or [] 122 broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64) 123 # Rather than relying on having static shape information in the TensorFlow 124 # graph, we use an XlaBroadcastHelper op that can compute the correct shapes 125 # at JIT compilation time. 126 x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims) 127 return fn(x, y, name=name) 128 129 return broadcasting_binary_op_wrapper 130 131 132# Map from TF signed types to TF unsigned types. 133_SIGNED_TO_UNSIGNED_TABLE = { 134 dtypes.int8: dtypes.uint8, 135 dtypes.int16: dtypes.uint16, 136 dtypes.int32: dtypes.uint32, 137 dtypes.int64: dtypes.uint64, 138} 139 140# Map from TF unsigned types to TF signed types. 141_UNSIGNED_TO_SIGNED_TABLE = { 142 dtypes.uint8: dtypes.int8, 143 dtypes.uint16: dtypes.int16, 144 dtypes.uint32: dtypes.int32, 145 dtypes.uint64: dtypes.int64, 146} 147 148 149def _shift_right_logical_helper(x, y, name=None): 150 """Performs an integer right logical shift irrespective of input type.""" 151 assert y.dtype == x.dtype 152 dtype = x.dtype 153 signed = dtype in _SIGNED_TO_UNSIGNED_TABLE 154 if signed: 155 unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype] 156 x = math_ops.cast(x, unsigned_dtype) 157 y = math_ops.cast(y, unsigned_dtype) 158 output = bitwise_ops.right_shift(x, y, name=name) 159 if signed: 160 output = math_ops.cast(output, dtype) 161 return output 162 163 164def _shift_right_arithmetic_helper(x, y, name=None): 165 """Performs an integer right arithmetic shift irrespective of input type.""" 166 assert y.dtype == x.dtype 167 dtype = x.dtype 168 unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE 169 if unsigned: 170 signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype] 171 x = math_ops.cast(x, signed_dtype) 172 y = math_ops.cast(y, signed_dtype) 173 output = bitwise_ops.right_shift(x, y, name=name) 174 if unsigned: 175 output = math_ops.cast(output, dtype) 176 return output 177 178 179add = _broadcasting_binary_op(math_ops.add) 180sub = _broadcasting_binary_op(math_ops.sub) 181mul = _broadcasting_binary_op(math_ops.mul) 182div = _broadcasting_binary_op(math_ops.div) 183rem = _broadcasting_binary_op(gen_math_ops.mod) 184max = _broadcasting_binary_op(math_ops.maximum) 185min = _broadcasting_binary_op(math_ops.minimum) 186atan2 = _broadcasting_binary_op(math_ops.atan2) 187complex = _broadcasting_binary_op(math_ops.complex) 188logical_and = _broadcasting_binary_op(math_ops.logical_and) 189logical_or = _broadcasting_binary_op(math_ops.logical_or) 190logical_xor = _broadcasting_binary_op(math_ops.logical_xor) 191eq = _broadcasting_binary_op(math_ops.equal) 192ne = _broadcasting_binary_op(math_ops.not_equal) 193ge = _broadcasting_binary_op(math_ops.greater_equal) 194gt = _broadcasting_binary_op(math_ops.greater) 195le = _broadcasting_binary_op(math_ops.less_equal) 196lt = _broadcasting_binary_op(math_ops.less) 197pow = _broadcasting_binary_op(math_ops.pow) 198shift_left = _broadcasting_binary_op(bitwise_ops.left_shift) 199shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper) 200shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper) 201 202igamma = _broadcasting_binary_op(math_ops.igamma) 203igammac = _broadcasting_binary_op(math_ops.igammac) 204 205 206def _binary_op(fn): 207 """Wrapper that restricts `fn` to have the correct signature.""" 208 209 def binary_op_wrapper(x, y, name=None): 210 return fn(x, y, name=name) 211 212 return binary_op_wrapper 213 214 215transpose = _binary_op(array_ops.transpose) 216rev = _binary_op(array_ops.reverse) 217 218bitcast_convert_type = array_ops.bitcast 219 220 221def broadcast(x, dims, name=None): 222 x = ops.convert_to_tensor(x) 223 shape = array_ops.concat([constant_op.constant(dims), 224 array_ops.shape(x)], 225 axis=0) 226 return array_ops.broadcast_to(x, shape, name=name) 227 228 229def clamp(a, x, b, name=None): 230 return min(max(a, x, name=name), b, name=name) 231 232 233concatenate = array_ops.concat 234 235 236def conv(lhs, 237 rhs, 238 window_strides, 239 padding, 240 lhs_dilation, 241 rhs_dilation, 242 dimension_numbers, 243 feature_group_count=1, 244 precision_config=None, 245 name=None): 246 """Wraps the XLA ConvGeneralDilated operator. 247 248 ConvGeneralDilated is the most general form of XLA convolution and is 249 documented at 250 https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution 251 252 Args: 253 lhs: the input tensor 254 rhs: the kernel tensor 255 window_strides: the inter-window strides 256 padding: the padding to apply at the start and end of each input dimensions 257 lhs_dilation: dilation to apply between input elements 258 rhs_dilation: dilation to apply between kernel elements 259 dimension_numbers: a `ConvolutionDimensionNumbers` proto. 260 feature_group_count: number of feature groups for grouped convolution. 261 precision_config: a `xla.PrecisionConfig` proto. 262 name: an optional name for the operator 263 264 Returns: 265 A tensor representing the output of the convolution. 266 """ 267 precision_config_proto = "" 268 if precision_config: 269 precision_config_proto = precision_config.SerializeToString() 270 return gen_xla_ops.xla_conv( 271 lhs, 272 rhs, 273 window_strides=window_strides, 274 padding=padding, 275 lhs_dilation=lhs_dilation, 276 rhs_dilation=rhs_dilation, 277 feature_group_count=feature_group_count, 278 dimension_numbers=dimension_numbers.SerializeToString(), 279 precision_config=precision_config_proto, 280 name=name) 281 282 283convert_element_type = math_ops.cast 284 285 286def dot(lhs, rhs, name=None): 287 return math_ops.tensordot(lhs, rhs, axes=1, name=name) 288 289 290def dot_general(lhs, rhs, dimension_numbers, precision_config=None, name=None): 291 precision_config_proto = "" 292 if precision_config: 293 precision_config_proto = precision_config.SerializeToString() 294 return gen_xla_ops.xla_dot( 295 lhs, 296 rhs, 297 dimension_numbers=dimension_numbers.SerializeToString(), 298 precision_config=precision_config_proto, 299 name=name) 300 301 302def self_adjoint_eig(a, lower, max_iter, epsilon): 303 return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon) 304 305 306def svd(a, max_iter, epsilon, precision_config=None): 307 precision_config_proto = "" 308 if precision_config: 309 precision_config_proto = precision_config.SerializeToString() 310 return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto) 311 312 313dynamic_slice = gen_xla_ops.xla_dynamic_slice 314dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice 315einsum = gen_xla_ops.xla_einsum 316 317# TODO(phawkins): generalize tf.pad to support interior padding, and then remove 318# the XLA-specific pad operator. 319pad = gen_xla_ops.xla_pad 320 321 322def random_normal(mu, sigma, dims, name=None): 323 mu = ops.convert_to_tensor(mu) 324 return random_ops.random_normal( 325 dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name) 326 327 328def random_uniform(minval, maxval, dims, name=None): 329 minval = ops.convert_to_tensor(minval) 330 return random_ops.random_uniform( 331 dims, minval, maxval, dtype=minval.dtype, name=name) 332 333 334recv = gen_xla_ops.xla_recv 335reduce = gen_xla_ops.xla_reduce 336 337 338def reduce_window(operand, 339 init, 340 reducer, 341 window_dimensions, 342 window_strides=None, 343 base_dilations=None, 344 window_dilations=None, 345 padding=None, 346 name=None): 347 """Wraps the XLA ReduceWindow operator. 348 349 ReduceWindow is documented at 350 https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow . 351 352 Args: 353 operand: the input tensor 354 init: a scalar tensor representing the initial value for the reduction 355 reducer: a reduction function that combines a pair of scalars. 356 window_dimensions: shape of the window, as a list of integers 357 window_strides: inter-window strides, as a list of integers. Optional; if 358 omitted, defaults to strides of 1. 359 padding: padding to apply to 'operand'. List of (low, high) pairs of 360 integers that specify the padding to apply before and after each 361 dimension. Optional; if omitted, defaults to no padding. 362 name: the operator name, or None. 363 364 Returns: 365 A tensor that represents the output of the reduce_window operator. 366 """ 367 window_strides = window_strides or [1] * len(window_dimensions) 368 base_dilations = base_dilations or [1] * len(window_dimensions) 369 window_dilations = window_dilations or [1] * len(window_dimensions) 370 padding = padding or [(0, 0)] * len(window_dimensions) 371 return gen_xla_ops.xla_reduce_window( 372 input=operand, 373 init_value=init, 374 window_dimensions=window_dimensions, 375 window_strides=window_strides, 376 base_dilations=base_dilations, 377 window_dilations=window_dilations, 378 padding=padding, 379 computation=reducer, 380 name=name) 381 382 383replica_id = gen_xla_ops.xla_replica_id 384 385 386def reshape(x, new_sizes, dimensions=None, name=None): 387 if dimensions is not None: 388 x = array_ops.transpose(x, dimensions) 389 x = array_ops.reshape(x, new_sizes, name=name) 390 return x 391 392 393def select(condition, x, y, name=None): 394 return array_ops.where(condition, x, y, name) 395 396 397select_and_scatter = gen_xla_ops.xla_select_and_scatter 398send = gen_xla_ops.xla_send 399 400 401def slice(x, start_dims, limit_dims, strides): 402 spec = [ 403 _slice(start, limit, stride) 404 for (start, limit, stride) in zip(start_dims, limit_dims, strides) 405 ] 406 return x[tuple(spec)] 407 408 409sharding = gen_xla_ops.xla_sharding 410 411 412@ops.RegisterGradient("XlaSharding") 413def _sharding_grad(op, grad): 414 del op # Unused 415 return [grad] 416 417 418sort = gen_xla_ops.xla_sort 419key_value_sort = gen_xla_ops.xla_key_value_sort 420while_loop = gen_xla_ops.xla_while 421dequantize = gen_xla_ops.xla_dequantize 422 423 424def gather(operand, start_indices, dimension_numbers, slice_sizes, 425 indices_are_sorted=False, name=None): 426 return gen_xla_ops.xla_gather( 427 operand, 428 start_indices, 429 slice_sizes=slice_sizes, 430 dimension_numbers=dimension_numbers.SerializeToString(), 431 indices_are_sorted=indices_are_sorted, 432 name=name) 433 434 435def scatter(operand, scatter_indices, updates, update_computation, 436 dimension_numbers, indices_are_sorted=False, name=None): 437 return gen_xla_ops.xla_scatter( 438 operand, 439 scatter_indices, 440 updates, 441 update_computation=update_computation, 442 dimension_numbers=dimension_numbers.SerializeToString(), 443 indices_are_sorted=indices_are_sorted, 444 name=name) 445