1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Shape Ops. 17 18 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" 19 #include "tensorflow/compiler/tf2xla/type_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 23 #include "tensorflow/compiler/xla/client/lib/constants.h" 24 #include "tensorflow/compiler/xla/client/xla_builder.h" 25 #include "tensorflow/core/framework/bounds_check.h" 26 #include "tensorflow/core/framework/kernel_def_builder.h" 27 #include "tensorflow/core/framework/tensor_shape.h" 28 29 namespace tensorflow { 30 namespace { 31 32 class ShapeOp : public XlaOpKernel { 33 public: ShapeOp(OpKernelConstruction * ctx)34 explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 35 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 36 } 37 Compile(XlaOpKernelContext * ctx)38 void Compile(XlaOpKernelContext* ctx) override { 39 const TensorShape input_shape = ctx->InputShape(0); 40 Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); 41 OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); 42 ctx->SetConstantOutput(0, shape_constant); 43 } 44 45 private: 46 DataType out_dtype_; 47 }; 48 49 REGISTER_XLA_OP(Name("Shape").CompilationOnly().IsMetadataOp(), ShapeOp); 50 51 class ShapeNOp : public XlaOpKernel { 52 public: ShapeNOp(OpKernelConstruction * ctx)53 explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 54 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 55 } 56 Compile(XlaOpKernelContext * ctx)57 void Compile(XlaOpKernelContext* ctx) override { 58 for (int i = 0; i < ctx->num_inputs(); ++i) { 59 const TensorShape input_shape = ctx->InputShape(i); 60 Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()})); 61 OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant)); 62 ctx->SetConstantOutput(i, shape_constant); 63 } 64 } 65 IsExpensive()66 bool IsExpensive() override { return false; } 67 68 private: 69 DataType out_dtype_; 70 }; 71 REGISTER_XLA_OP(Name("ShapeN").CompilationOnly().IsMetadataOp(), ShapeNOp); 72 73 class RankOp : public XlaOpKernel { 74 public: RankOp(OpKernelConstruction * ctx)75 explicit RankOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 76 Compile(XlaOpKernelContext * ctx)77 void Compile(XlaOpKernelContext* ctx) override { 78 const TensorShape input_shape = ctx->InputShape(0); 79 const int rank = input_shape.dims(); 80 Tensor rank_constant(DT_INT32, TensorShape({})); 81 rank_constant.scalar<int32>()() = rank; 82 83 ctx->SetConstantOutput(0, rank_constant); 84 } 85 }; 86 87 REGISTER_XLA_OP(Name("Rank").CompilationOnly().IsMetadataOp(), RankOp); 88 89 class SizeOp : public XlaOpKernel { 90 public: SizeOp(OpKernelConstruction * ctx)91 explicit SizeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 92 Compile(XlaOpKernelContext * ctx)93 void Compile(XlaOpKernelContext* ctx) override { 94 const TensorShape input_shape = ctx->InputShape(0); 95 OP_REQUIRES(ctx, 96 FastBoundsCheck(input_shape.num_elements(), 97 std::numeric_limits<int32>::max()), 98 errors::InvalidArgument("Size does not work for tensors > " 99 "int32 max.")); 100 Tensor size_constant(DT_INT32, TensorShape({})); 101 const int rank = input_shape.dims(); 102 xla::XlaBuilder* builder = ctx->builder(); 103 auto size = xla::One(builder, xla::U32); 104 for (int64 i = 0; i < rank; ++i) { 105 size = xla::Mul(size, xla::GetDimensionSize(ctx->Input(0), i)); 106 } 107 size = xla::ConvertElementType(size, ctx->output_xla_type(0)); 108 ctx->SetOutput(0, size); 109 } 110 }; 111 112 REGISTER_XLA_OP(Name("Size").CompilationOnly().IsMetadataOp(), SizeOp); 113 114 class ExpandDimsOp : public XlaOpKernel { 115 public: ExpandDimsOp(OpKernelConstruction * ctx)116 explicit ExpandDimsOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 117 Compile(XlaOpKernelContext * ctx)118 void Compile(XlaOpKernelContext* ctx) override { 119 const TensorShape input_shape = ctx->InputShape("input"); 120 const TensorShape dim_shape = ctx->InputShape("dim"); 121 122 std::vector<int64> dims; 123 OP_REQUIRES_OK(ctx, ctx->ConstantInputReshapedToIntVector("dim", &dims)); 124 OP_REQUIRES(ctx, dims.size() == 1, 125 errors::InvalidArgument(absl::StrCat( 126 "dim input to ExpandDims must be a scalar; got ", 127 dim_shape.DebugString()))); 128 int dim = dims[0]; 129 130 OP_REQUIRES(ctx, 131 (dim >= -1 - input_shape.dims() && dim <= input_shape.dims()), 132 errors::InvalidArgument("Tried to expand dim index ", dim, 133 " for tensor with ", input_shape.dims(), 134 " dimensions.")); 135 136 auto existing_dims = input_shape.dim_sizes(); 137 // Safe - # elements in tensor dims bounded. 138 const int existing_dims_size = static_cast<int>(existing_dims.size()); 139 std::vector<int64> new_shape(existing_dims_size); 140 for (size_t i = 0; i < new_shape.size(); ++i) { 141 new_shape[i] = existing_dims[i]; 142 } 143 144 // We emulate numpy's interpretation of the dim axis when 145 // -input.dims() >= dim <= input.dims(). 146 if (dim < 0) { 147 dim += existing_dims.size() + 1; 148 } 149 150 // Clamp to the end if needed. 151 dim = std::min<int32>(dim, existing_dims_size); 152 new_shape.emplace(new_shape.begin() + dim, 1); 153 154 ctx->SetOutput(0, xla::Reshape(ctx->Input("input"), new_shape)); 155 } 156 }; 157 REGISTER_XLA_OP(Name("ExpandDims").CompileTimeConstantInput("dim"), 158 ExpandDimsOp); 159 160 class SqueezeOp : public XlaOpKernel { 161 public: SqueezeOp(OpKernelConstruction * ctx)162 explicit SqueezeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 163 std::vector<int32> squeeze_dims; 164 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims)); 165 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end()); 166 } 167 Compile(XlaOpKernelContext * ctx)168 void Compile(XlaOpKernelContext* ctx) override { 169 const TensorShape input_shape = ctx->InputShape(0); 170 auto existing_dims = input_shape.dim_sizes(); 171 int existing_dims_size = input_shape.dims(); 172 std::vector<int64> new_shape; 173 174 std::unordered_set<int32> wrapped_squeeze_dims; 175 wrapped_squeeze_dims.reserve(squeeze_dims_.size()); 176 // Validate squeeze dims against the input. 177 for (int32 dim : squeeze_dims_) { 178 OP_REQUIRES(ctx, (dim >= -input_shape.dims() && dim < input_shape.dims()), 179 errors::InvalidArgument("Tried to squeeze dim index ", dim, 180 " for tensor with ", 181 input_shape.dims(), " dimensions.")); 182 // If dim is < 0, we wrap around (-1 means the last element). 183 if (dim < 0) { 184 dim = existing_dims_size + dim; 185 } 186 187 wrapped_squeeze_dims.insert(dim); 188 } 189 190 for (int i = 0; i < existing_dims_size; ++i) { 191 auto existing_dim = existing_dims[i]; 192 193 // If squeeze_set is non-empty, only squeeze those dimensions. 194 if (!wrapped_squeeze_dims.empty()) { 195 if (wrapped_squeeze_dims.count(i) > 0) { 196 OP_REQUIRES(ctx, existing_dim == 1, 197 errors::InvalidArgument( 198 "Tried to explicitly squeeze dimension ", i, 199 " but dimension was not 1: ", existing_dim)); 200 } else { 201 // This dimension is not being squeezed. 202 new_shape.push_back(existing_dim); 203 } 204 } else { 205 // Copy over all non-1-length dimensions. 206 if (existing_dim != 1) { 207 new_shape.push_back(existing_dim); 208 } 209 } 210 } 211 212 ctx->SetOutput(0, xla::Reshape(ctx->Input(0), new_shape)); 213 } 214 215 private: 216 std::unordered_set<int32> squeeze_dims_; 217 }; 218 219 REGISTER_XLA_OP(Name("Squeeze"), SqueezeOp); 220 221 class ZerosLikeOp : public XlaOpKernel { 222 public: ZerosLikeOp(OpKernelConstruction * ctx)223 explicit ZerosLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 224 Compile(XlaOpKernelContext * ctx)225 void Compile(XlaOpKernelContext* ctx) override { 226 const TensorShape input_shape = ctx->InputShape(0); 227 228 auto zero = XlaHelpers::Zero(ctx->builder(), input_type(0)); 229 ctx->SetOutput(0, xla::Broadcast(zero, input_shape.dim_sizes())); 230 } 231 }; 232 233 REGISTER_XLA_OP(Name("ZerosLike"), ZerosLikeOp); 234 235 class OnesLikeOp : public XlaOpKernel { 236 public: OnesLikeOp(OpKernelConstruction * ctx)237 explicit OnesLikeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 238 Compile(XlaOpKernelContext * ctx)239 void Compile(XlaOpKernelContext* ctx) override { 240 const TensorShape input_shape = ctx->InputShape(0); 241 242 auto one = XlaHelpers::One(ctx->builder(), input_type(0)); 243 ctx->SetOutput(0, xla::Broadcast(one, input_shape.dim_sizes())); 244 } 245 }; 246 247 REGISTER_XLA_OP(Name("OnesLike"), OnesLikeOp); 248 249 } // namespace 250 } // namespace tensorflow 251