1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <numeric> 17 18 #include "tensorflow/compiler/tf2xla/type_util.h" 19 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/core/framework/bounds_check.h" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 26 namespace tensorflow { 27 namespace { 28 29 class SelectOp : public XlaOpKernel { 30 public: SelectOp(OpKernelConstruction * ctx)31 explicit SelectOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 32 Compile(XlaOpKernelContext * ctx)33 void Compile(XlaOpKernelContext* ctx) override { 34 const TensorShape cond_shape = ctx->InputShape(0); 35 const TensorShape then_shape = ctx->InputShape(1); 36 const TensorShape else_shape = ctx->InputShape(2); 37 38 OP_REQUIRES( 39 ctx, then_shape.IsSameSize(else_shape), 40 errors::InvalidArgument( 41 "'then' and 'else' must have the same size. but received: ", 42 then_shape.DebugString(), " vs. ", else_shape.DebugString())); 43 44 auto cond_handle = ctx->Input(0); 45 auto then_handle = ctx->Input(1); 46 auto else_handle = ctx->Input(2); 47 48 bool broadcasting = !cond_shape.IsSameSize(then_shape); 49 bool cond_is_scalar = TensorShapeUtils::IsScalar(cond_shape); 50 if (broadcasting && !cond_is_scalar) { 51 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(cond_shape), 52 errors::InvalidArgument( 53 "'cond' must be a scalar or a vector, but saw shape: ", 54 cond_shape.DebugString())); 55 OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(then_shape), 56 errors::InvalidArgument( 57 "'then' must be at least a vector, but saw shape: ", 58 then_shape.DebugString())); 59 OP_REQUIRES(ctx, then_shape.dim_size(0) == cond_shape.num_elements(), 60 errors::InvalidArgument("Number of batches of 'then' must " 61 "match size of 'cond', but saw: ", 62 then_shape.dim_size(0), " vs. ", 63 cond_shape.num_elements())); 64 65 // TODO(phawkins): broadcasting on the right seems pretty awkward in 66 // XLA. It seems we have to broadcast on the left and then Reshape 67 // to get the dimensions in the right order. 68 const auto dim_sizes = then_shape.dim_sizes(); 69 absl::Span<const int64> bdims = dim_sizes; 70 bdims.remove_prefix(1); 71 cond_handle = xla::Broadcast(cond_handle, bdims); 72 73 std::vector<int64> dim_order(then_shape.dims()); 74 dim_order[0] = then_shape.dims() - 1; 75 std::iota(dim_order.begin() + 1, dim_order.end(), 0); 76 cond_handle = xla::Transpose(cond_handle, dim_order); 77 } 78 ctx->SetOutput(0, xla::Select(cond_handle, then_handle, else_handle)); 79 } 80 81 private: 82 TF_DISALLOW_COPY_AND_ASSIGN(SelectOp); 83 }; 84 85 REGISTER_XLA_OP(Name("Select"), SelectOp); 86 87 } // namespace 88 } // namespace tensorflow 89