• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "absl/strings/string_view.h"
19 #include "absl/types/span.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
23 #include "tensorflow/compiler/xla/client/lib/constants.h"
24 #include "tensorflow/compiler/xla/client/xla_builder.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/op_kernel.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/platform/errors.h"
31 #include "tensorflow/core/platform/status.h"
32 #include "tensorflow/core/platform/statusor.h"
33 
34 namespace tensorflow {
35 namespace {
36 
37 constexpr absl::string_view kNumSplitsAttrName = "num_splits";
38 constexpr absl::string_view kNumConcatsAttrName = "num_concats";
39 
40 template <bool Split>
GetAndValidateAttributes(OpKernelConstruction * ctx,std::vector<int64_t> & num_partitions,int & num_slices,std::vector<int64_t> & paddings,bool & has_paddings)41 Status GetAndValidateAttributes(OpKernelConstruction* ctx,
42                                 std::vector<int64_t>& num_partitions,
43                                 int& num_slices, std::vector<int64_t>& paddings,
44                                 bool& has_paddings) {
45   absl::string_view num_partitions_attr_name =
46       Split ? kNumSplitsAttrName : kNumConcatsAttrName;
47   TF_RETURN_IF_ERROR(ctx->GetAttr(num_partitions_attr_name, &num_partitions));
48 
49   int num_dims_to_split = 0;
50   for (int i = 0, e = num_partitions.size(); i < e; ++i) {
51     const auto& split = num_partitions[i];
52     if (split <= 0) {
53       return errors::InvalidArgument("'", num_partitions_attr_name,
54                                      "' at index ", i,
55                                      " must be positive, but got ", split, ".");
56     }
57     if (split > 1) {
58       ++num_dims_to_split;
59     }
60     num_slices *= split;
61   }
62 
63   int n;
64   TF_RETURN_IF_ERROR(ctx->GetAttr("N", &n));
65   if (n != num_slices) {
66     return errors::InvalidArgument(
67         "'N' must match number of slices ", num_slices, " from '",
68         num_partitions_attr_name, "', but got ", n, ".");
69   }
70 
71   TF_RETURN_IF_ERROR(ctx->GetAttr("paddings", &paddings));
72   const int expected_rank = num_partitions.size();
73   if (!paddings.empty()) {
74     if (paddings.size() != expected_rank) {
75       return errors::InvalidArgument(
76           "'paddings' length must match '", num_partitions_attr_name,
77           "' length ", expected_rank, ", but got ", paddings.size(), ".");
78     }
79 
80     for (int dim = 0; dim < expected_rank; ++dim) {
81       if (paddings[dim] < 0) {
82         return errors::InvalidArgument(
83             "'padding' must be all non-negative, but got ", paddings[dim],
84             " at index ", dim, ".");
85       }
86       if (paddings[dim] > 0) {
87         has_paddings = true;
88       }
89     }
90   } else {
91     paddings.assign(expected_rank, 0);
92   }
93 
94   return OkStatus();
95 }
96 
GetSliceIndices(absl::Span<const int64> num_partitions,absl::Span<const int64> slice_shape,const int index)97 std::vector<int64_t> GetSliceIndices(absl::Span<const int64> num_partitions,
98                                      absl::Span<const int64> slice_shape,
99                                      const int index) {
100   DCHECK_EQ(num_partitions.size(), slice_shape.size());
101 
102   std::vector<int64_t> slice_indices(num_partitions.size());
103 
104   if (num_partitions.empty()) {
105     return slice_indices;
106   }
107 
108   auto divisor = [&](const int dim) {
109     int divisor = 1;
110     for (int i = num_partitions.size() - 1; i > dim; --i) {
111       divisor *= num_partitions[i];
112     }
113     return divisor;
114   };
115 
116   for (int dim = num_partitions.size() - 1; dim > 0; --dim) {
117     slice_indices[dim] =
118         ((index / divisor(dim)) % num_partitions[dim]) * slice_shape[dim];
119   }
120   slice_indices[0] = (index / divisor(0)) * slice_shape[0];
121 
122   return slice_indices;
123 }
124 
125 constexpr absl::string_view kTensorName = "'input' tensor";
126 constexpr absl::string_view kResourceName = "'resource' variable tensor";
127 
128 template <bool Resource>
129 class XlaSplitNDBaseOp : public XlaOpKernel {
130  public:
XlaSplitNDBaseOp(OpKernelConstruction * ctx)131   explicit XlaSplitNDBaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
132     OP_REQUIRES_OK(ctx,
133                    GetAndValidateAttributes<true>(ctx, num_splits_, num_slices_,
134                                                   paddings_, has_paddings_));
135   }
136 
137  protected:
CompileInternal(XlaOpKernelContext * ctx,const xla::XlaOp input,const TensorShape & input_shape,const DataType input_dtype)138   Status CompileInternal(XlaOpKernelContext* ctx, const xla::XlaOp input,
139                          const TensorShape& input_shape,
140                          const DataType input_dtype) {
141     xla::PrimitiveType type;
142     TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(input_dtype, &type));
143 
144     absl::string_view input_name = Resource ? kResourceName : kTensorName;
145     const int rank = input_shape.dims();
146 
147     if (rank != num_splits_.size()) {
148       return errors::InvalidArgument(
149           input_name, " rank must be the same as 'num_splits' length ",
150           num_splits_.size(), ", but got rank ", rank, ".");
151     }
152 
153     for (int dim = 0; dim < rank; ++dim) {
154       if ((input_shape.dim_size(dim) + paddings_[dim]) % num_splits_[dim] !=
155           0) {
156         return errors::InvalidArgument(
157             input_name, " shape dimension ", dim, " (",
158             input_shape.dim_size(dim), ") with padding ", paddings_[dim],
159             " must be evenly divisible by 'num_splits' ", num_splits_[dim],
160             ".");
161       }
162     }
163 
164     if (num_slices_ == 1 && has_paddings_) {
165       xla::PaddingConfig padding_config;
166       for (int dim = 0; dim < rank; ++dim) {
167         auto* padding_dim = padding_config.add_dimensions();
168         padding_dim->set_edge_padding_low(0);
169         padding_dim->set_edge_padding_high(paddings_[dim]);
170         padding_dim->set_interior_padding(0);
171       }
172       ctx->SetOutput(
173           /*index=*/0,
174           xla::Pad(input,
175                    xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0),
176                    padding_config));
177       return OkStatus();
178     } else if (num_slices_ == 1) {
179       ctx->SetOutput(/*index=*/0, input);
180       return OkStatus();
181     }
182 
183     // Slice shape with optional padding.
184     std::vector<int64_t> slice_shape(rank);
185     for (int dim = 0; dim < rank; ++dim) {
186       slice_shape[dim] =
187           (input_shape.dim_size(dim) + paddings_[dim]) / num_splits_[dim];
188     }
189 
190     const std::vector<int64_t> slice_strides(rank, 1);
191 
192     for (int i = 0; i < num_slices_; ++i) {
193       int num_complete_pad_dims = 0;
194       int num_partial_pad_dims = 0;
195       std::vector<int64_t> slice_start_indices =
196           GetSliceIndices(num_splits_, slice_shape, i);
197       std::vector<int64_t> slice_limit_indices(slice_shape.size());
198       xla::PaddingConfig slice_padding_config;
199       for (int dim = 0; dim < rank; ++dim) {
200         auto* padding_dim = slice_padding_config.add_dimensions();
201         padding_dim->set_edge_padding_low(0);
202         padding_dim->set_edge_padding_high(0);
203         padding_dim->set_interior_padding(0);
204       }
205 
206       // Calculate paddings necessary for slice instead of padding input and
207       // slicing subsequently to reduce temporary memory allocation.
208       for (int dim = 0; dim < rank; ++dim) {
209         const int64 dim_size = input_shape.dim_size(dim);
210         if (slice_start_indices[dim] >= dim_size) {
211           // Complete padding.
212           slice_start_indices[dim] = dim_size;
213           slice_limit_indices[dim] = dim_size;
214           slice_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
215               slice_shape[dim]);
216           ++num_complete_pad_dims;
217         } else if (slice_start_indices[dim] + slice_shape[dim] > dim_size) {
218           // Partial padding.
219           slice_limit_indices[dim] = dim_size;
220           slice_padding_config.mutable_dimensions(dim)->set_edge_padding_high(
221               slice_start_indices[dim] + slice_shape[dim] - dim_size);
222           ++num_partial_pad_dims;
223         } else {
224           slice_limit_indices[dim] =
225               slice_start_indices[dim] + slice_shape[dim];
226         }
227       }
228 
229       if (num_complete_pad_dims == rank) {
230         ctx->SetOutput(i, xla::Broadcast(xla::ConstantR0WithType(
231                                              ctx->builder(), type, /*value=*/0),
232                                          slice_shape));
233       } else if (num_complete_pad_dims > 0 || num_partial_pad_dims > 0) {
234         ctx->SetOutput(
235             i,
236             xla::Pad(xla::Slice(input, slice_start_indices, slice_limit_indices,
237                                 slice_strides),
238                      xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0),
239                      slice_padding_config));
240       } else {
241         ctx->SetOutput(i, xla::Slice(input, slice_start_indices,
242                                      slice_limit_indices, slice_strides));
243       }
244     }
245     return OkStatus();
246   }
247 
248  private:
249   std::vector<int64_t> num_splits_;
250   int num_slices_ = 1;
251   std::vector<int64_t> paddings_;
252   bool has_paddings_ = false;
253 };
254 
255 class XlaSplitNDOp : public XlaSplitNDBaseOp<false> {
256  public:
XlaSplitNDOp(OpKernelConstruction * ctx)257   explicit XlaSplitNDOp(OpKernelConstruction* ctx)
258       : XlaSplitNDBaseOp<false>(ctx) {}
259 
Compile(XlaOpKernelContext * ctx)260   void Compile(XlaOpKernelContext* ctx) override {
261     OP_REQUIRES_OK(ctx,
262                    this->CompileInternal(ctx, ctx->Input(0), ctx->InputShape(0),
263                                          ctx->input_type(0)));
264   }
265 };
266 
267 REGISTER_XLA_OP(Name("XlaSplitND"), XlaSplitNDOp);
268 
269 class ReadVariableXlaSplitNDOp : public XlaSplitNDBaseOp<true> {
270  public:
ReadVariableXlaSplitNDOp(OpKernelConstruction * ctx)271   explicit ReadVariableXlaSplitNDOp(OpKernelConstruction* ctx)
272       : XlaSplitNDBaseOp<true>(ctx) {
273     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
274   }
275 
Compile(XlaOpKernelContext * ctx)276   void Compile(XlaOpKernelContext* ctx) override {
277     DataType variable_input_dtype;
278     TensorShape variable_input_shape;
279     OP_REQUIRES_OK(
280         ctx, ctx->GetVariableTypeAndShape(/*index=*/0, &variable_input_dtype,
281                                           &variable_input_shape));
282     OP_REQUIRES(
283         ctx, variable_input_dtype == dtype_,
284         errors::InvalidArgument("'T' must match 'resource' variable dtype ",
285                                 DataTypeString(variable_input_dtype),
286                                 ", but got ", dtype_));
287 
288     xla::XlaOp handle;
289     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(/*index=*/0, dtype_,
290                                                /*shape=*/nullptr, &handle));
291 
292     OP_REQUIRES_OK(
293         ctx, this->CompileInternal(ctx, handle, variable_input_shape, dtype_));
294   }
295 
296  private:
297   DataType dtype_;
298 };
299 
300 REGISTER_XLA_OP(Name("ReadVariableXlaSplitND"), ReadVariableXlaSplitNDOp);
301 
302 class XlaConcatNDBaseOp : public XlaOpKernel {
303  public:
XlaConcatNDBaseOp(OpKernelConstruction * ctx)304   explicit XlaConcatNDBaseOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
305     OP_REQUIRES_OK(
306         ctx, GetAndValidateAttributes<false>(ctx, num_concats_, num_slices_,
307                                              paddings_, has_paddings_));
308     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype_));
309   }
310 
311  protected:
CompileInternal(XlaOpKernelContext * ctx)312   StatusOr<xla::XlaOp> CompileInternal(XlaOpKernelContext* ctx) {
313     xla::PrimitiveType type;
314     TF_RETURN_IF_ERROR(DataTypeToPrimitiveType(dtype_, &type));
315 
316     std::vector<xla::XlaOp> input_handles;
317     std::vector<TensorShape> input_shapes;
318     std::vector<int64_t> output_shape;
319     TF_RETURN_IF_ERROR(GetInputsAndOutputShape(ctx, input_handles, input_shapes,
320                                                output_shape));
321 
322     const int rank = output_shape.size();
323 
324     if (num_slices_ == 1 && has_paddings_) {
325       return xla::Slice(input_handles[0],
326                         /*start_indices=*/std::vector<int64_t>(rank, 0),
327                         /*limit_indices=*/output_shape,
328                         /*strides=*/std::vector<int64_t>(rank, 1));
329     } else if (num_slices_ == 1) {
330       return input_handles[0];
331     }
332 
333     auto slice_shape = input_shapes[0].dim_sizes();
334     xla::XlaOp output = xla::Broadcast(
335         xla::ConstantR0WithType(ctx->builder(), type, /*value=*/0),
336         output_shape);
337     const std::vector<int64_t> input_slice_start_indices(rank, 0);
338     const std::vector<int64_t> slice_strides(rank, 1);
339 
340     for (int i = 0; i < num_slices_; ++i) {
341       std::vector<int64_t> slice_start_indices =
342           GetSliceIndices(num_concats_, slice_shape, i);
343 
344       int num_complete_pad_dims = 0;
345       int num_partial_pad_dims = 0;
346       std::vector<int64_t> slice_limit_indices(rank);
347 
348       // Calculate paddings necessary to strip from slice.
349       for (int dim = 0; dim < rank; ++dim) {
350         const int64_t dim_size = output_shape[dim];
351         if (slice_start_indices[dim] >= dim_size) {
352           // Complete padding.
353           slice_start_indices[dim] = dim_size;
354           slice_limit_indices[dim] = dim_size;
355           ++num_complete_pad_dims;
356         } else if (slice_start_indices[dim] + slice_shape[dim] > dim_size) {
357           // Partial padding.
358           slice_limit_indices[dim] = dim_size;
359           ++num_partial_pad_dims;
360         } else {
361           slice_limit_indices[dim] =
362               slice_start_indices[dim] + slice_shape[dim];
363         }
364       }
365 
366       if (num_complete_pad_dims == rank) {
367         continue;
368       }
369 
370       xla::XlaOp input_slice = input_handles[i];
371       if (num_complete_pad_dims > 0 || num_partial_pad_dims > 0) {
372         std::vector<int64_t> input_slice_limit_indices(rank);
373         for (int dim = 0; dim < rank; ++dim) {
374           input_slice_limit_indices[dim] =
375               slice_limit_indices[dim] - slice_start_indices[dim];
376         }
377         input_slice = xla::Slice(input_slice, input_slice_start_indices,
378                                  input_slice_limit_indices, slice_strides);
379       }
380 
381       std::vector<xla::XlaOp> update_slice_start_indices;
382       update_slice_start_indices.reserve(rank);
383       for (int64 start_index : slice_start_indices) {
384         update_slice_start_indices.push_back(
385             xla::ConstantR0<int32>(ctx->builder(), start_index));
386       }
387       output = xla::DynamicUpdateSlice(output, input_slice,
388                                        update_slice_start_indices);
389     }
390 
391     return output;
392   }
393 
394   DataType dtype_;
395 
396  private:
GetInputsAndOutputShape(XlaOpKernelContext * ctx,std::vector<xla::XlaOp> & input_handles,std::vector<TensorShape> & input_shapes,std::vector<int64_t> & output_shape)397   Status GetInputsAndOutputShape(XlaOpKernelContext* ctx,
398                                  std::vector<xla::XlaOp>& input_handles,
399                                  std::vector<TensorShape>& input_shapes,
400                                  std::vector<int64_t>& output_shape) {
401     TF_RETURN_IF_ERROR(ctx->InputList("inputs", &input_handles, &input_shapes));
402 
403     const TensorShape& slice_shape = input_shapes[0];
404     if (slice_shape.dims() != num_concats_.size()) {
405       return errors::InvalidArgument(
406           "'inputs' rank must be the same as 'num_concats' length ",
407           num_concats_.size(), ", but got rank ", slice_shape.dims(), ".");
408     }
409     for (int i = 1; i < num_slices_; ++i) {
410       const TensorShape& slice_shape_i = input_shapes[i];
411       if (slice_shape != slice_shape_i) {
412         return errors::InvalidArgument(
413             "'inputs' must all have the same expected shape ", slice_shape,
414             ", but got ", slice_shape_i, " at index ", i, ".");
415       }
416     }
417 
418     const int rank = input_shapes[0].dims();
419     for (int dim = 0; dim < rank; ++dim) {
420       const int max_dim_size = slice_shape.dim_size(dim) * num_concats_[dim];
421       if (paddings_[dim] > max_dim_size) {
422         return errors::InvalidArgument(
423             "'paddings' must not exceed expected output shape dimension ",
424             max_dim_size, " at index ", dim, ", but got ", paddings_[dim], ".");
425       }
426       output_shape.push_back(max_dim_size - paddings_[dim]);
427     }
428 
429     return OkStatus();
430   }
431 
432   std::vector<int64_t> num_concats_;
433   int num_slices_ = 1;
434   std::vector<int64_t> paddings_;
435   bool has_paddings_ = false;
436 };
437 
438 class XlaConcatNDOp : public XlaConcatNDBaseOp {
439  public:
XlaConcatNDOp(OpKernelConstruction * ctx)440   explicit XlaConcatNDOp(OpKernelConstruction* ctx) : XlaConcatNDBaseOp(ctx) {}
441 
Compile(XlaOpKernelContext * ctx)442   void Compile(XlaOpKernelContext* ctx) override {
443     auto output_or = this->CompileInternal(ctx);
444     OP_REQUIRES_OK(ctx, output_or.status());
445     ctx->SetOutput(/*index=*/0, output_or.ValueOrDie());
446   }
447 };
448 
449 REGISTER_XLA_OP(Name("XlaConcatND"), XlaConcatNDOp);
450 
451 class AssignVariableXlaConcatNDOp : public XlaConcatNDBaseOp {
452  public:
AssignVariableXlaConcatNDOp(OpKernelConstruction * ctx)453   explicit AssignVariableXlaConcatNDOp(OpKernelConstruction* ctx)
454       : XlaConcatNDBaseOp(ctx) {}
455 
Compile(XlaOpKernelContext * ctx)456   void Compile(XlaOpKernelContext* ctx) override {
457     auto output_or = this->CompileInternal(ctx);
458     OP_REQUIRES_OK(ctx, output_or.status());
459     OP_REQUIRES_OK(ctx,
460                    ctx->AssignVariable("resource", dtype_, output_or.value()));
461   }
462 };
463 
464 REGISTER_XLA_OP(Name("AssignVariableXlaConcatND"), AssignVariableXlaConcatNDOp);
465 
466 }  // namespace
467 }  // namespace tensorflow
468