1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Shape Ops. 17 18 #include "absl/strings/str_format.h" 19 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/kernels/tensor_list_utils.h" 21 #include "tensorflow/compiler/tf2xla/shape_util.h" 22 #include "tensorflow/compiler/tf2xla/type_util.h" 23 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 25 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 26 #include "tensorflow/compiler/xla/client/lib/constants.h" 27 #include "tensorflow/compiler/xla/client/xla_builder.h" 28 #include "tensorflow/compiler/xla/literal.h" 29 #include "tensorflow/compiler/xla/shape_util.h" 30 #include "tensorflow/core/framework/bounds_check.h" 31 #include "tensorflow/core/framework/kernel_def_builder.h" 32 #include "tensorflow/core/framework/op_kernel.h" 33 #include "tensorflow/core/framework/tensor_shape.h" 34 35 namespace tensorflow { 36 namespace { 37 38 class ShapeOp : public XlaOpKernel { 39 public: ShapeOp(OpKernelConstruction * ctx)40 explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 41 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 42 } 43 Compile(XlaOpKernelContext * ctx)44 void Compile(XlaOpKernelContext* ctx) override { 45 const TensorShape input_shape = ctx->InputShape(0); 46 std::vector<xla::XlaOp> operands; 47 const int rank = input_shape.dims(); 48 if (rank != 0) { 49 for (int64_t i = 0; i < rank; ++i) { 50 operands.push_back(xla::Broadcast( 51 xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i), 52 ctx->output_xla_type(0)), 53 {1})); 54 } 55 56 ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), operands, 0)); 57 } else { 58 // Rank 0 won't have dynamic size dimension, use constant output. 59 Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); 60 OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); 61 ctx->SetConstantOutput(0, shape_constant); 62 } 63 } 64 65 private: 66 DataType out_dtype_; 67 }; 68 69 REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); 70 71 class XlaSetBoundOp : public XlaOpKernel { 72 public: XlaSetBoundOp(OpKernelConstruction * context)73 explicit XlaSetBoundOp(OpKernelConstruction* context) 74 : XlaOpKernel(context) {} 75 Compile(XlaOpKernelContext * ctx)76 void Compile(XlaOpKernelContext* ctx) override { 77 const TensorShape input_shape = ctx->InputShape("input"); 78 const TensorShape bound_shape = ctx->InputShape("bound"); 79 80 OP_REQUIRES( 81 ctx, 82 ctx->InputType("bound") == DT_INT32 && 83 ctx->InputType("input") == DT_INT32, 84 errors::InvalidArgument( 85 "XlaSetBound can only set bound for int32 scalar value: got", 86 input_shape.DebugString())); 87 88 OP_REQUIRES( 89 ctx, input_shape.dims() == 0, 90 errors::InvalidArgument("XlaSetBound should only be used to set a " 91 "bound to the an int32 scalar value: got", 92 input_shape.DebugString())); 93 94 OP_REQUIRES( 95 ctx, bound_shape.dims() == 0, 96 errors::InvalidArgument("XlaSetBound should only be used to set a " 97 "bound to the an int32 scalar value: got", 98 bound_shape.DebugString())); 99 int64_t bound; 100 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("bound", &bound)); 101 xla::Literal bound_literal = xla::LiteralUtil::CreateR0<int32>(bound); 102 xla::XlaOp result = 103 xla::CustomCall(ctx->builder(), "SetBound", {ctx->Input("input")}, 104 ctx->InputXlaShape("input").ValueOrDie(), "", false, {}, 105 &bound_literal); 106 ctx->SetOutput(0, result); 107 } 108 }; 109 110 REGISTER_XLA_OP(Name("XlaSetBound").CompileTimeConstantInput("bound"), 111 XlaSetBoundOp); 112 113 class XlaSetDynamicDimensionSizeOp : public XlaOpKernel { 114 public: XlaSetDynamicDimensionSizeOp(OpKernelConstruction * context)115 explicit XlaSetDynamicDimensionSizeOp(OpKernelConstruction* context) 116 : XlaOpKernel(context) {} 117 Compile(XlaOpKernelContext * ctx)118 void Compile(XlaOpKernelContext* ctx) override { 119 const TensorShape dim_index_shape = ctx->InputShape("dim_index"); 120 const TensorShape size_shape = ctx->InputShape("size"); 121 122 OP_REQUIRES(ctx, 123 ctx->InputType("dim_index") == DT_INT32 && 124 ctx->InputType("size") == DT_INT32, 125 errors::InvalidArgument("dim_index and size has to be int32 for" 126 "XlaSetDynamicDimensionSizeOp")); 127 128 OP_REQUIRES( 129 ctx, dim_index_shape.dims() == 0 && size_shape.dims() == 0, 130 errors::InvalidArgument("XlaSetDynamicDimensionSizeOp's dim_index and " 131 "size has to be int32 scalar value")); 132 int64_t dim_index; 133 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("dim_index", &dim_index)); 134 135 xla::XlaOp result = 136 xla::SetDimensionSize(ctx->Input(0), ctx->Input("size"), dim_index); 137 ctx->SetOutput(0, result); 138 } 139 }; 140 141 REGISTER_XLA_OP( 142 Name("XlaSetDynamicDimensionSize").CompileTimeConstantInput("dim_index"), 143 XlaSetDynamicDimensionSizeOp); 144 145 class XlaRemoveDynamicDimensionSizeOp : public XlaOpKernel { 146 public: XlaRemoveDynamicDimensionSizeOp(OpKernelConstruction * context)147 explicit XlaRemoveDynamicDimensionSizeOp(OpKernelConstruction* context) 148 : XlaOpKernel(context) {} 149 Compile(XlaOpKernelContext * ctx)150 void Compile(XlaOpKernelContext* ctx) override { 151 const TensorShape dim_index_shape = ctx->InputShape("dim_index"); 152 153 OP_REQUIRES(ctx, ctx->InputType("dim_index") == DT_INT32, 154 errors::InvalidArgument("dim_index has to be int32 for" 155 "XlaRemoveDynamicDimensionSizeOp")); 156 157 OP_REQUIRES( 158 ctx, dim_index_shape.dims() == 0, 159 errors::InvalidArgument("XlaRemoveDynamicDimensionSizeOp's dim_index " 160 "has to be int32 scalar value")); 161 int64_t dim_index; 162 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("dim_index", &dim_index)); 163 164 xla::XlaOp result = xla::RemoveDynamicDimension(ctx->Input(0), dim_index); 165 ctx->SetOutput(0, result); 166 } 167 }; 168 169 REGISTER_XLA_OP( 170 Name("XlaRemoveDynamicDimensionSize").CompileTimeConstantInput("dim_index"), 171 XlaRemoveDynamicDimensionSizeOp); 172 173 class ShapeNOp : public XlaOpKernel { 174 public: ShapeNOp(OpKernelConstruction * ctx)175 explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 176 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 177 } 178 Compile(XlaOpKernelContext * ctx)179 void Compile(XlaOpKernelContext* ctx) override { 180 for (int i = 0; i < ctx->num_inputs(); ++i) { 181 const TensorShape input_shape = ctx->InputShape(i); 182 std::vector<xla::XlaOp> operands; 183 184 const int rank = input_shape.dims(); 185 if (rank != 0) { 186 // Each dimension can be dynamic, so use GetDimensionSize to get the 187 // runtime dimension. 188 for (int64_t dim = 0; dim < rank; ++dim) { 189 operands.push_back(xla::Broadcast( 190 xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(i), dim), 191 ctx->output_xla_type(i)), 192 {1})); 193 } 194 195 ctx->SetOutput(i, xla::ConcatInDim(ctx->builder(), operands, 0)); 196 } else { 197 // Rank 0 won't have dynamic size dimension, use constant output. 198 Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); 199 OP_REQUIRES_OK(ctx, 200 TensorShapeToConstant(input_shape, &shape_constant)); 201 ctx->SetConstantOutput(i, shape_constant); 202 } 203 } 204 } 205 IsExpensive()206 bool IsExpensive() override { return false; } 207 208 private: 209 DataType out_dtype_; 210 }; 211 REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); 212 213 class RankOp : public XlaOpKernel { 214 public: RankOp(OpKernelConstruction * ctx)215 explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 216 Compile(XlaOpKernelContext * ctx)217 void Compile(XlaOpKernelContext* ctx) override { 218 const TensorShape input_shape = ctx->InputShape(0); 219 const int rank = input_shape.dims(); 220 Tensor rank_constant(DT_INT32, TensorShape({})); 221 rank_constant.scalar<int32>()() = rank; 222 223 ctx->SetConstantOutput(0, rank_constant); 224 } 225 }; 226 227 REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); 228 229 class SizeOp : public XlaOpKernel { 230 public: SizeOp(OpKernelConstruction * ctx)231 explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 232 Compile(XlaOpKernelContext * ctx)233 void Compile(XlaOpKernelContext* ctx) override { 234 const TensorShape input_shape = ctx->InputShape(0); 235 OP_REQUIRES(ctx, 236 FastBoundsCheck(input_shape.num_elements(), 237 std::numeric_limits<int32>::max()), 238 errors::InvalidArgument("Size does not work for tensors > " 239 "int32 max.")); 240 Tensor size_constant(DT_INT32, TensorShape({})); 241 const int rank = input_shape.dims(); 242 xla::XlaBuilder* builder = ctx->builder(); 243 auto size = xla::One(builder, xla::U32); 244 for (int64_t i = 0; i < rank; ++i) { 245 size = xla::Mul( 246 size, xla::ConvertElementType(xla::GetDimensionSize(ctx->Input(0), i), 247 xla::U32)); 248 } 249 size = xla::ConvertElementType(size, ctx->output_xla_type(0)); 250 ctx->SetOutput(0, size); 251 } 252 }; 253 254 REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); 255 256 class ExpandDimsOp : public XlaOpKernel { 257 public: ExpandDimsOp(OpKernelConstruction * ctx)258 explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 259 Compile(XlaOpKernelContext * ctx)260 void Compile(XlaOpKernelContext* ctx) override { 261 const TensorShape input_shape = ctx->InputShape("input"); 262 const TensorShape dim_shape = ctx->InputShape("dim"); 263 264 std::vector<int64_t> dims; 265 OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims)); 266 OP_REQUIRES(ctx, dims.size() == 1, 267 errors::InvalidArgument(absl::StrCat( 268 "dim input to ExpandDims must be a scalar; got ", 269 dim_shape.DebugString()))); 270 int dim = dims[0]; 271 272 OP_REQUIRES(ctx, 273 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), 274 errors::InvalidArgument("Tried to expand dim index ", dim, 275 " for tensor with ", input_shape.dims(), 276 " dimensions.")); 277 278 auto existing_dims = input_shape.dim_sizes(); 279 // Safe - # elements in tensor dims bounded. 280 const int existing_dims_size = static_cast<int>(existing_dims.size()); 281 std::vector<int64_t> new_shape(existing_dims_size); 282 for (size_t i = 0; i < new_shape.size(); ++i) { 283 new_shape[i] = existing_dims[i]; 284 } 285 286 // We emulate numpy's interpretation of the dim axis when 287 // -input.dims() >= dim <= input.dims(). 288 if (dim < 0) { 289 dim += existing_dims.size() + 1; 290 } 291 292 // Clamp to the end if needed. 293 dim = std::min<int32>(dim, existing_dims_size); 294 new_shape.emplace(new_shape.begin() + dim, 1); 295 296 ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); 297 } 298 }; 299 REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), 300 ExpandDimsOp); 301 302 class SqueezeOp : public XlaOpKernel { 303 public: SqueezeOp(OpKernelConstruction * ctx)304 explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 305 std::vector<int32> squeeze_dims; 306 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); 307 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); 308 } 309 Compile(XlaOpKernelContext * ctx)310 void Compile(XlaOpKernelContext* ctx) override { 311 StatusOr<xla::Shape> input_shape = ctx->builder()->GetShape(ctx->Input(0)); 312 OP_REQUIRES_OK(ctx, input_shape.status()); 313 xla::Shape shape = input_shape.ValueOrDie(); 314 int64_t rank = shape.rank(); 315 316 std::unordered_set<int32> wrapped_squeeze_dims; 317 wrapped_squeeze_dims.reserve(squeeze_dims_.size()); 318 std::vector<int64_t> new_shape; 319 // Validate squeeze dims against the input. 320 for (int32_t dim : squeeze_dims_) { 321 OP_REQUIRES( 322 ctx, (dim >= -rank && dim < rank), 323 errors::InvalidArgument("Tried to squeeze dim index ", dim, 324 " for tensor with ", rank, " dimensions.")); 325 // If dim is < 0, we wrap around (-1 means the last element). 326 if (dim < 0) { 327 dim = rank + dim; 328 } 329 330 wrapped_squeeze_dims.insert(dim); 331 } 332 333 for (int i = 0; i < rank; ++i) { 334 auto existing_dim = shape.dimensions(i); 335 336 // If squeeze_set is non-empty, only squeeze those dimensions. 337 if (!wrapped_squeeze_dims.empty()) { 338 if (wrapped_squeeze_dims.count(i) > 0) { 339 OP_REQUIRES(ctx, existing_dim == 1, 340 errors::InvalidArgument( 341 "Tried to explicitly squeeze dimension ", i, 342 " but dimension was not 1: ", existing_dim)); 343 } else { 344 // This dimension is not being squeezed. 345 new_shape.push_back(existing_dim); 346 } 347 } else { 348 OP_REQUIRES( 349 ctx, !shape.is_dynamic_dimension(i), 350 errors::InvalidArgument("Squeeze op does not support bounded " 351 "dynamic dimensions. Input shape: ", 352 shape.DebugString())); 353 // Copy over all non-1-length dimensions. 354 if (existing_dim != 1) { 355 new_shape.push_back(existing_dim); 356 } 357 } 358 } 359 360 ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); 361 } 362 363 private: 364 std::unordered_set<int32> squeeze_dims_; 365 }; 366 367 REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp); 368 369 class ZerosLikeOp : public XlaOpKernel { 370 public: ZerosLikeOp(OpKernelConstruction * ctx)371 explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 372 Compile(XlaOpKernelContext * ctx)373 void Compile(XlaOpKernelContext* ctx) override { 374 if (IsTensorListInput(ctx, 0)) { 375 // Input is a TensorList. 376 377 // Check the TensorList input is initialized. 378 xla::XlaOp list = ctx->Input(0); 379 bool is_initialized; 380 OP_REQUIRES_OK(ctx, IsTensorListInitialized(list, &is_initialized)); 381 OP_REQUIRES( 382 ctx, is_initialized, 383 errors::InvalidArgument( 384 "TensorList input for ZerosLike op is an uninitialized list")); 385 386 auto list_shape_or = ctx->builder()->GetShape(list); 387 OP_REQUIRES_OK(ctx, list_shape_or.status()); 388 const xla::Shape& list_shape = list_shape_or.ValueOrDie(); 389 std::vector<std::vector<xla::XlaOp>> list_dynamic_dims; 390 list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1); 391 for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { 392 // Set dynamic dimension size to 0 for initialization value. 393 std::vector<xla::XlaOp> dynamic_dims; 394 const xla::Shape& shape = list_shape.tuple_shapes(i); 395 auto sub_element = xla::GetTupleElement(list, i); 396 for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) { 397 dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim)); 398 } 399 list_dynamic_dims.push_back(dynamic_dims); 400 } 401 xla::XlaOp new_list; 402 OP_REQUIRES_OK( 403 ctx, CreateZerosTensorListWithShape(ctx->builder(), list_shape, 404 list_dynamic_dims, &new_list)); 405 406 xla::XlaOp push_index; 407 OP_REQUIRES_OK(ctx, GetTensorListPushIndex(list, &push_index)); 408 409 xla::XlaOp result; 410 OP_REQUIRES_OK(ctx, 411 SetTensorListPushIndex(new_list, push_index, &result)); 412 ctx->SetTensorListOutput(0, result); 413 } else { 414 auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); 415 xla::XlaOp input = ctx->Input(0); 416 auto input_shape = ctx->InputXlaShape(0).ValueOrDie(); 417 auto result = xla::Broadcast(zero, input_shape.dimensions()); 418 419 // Setting up dynamic dimensions of the broadcast. 420 for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { 421 if (input_shape.is_dynamic_dimension(i)) { 422 xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i); 423 result = xla::SetDimensionSize(result, input_dynamic_dim, i); 424 } 425 } 426 427 ctx->SetOutput(0, result); 428 } 429 } 430 }; 431 432 REGISTER_XLA_OP(Name("ZerosLike").AllowVariantTypes(), ZerosLikeOp); 433 434 class OnesLikeOp : public XlaOpKernel { 435 public: OnesLikeOp(OpKernelConstruction * ctx)436 explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 437 Compile(XlaOpKernelContext * ctx)438 void Compile(XlaOpKernelContext* ctx) override { 439 const TensorShape input_shape = ctx->InputShape(0); 440 441 auto one = XlaHelpers::One(ctx->builder(), input_type(0)); 442 ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); 443 } 444 }; 445 446 REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp); 447 448 } // namespace 449 } // namespace tensorflow 450