1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // XLA-specific Concat Ops. 17 18 #include <limits> 19 #include <vector> 20 21 #include "tensorflow/compiler/tf2xla/type_util.h" 22 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 25 #include "tensorflow/compiler/xla/client/xla_builder.h" 26 #include "tensorflow/compiler/xla/literal_util.h" 27 #include "tensorflow/core/framework/bounds_check.h" 28 #include "tensorflow/core/framework/op_kernel.h" 29 #include "tensorflow/core/framework/register_types.h" 30 #include "tensorflow/core/framework/tensor.h" 31 #include "tensorflow/core/framework/tensor_shape.h" 32 #include "tensorflow/core/framework/tensor_types.h" 33 #include "tensorflow/core/framework/types.h" 34 #include "tensorflow/core/lib/core/status.h" 35 #include "tensorflow/core/platform/types.h" 36 37 namespace tensorflow { 38 namespace { 39 40 // -------------------------------------------------------------------------- 41 class ConcatBaseOp : public XlaOpKernel { 42 public: ConcatBaseOp(OpKernelConstruction * c,int axis_index)43 ConcatBaseOp(OpKernelConstruction* c, int axis_index) 44 : XlaOpKernel(c), axis_index_(axis_index) {} 45 Compile(XlaOpKernelContext * ctx)46 void Compile(XlaOpKernelContext* ctx) override { 47 const TensorShape concat_dim_tensor_shape = ctx->InputShape(axis_index_); 48 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_tensor_shape), 49 errors::InvalidArgument( 50 "Concat dim tensor should be a scalar, but got shape ", 51 concat_dim_tensor_shape.DebugString())); 52 int64_t concat_dim; 53 OP_REQUIRES_OK(ctx, 54 ctx->ConstantInputAsIntScalar(axis_index_, &concat_dim)); 55 56 std::vector<xla::XlaOp> values; 57 std::vector<TensorShape> shapes; 58 OP_REQUIRES_OK(ctx, ctx->InputList("values", &values, &shapes)); 59 const int N = values.size(); 60 const int input_dims = shapes[0].dims(); 61 const TensorShape& input_shape = shapes[0]; 62 63 int32_t axis = concat_dim < 0 ? concat_dim + input_dims : concat_dim; 64 OP_REQUIRES(ctx, 0 <= axis && axis < input_dims, 65 errors::InvalidArgument( 66 "ConcatOp : Expected concatenating dimensions in the range " 67 "[", 68 -input_dims, ", ", input_dims, "), but got ", concat_dim)); 69 70 // Make a vector holding the XlaOp for each of the inputs that has non-zero 71 // elements. 72 std::vector<xla::XlaOp> input_data; 73 int output_concat_dim = 0; 74 for (int i = 0; i < N; ++i) { 75 xla::XlaOp handle = values[i]; 76 const TensorShape& in_shape = shapes[i]; 77 OP_REQUIRES( 78 ctx, in_shape.dims() == input_dims, 79 errors::InvalidArgument( 80 "ConcatOp : Ranks of all input tensors should match: shape[0] = ", 81 input_shape.DebugString(), " vs. shape[", i, 82 "] = ", in_shape.DebugString())); 83 if (in_shape.dims() == 0) { 84 // Inputs that come in as scalars must be reshaped to 1-vectors. 85 input_data.push_back(xla::Reshape(handle, {1})); 86 } else { 87 input_data.push_back(handle); 88 } 89 output_concat_dim += in_shape.dims() > 0 ? in_shape.dim_size(axis) : 1; 90 } 91 92 VLOG(1) << "Concat dim " << concat_dim << " equivalent to " << axis; 93 ctx->SetOutput(0, xla::ConcatInDim(ctx->builder(), input_data, axis)); 94 } 95 96 private: 97 int axis_index_; 98 }; 99 100 class ConcatOp : public ConcatBaseOp { 101 public: ConcatOp(OpKernelConstruction * c)102 explicit ConcatOp(OpKernelConstruction* c) 103 : ConcatBaseOp(c, /* axis_index */ 0) {} 104 }; 105 106 // ConcatV2 operation is the same as Concat except 'concat_dim' 107 // is the last input instead of the first and renamed to 'axis'. 108 class ConcatV2Op : public ConcatBaseOp { 109 public: ConcatV2Op(OpKernelConstruction * c)110 explicit ConcatV2Op(OpKernelConstruction* c) 111 : ConcatBaseOp(c, /* axis_index */ c->num_inputs() - 1) {} 112 }; 113 114 REGISTER_XLA_OP(Name("Concat").CompileTimeConstantInput("concat_dim"), 115 ConcatOp); 116 REGISTER_XLA_OP(Name("ConcatV2") 117 .TypeConstraint("Tidx", DT_INT32) 118 .CompileTimeConstantInput("axis"), 119 ConcatV2Op); 120 121 class ConcatOffsetOp : public XlaOpKernel { 122 public: ConcatOffsetOp(OpKernelConstruction * ctx)123 explicit ConcatOffsetOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} 124 Compile(XlaOpKernelContext * ctx)125 void Compile(XlaOpKernelContext* ctx) override { 126 const TensorShape concat_dim_shape = ctx->InputShape(0); 127 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(concat_dim_shape), 128 errors::InvalidArgument( 129 "Concat dim tensor should be a scalar, but got shape ", 130 concat_dim_shape.DebugString())); 131 for (int i = 1; i < ctx->num_inputs(); ++i) { 132 OP_REQUIRES(ctx, TensorShapeUtils::IsVector(ctx->InputShape(i)), 133 errors::InvalidArgument("input ", i, 134 " should be a vector, but got shape ", 135 ctx->InputShape(i).DebugString())); 136 } 137 // Suppose a Concat() op needs to Concatenate N tensors, each of 138 // which has the same number of dimensions. Their shapes match 139 // except the concat dimension. 140 // 141 // E.g., say, we want to concatenate 3 tensors in the 2nd 142 // dimension, and their shapes are: 143 // 144 // [2, 2, 5, 7] 145 // [2, 3, 5, 7] 146 // [2, 4, 5, 7] 147 // 148 // Here, N=3, cdim=1, dims=4. The concatenated tensor has shape 149 // [2,9,5,7]. We will compute the cumulative sum along the 2nd 150 // dimension to figure out each input's offset in the concatenated 151 // output: 152 // [0, 0, 0, 0] 153 // [0, 2, 0, 0] 154 // [0, 5, 0, 0] 155 const int32_t N = ctx->num_inputs() - 1; 156 const TensorShape inp0_shape = ctx->InputShape(1); 157 std::vector<int64_t> inp0_dims; 158 OP_REQUIRES_OK(ctx, 159 ctx->ConstantInputAsIntVector( 160 1, &inp0_dims, xla::ValueInferenceMode::kUpperBound)); 161 const int64_t inp0_rank = inp0_shape.num_elements(); 162 163 int64_t cdim; 164 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &cdim)); 165 166 VLOG(1) << "ConcatOffset " << cdim << "," << inp0_rank; 167 int32_t axis = cdim < 0 ? cdim + inp0_rank : cdim; 168 OP_REQUIRES(ctx, FastBoundsCheck(axis, inp0_rank), 169 errors::InvalidArgument("Concat dim is out of range: ", axis, 170 " vs. ", inp0_rank)); 171 int32_t offset = 0; 172 for (int i = 0; i < N; ++i) { 173 const TensorShape inp_shape = ctx->InputShape(1 + i); 174 OP_REQUIRES(ctx, inp0_rank == inp_shape.num_elements(), 175 errors::InvalidArgument("input ", i, " should contain ", 176 inp0_rank, " elements, but got ", 177 inp_shape.num_elements())); 178 std::vector<int64_t> inp_dims; 179 OP_REQUIRES_OK( 180 ctx, ctx->ConstantInputAsIntVector( 181 1 + i, &inp_dims, xla::ValueInferenceMode::kUpperBound)); 182 183 Tensor out_constant(DT_INT32, TensorShape({inp0_rank})); 184 auto out_vec = out_constant.vec<int32>(); 185 for (int64_t j = 0; j < inp0_rank; ++j) { 186 if (j == axis) { 187 out_vec(j) = offset; 188 offset += inp_dims[j]; 189 } else { 190 const int32_t inp0_element = inp0_dims[j]; 191 const int32_t inp_element = inp_dims[j]; 192 OP_REQUIRES(ctx, inp0_element == inp_element, 193 errors::InvalidArgument( 194 "All dimensions except ", axis, " must match. Input ", 195 i, " has shape [", absl::StrJoin(inp_dims, " "), 196 "] and doesn't match input 0 with shape [", 197 absl::StrJoin(inp0_dims, " "), "].")); 198 out_vec(j) = 0; 199 } 200 } 201 202 ctx->SetConstantOutput(i, out_constant); 203 } 204 } 205 }; 206 207 REGISTER_XLA_OP(Name("ConcatOffset") 208 .CompileTimeConstantInput("concat_dim") 209 .CompileTimeConstantInput("shape"), 210 ConcatOffsetOp); 211 212 } // namespace 213 } // namespace tensorflow 214