1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <vector>
17
18 #include "absl/strings/string_view.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33
34 namespace tensorflow {
35 namespace {
36
37 constexpr absl::string_view kNumSplitsAttrName = "num_splits";
38 constexpr absl::string_view kNumConcatsAttrName = "num_concats";
39
40 template <bool Split>
GetAndValidateAttributes(OpKernelConstruction * ctx,std::vector<int64_t> & num_partitions,int & num_slices,std::vector<int64_t> & paddings,bool & has_paddings)41 Status GetAndValidateAttributes(OpKernelConstruction* ctx,
42 std::vector<int64_t>& num_partitions,
43 int& num_slices, std::vector<int64_t>& paddings,
44 bool& has_paddings) {
45 absl::string_view num_partitions_attr_name =
46 Split ? kNumSplitsAttrName : kNumConcatsAttrName;
47 TF_RETURN_IF_ERROR(ctx->GetAttr(num_partitions_attr_name, &num_partitions));
48
49 int num_dims_to_split = 0;
50 for (int i = 0, e = num_partitions.size(); i < e; ++i) {
51 const auto& split = num_partitions[i];
52 if (split <= 0) {
53 return errors::InvalidArgument("'", num_partitions_attr_name,
54 "' at index ", i,
55 " must be positive, but got ", split, ".");
56 }
57 if (split > 1) {
58 ++num_dims_to_split;
59 }
60 num_slices *= split;
61 }
62
63 int n;
64 TF_RETURN_IF_ERROR(ctx->GetAttr("N", &n));
65 if (n != num_slices) {
66 return errors::InvalidArgument(
67 "'N' must match number of slices ", num_slices, " from '",
68 num_partitions_attr_name, "', but got ", n, ".");
69 }
70
71 TF_RETURN_IF_ERROR(ctx->GetAttr("paddings", &paddings));
72 const int expected_rank = num_partitions.size();
73 if (!paddings.empty()) {
74 if (paddings.size() != expected_rank) {
75 return errors::InvalidArgument(
76 "'paddings' length must match '", num_partitions_attr_name,
77 "' length ", expected_rank, ", but got ", paddings.size(), ".");
78 }
79
80 for (int dim = 0; dim < expected_rank; ++dim) {
81 if (paddings[dim] < 0) {
82 return errors::InvalidArgument(
83 "'padding' must be all non-negative, but got ", paddings[dim],
84 " at index ", dim, ".");
85 }
86 if (paddings[dim] > 0) {
87 has_paddings = true;
88 }
89 }
90 } else {
91 paddings.assign(expected_rank, 0);
92 }
93
94 return OkStatus();
95 }
96
GetSliceIndices(absl::Span<const int64> num_partitions,absl::Span<const int64> slice_shape,const int index)97 std::vector<int64_t> GetSliceIndices(absl::Span<const int64> num_partitions,
98 absl::Span<const int64> slice_shape,
99 const int index) {
100 DCHECK_EQ(num_partitions.size(), slice_shape.size());
101
102 std::vector<int64_t> slice_indices(num_partitions.size());
103
104 if (num_partitions.empty()) {
105 return slice_indices;
106 }
107
108 auto divisor = [&](const int dim) {
109 int divisor = 1;
110 for (int i = num_partitions.size() - 1; i > dim; --i) {
111 divisor *= num_partitions[i];
112 }
113 return divisor;
114 };
115
116 for (int dim = num_partitions.size() - 1; dim > 0; --dim) {
117 slice_indices[dim] =
118 ((index / divisor(dim)) % num_partitions[dim]) * slice_shape[dim];
119 }
120 slice_indices[0] = (index / divisor(0)) * slice_shape[0];
121
122 return slice_indices;
123 }
124
125 constexpr absl::string_view kTensorName = "'input' tensor";
126 constexpr absl::string_view kResourceName = "'resource' variable tensor";
127
128 template <bool Resource>
129 class XlaSplitNDBaseOp : public XlaOpKernel {
130 public:
XlaSplitNDBaseOp(OpKernelConstruction * ctx)131 explicit XlaSplitNDBaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
132 OP_REQUIRES_OK(ctx,
133 GetAndValidateAttributes<true>(ctx, num_splits_, num_slices_,
134 paddings_, has_paddings_));
135 }
136
137 protected:
CompileInternal(XlaOpKernelContext * ctx,const xla::XlaOp input,const TensorShape & input_shape,const DataType input_dtype)138 Status CompileInternal(XlaOpKernelContext* ctx, const xla::XlaOp input,
139 const TensorShape& input_shape,
140 const DataType input_dtype) {
141 xla::PrimitiveType type;
142 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(input_dtype, &type));
143
144 absl::string_view input_name = Resource ? kResourceName : kTensorName;
145 const int rank = input_shape.dims();
146
147 if (rank != num_splits_.size()) {
148 return errors::InvalidArgument(
149 input_name, " rank must be the same as 'num_splits' length ",
150 num_splits_.size(), ", but got rank ", rank, ".");
151 }
152
153 for (int dim = 0; dim < rank; ++dim) {
154 if ((input_shape.dim_size(dim) + paddings_[dim]) % num_splits_[dim] !=
155 0) {
156 return errors::InvalidArgument(
157 input_name, " shape dimension ", dim, " (",
158 input_shape.dim_size(dim), ") with padding ", paddings_[dim],
159 " must be evenly divisible by 'num_splits' ", num_splits_[dim],
160 ".");
161 }
162 }
163
164 if (num_slices_ == 1 && has_paddings_) {
165 xla::PaddingConfig padding_config;
166 for (int dim = 0; dim < rank; ++dim) {
167 auto* padding_dim = padding_config.add_dimensions();
168 padding_dim->set_edge_padding_low(0);
169 padding_dim->set_edge_padding_high(paddings_[dim]);
170 padding_dim->set_interior_padding(0);
171 }
172 ctx->SetOutput(
173 /*index=*/0,
174 xla::Pad(input,
175 xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0),
176 padding_config));
177 return OkStatus();
178 } else if (num_slices_ == 1) {
179 ctx->SetOutput(/*index=*/0, input);
180 return OkStatus();
181 }
182
183 // Slice shape with optional padding.
184 std::vector<int64_t> slice_shape(rank);
185 for (int dim = 0; dim < rank; ++dim) {
186 slice_shape[dim] =
187 (input_shape.dim_size(dim) + paddings_[dim]) / num_splits_[dim];
188 }
189
190 const std::vector<int64_t> slice_strides(rank, 1);
191
192 for (int i = 0; i < num_slices_; ++i) {
193 int num_complete_pad_dims = 0;
194 int num_partial_pad_dims = 0;
195 std::vector<int64_t> slice_start_indices =
196 GetSliceIndices(num_splits_, slice_shape, i);
197 std::vector<int64_t> slice_limit_indices(slice_shape.size());
198 xla::PaddingConfig slice_padding_config;
199 for (int dim = 0; dim < rank; ++dim) {
200 auto* padding_dim = slice_padding_config.add_dimensions();
201 padding_dim->set_edge_padding_low(0);
202 padding_dim->set_edge_padding_high(0);
203 padding_dim->set_interior_padding(0);
204 }
205
206 // Calculate paddings necessary for slice instead of padding input and
207 // slicing subsequently to reduce temporary memory allocation.
208 for (int dim = 0; dim < rank; ++dim) {
209 const int64 dim_size = input_shape.dim_size(dim);
210 if (slice_start_indices[dim] >= dim_size) {
211 // Complete padding.
212 slice_start_indices[dim] = dim_size;
213 slice_limit_indices[dim] = dim_size;
214 slice_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
215 slice_shape[dim]);
216 ++num_complete_pad_dims;
217 } else if (slice_start_indices[dim] + slice_shape[dim] > dim_size) {
218 // Partial padding.
219 slice_limit_indices[dim] = dim_size;
220 slice_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
221 slice_start_indices[dim] + slice_shape[dim] - dim_size);
222 ++num_partial_pad_dims;
223 } else {
224 slice_limit_indices[dim] =
225 slice_start_indices[dim] + slice_shape[dim];
226 }
227 }
228
229 if (num_complete_pad_dims == rank) {
230 ctx->SetOutput(i, xla::Broadcast(xla::ConstantR0WithType(
231 ctx->builder(), type, /*value=*/0),
232 slice_shape));
233 } else if (num_complete_pad_dims > 0 || num_partial_pad_dims > 0) {
234 ctx->SetOutput(
235 i,
236 xla::Pad(xla::Slice(input, slice_start_indices, slice_limit_indices,
237 slice_strides),
238 xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0),
239 slice_padding_config));
240 } else {
241 ctx->SetOutput(i, xla::Slice(input, slice_start_indices,
242 slice_limit_indices, slice_strides));
243 }
244 }
245 return OkStatus();
246 }
247
248 private:
249 std::vector<int64_t> num_splits_;
250 int num_slices_ = 1;
251 std::vector<int64_t> paddings_;
252 bool has_paddings_ = false;
253 };
254
255 class XlaSplitNDOp : public XlaSplitNDBaseOp<false> {
256 public:
XlaSplitNDOp(OpKernelConstruction * ctx)257 explicit XlaSplitNDOp(OpKernelConstruction* ctx)
258 : XlaSplitNDBaseOp<false>(ctx) {}
259
Compile(XlaOpKernelContext * ctx)260 void Compile(XlaOpKernelContext* ctx) override {
261 OP_REQUIRES_OK(ctx,
262 this->CompileInternal(ctx, ctx->Input(0), ctx->InputShape(0),
263 ctx->input_type(0)));
264 }
265 };
266
267 REGISTER_XLA_OP(Name("XlaSplitND"), XlaSplitNDOp);
268
269 class ReadVariableXlaSplitNDOp : public XlaSplitNDBaseOp<true> {
270 public:
ReadVariableXlaSplitNDOp(OpKernelConstruction * ctx)271 explicit ReadVariableXlaSplitNDOp(OpKernelConstruction* ctx)
272 : XlaSplitNDBaseOp<true>(ctx) {
273 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
274 }
275
Compile(XlaOpKernelContext * ctx)276 void Compile(XlaOpKernelContext* ctx) override {
277 DataType variable_input_dtype;
278 TensorShape variable_input_shape;
279 OP_REQUIRES_OK(
280 ctx, ctx->GetVariableTypeAndShape(/*index=*/0, &variable_input_dtype,
281 &variable_input_shape));
282 OP_REQUIRES(
283 ctx, variable_input_dtype == dtype_,
284 errors::InvalidArgument("'T' must match 'resource' variable dtype ",
285 DataTypeString(variable_input_dtype),
286 ", but got ", dtype_));
287
288 xla::XlaOp handle;
289 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(/*index=*/0, dtype_,
290 /*shape=*/nullptr, &handle));
291
292 OP_REQUIRES_OK(
293 ctx, this->CompileInternal(ctx, handle, variable_input_shape, dtype_));
294 }
295
296 private:
297 DataType dtype_;
298 };
299
300 REGISTER_XLA_OP(Name("ReadVariableXlaSplitND"), ReadVariableXlaSplitNDOp);
301
302 class XlaConcatNDBaseOp : public XlaOpKernel {
303 public:
XlaConcatNDBaseOp(OpKernelConstruction * ctx)304 explicit XlaConcatNDBaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
305 OP_REQUIRES_OK(
306 ctx, GetAndValidateAttributes<false>(ctx, num_concats_, num_slices_,
307 paddings_, has_paddings_));
308 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
309 }
310
311 protected:
CompileInternal(XlaOpKernelContext * ctx)312 StatusOr<xla::XlaOp> CompileInternal(XlaOpKernelContext* ctx) {
313 xla::PrimitiveType type;
314 TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype_, &type));
315
316 std::vector<xla::XlaOp> input_handles;
317 std::vector<TensorShape> input_shapes;
318 std::vector<int64_t> output_shape;
319 TF_RETURN_IF_ERROR(GetInputsAndOutputShape(ctx, input_handles, input_shapes,
320 output_shape));
321
322 const int rank = output_shape.size();
323
324 if (num_slices_ == 1 && has_paddings_) {
325 return xla::Slice(input_handles[0],
326 /*start_indices=*/std::vector<int64_t>(rank, 0),
327 /*limit_indices=*/output_shape,
328 /*strides=*/std::vector<int64_t>(rank, 1));
329 } else if (num_slices_ == 1) {
330 return input_handles[0];
331 }
332
333 auto slice_shape = input_shapes[0].dim_sizes();
334 xla::XlaOp output = xla::Broadcast(
335 xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0),
336 output_shape);
337 const std::vector<int64_t> input_slice_start_indices(rank, 0);
338 const std::vector<int64_t> slice_strides(rank, 1);
339
340 for (int i = 0; i < num_slices_; ++i) {
341 std::vector<int64_t> slice_start_indices =
342 GetSliceIndices(num_concats_, slice_shape, i);
343
344 int num_complete_pad_dims = 0;
345 int num_partial_pad_dims = 0;
346 std::vector<int64_t> slice_limit_indices(rank);
347
348 // Calculate paddings necessary to strip from slice.
349 for (int dim = 0; dim < rank; ++dim) {
350 const int64_t dim_size = output_shape[dim];
351 if (slice_start_indices[dim] >= dim_size) {
352 // Complete padding.
353 slice_start_indices[dim] = dim_size;
354 slice_limit_indices[dim] = dim_size;
355 ++num_complete_pad_dims;
356 } else if (slice_start_indices[dim] + slice_shape[dim] > dim_size) {
357 // Partial padding.
358 slice_limit_indices[dim] = dim_size;
359 ++num_partial_pad_dims;
360 } else {
361 slice_limit_indices[dim] =
362 slice_start_indices[dim] + slice_shape[dim];
363 }
364 }
365
366 if (num_complete_pad_dims == rank) {
367 continue;
368 }
369
370 xla::XlaOp input_slice = input_handles[i];
371 if (num_complete_pad_dims > 0 || num_partial_pad_dims > 0) {
372 std::vector<int64_t> input_slice_limit_indices(rank);
373 for (int dim = 0; dim < rank; ++dim) {
374 input_slice_limit_indices[dim] =
375 slice_limit_indices[dim] - slice_start_indices[dim];
376 }
377 input_slice = xla::Slice(input_slice, input_slice_start_indices,
378 input_slice_limit_indices, slice_strides);
379 }
380
381 std::vector<xla::XlaOp> update_slice_start_indices;
382 update_slice_start_indices.reserve(rank);
383 for (int64 start_index : slice_start_indices) {
384 update_slice_start_indices.push_back(
385 xla::ConstantR0<int32>(ctx->builder(), start_index));
386 }
387 output = xla::DynamicUpdateSlice(output, input_slice,
388 update_slice_start_indices);
389 }
390
391 return output;
392 }
393
394 DataType dtype_;
395
396 private:
GetInputsAndOutputShape(XlaOpKernelContext * ctx,std::vector<xla::XlaOp> & input_handles,std::vector<TensorShape> & input_shapes,std::vector<int64_t> & output_shape)397 Status GetInputsAndOutputShape(XlaOpKernelContext* ctx,
398 std::vector<xla::XlaOp>& input_handles,
399 std::vector<TensorShape>& input_shapes,
400 std::vector<int64_t>& output_shape) {
401 TF_RETURN_IF_ERROR(ctx->InputList("inputs", &input_handles, &input_shapes));
402
403 const TensorShape& slice_shape = input_shapes[0];
404 if (slice_shape.dims() != num_concats_.size()) {
405 return errors::InvalidArgument(
406 "'inputs' rank must be the same as 'num_concats' length ",
407 num_concats_.size(), ", but got rank ", slice_shape.dims(), ".");
408 }
409 for (int i = 1; i < num_slices_; ++i) {
410 const TensorShape& slice_shape_i = input_shapes[i];
411 if (slice_shape != slice_shape_i) {
412 return errors::InvalidArgument(
413 "'inputs' must all have the same expected shape ", slice_shape,
414 ", but got ", slice_shape_i, " at index ", i, ".");
415 }
416 }
417
418 const int rank = input_shapes[0].dims();
419 for (int dim = 0; dim < rank; ++dim) {
420 const int max_dim_size = slice_shape.dim_size(dim) * num_concats_[dim];
421 if (paddings_[dim] > max_dim_size) {
422 return errors::InvalidArgument(
423 "'paddings' must not exceed expected output shape dimension ",
424 max_dim_size, " at index ", dim, ", but got ", paddings_[dim], ".");
425 }
426 output_shape.push_back(max_dim_size - paddings_[dim]);
427 }
428
429 return OkStatus();
430 }
431
432 std::vector<int64_t> num_concats_;
433 int num_slices_ = 1;
434 std::vector<int64_t> paddings_;
435 bool has_paddings_ = false;
436 };
437
438 class XlaConcatNDOp : public XlaConcatNDBaseOp {
439 public:
XlaConcatNDOp(OpKernelConstruction * ctx)440 explicit XlaConcatNDOp(OpKernelConstruction* ctx) : XlaConcatNDBaseOp(ctx) {}
441
Compile(XlaOpKernelContext * ctx)442 void Compile(XlaOpKernelContext* ctx) override {
443 auto output_or = this->CompileInternal(ctx);
444 OP_REQUIRES_OK(ctx, output_or.status());
445 ctx->SetOutput(/*index=*/0, output_or.ValueOrDie());
446 }
447 };
448
449 REGISTER_XLA_OP(Name("XlaConcatND"), XlaConcatNDOp);
450
451 class AssignVariableXlaConcatNDOp : public XlaConcatNDBaseOp {
452 public:
AssignVariableXlaConcatNDOp(OpKernelConstruction * ctx)453 explicit AssignVariableXlaConcatNDOp(OpKernelConstruction* ctx)
454 : XlaConcatNDBaseOp(ctx) {}
455
Compile(XlaOpKernelContext * ctx)456 void Compile(XlaOpKernelContext* ctx) override {
457 auto output_or = this->CompileInternal(ctx);
458 OP_REQUIRES_OK(ctx, output_or.status());
459 OP_REQUIRES_OK(ctx,
460 ctx->AssignVariable("resource", dtype_, output_or.value()));
461 }
462 };
463
464 REGISTER_XLA_OP(Name("AssignVariableXlaConcatND"), AssignVariableXlaConcatNDOp);
465
466 } // namespace
467 } // namespace tensorflow
468