• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <vector>
17 
18 #include "absl/algorithm/container.h"
19 #include "tensorflow/compiler/tf2xla/literal_util.h"
20 #include "tensorflow/compiler/tf2xla/type_util.h"
21 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
25 #include "tensorflow/compiler/xla/client/lib/comparators.h"
26 #include "tensorflow/compiler/xla/client/lib/constants.h"
27 #include "tensorflow/compiler/xla/client/xla_builder.h"
28 #include "tensorflow/compiler/xla/comparison_util.h"
29 #include "tensorflow/compiler/xla/shape.h"
30 #include "tensorflow/compiler/xla/shape_util.h"
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/ops_util.h"
33 #include "tensorflow/core/framework/register_types.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/lib/core/status.h"
36 #include "tensorflow/core/platform/errors.h"
37 #include "tensorflow/core/tpu/tpu_defs.h"
38 
39 namespace tensorflow {
40 namespace {
41 
42 class DynamicPartitionOp : public XlaOpKernel {
43  public:
DynamicPartitionOp(OpKernelConstruction * ctx)44   explicit DynamicPartitionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
45     OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_));
46   }
47 
48   // Returns a S32 tensor representing how many items in `input` are equal to
49   // `target`
CountS32(XlaOpKernelContext * ctx,xla::XlaOp input,int64_t target)50   xla::XlaOp CountS32(XlaOpKernelContext* ctx, xla::XlaOp input,
51                       int64_t target) {
52     xla::XlaOp equal_dim =
53         xla::Compare(input, xla::ConstantR0<int32>(ctx->builder(), target), {},
54                      xla::ComparisonDirection::kEq);
55     xla::XlaOp casted = xla::ConvertElementType(equal_dim, xla::S32);
56     return xla::ReduceAll(
57         casted, xla::Zero(ctx->builder(), xla::S32),
58         xla::CreateScalarAddComputation(xla::S32, ctx->builder()));
59   }
60 
61   std::pair<std::vector<xla::XlaOp>, std::vector<xla::XlaOp>>
DynamicPartition1D(XlaOpKernelContext * ctx,xla::XlaOp data_1d,xla::XlaOp partitions_1d,const xla::Shape & data_1d_shape,const xla::Shape & partition_1d_shape)62   DynamicPartition1D(XlaOpKernelContext* ctx, xla::XlaOp data_1d,
63                      xla::XlaOp partitions_1d, const xla::Shape& data_1d_shape,
64                      const xla::Shape& partition_1d_shape) {
65     int64_t input_count = data_1d_shape.dimensions(0);
66     std::vector<xla::XlaOp> to_sort = {partitions_1d, data_1d};
67     std::vector<xla::PrimitiveType> types_to_sort = {
68         partition_1d_shape.element_type(), data_1d_shape.element_type()};
69     xla::XlaOp sorted = xla::Sort(
70         to_sort, xla::CreateScalarLtComputation(types_to_sort, ctx->builder()),
71         /*dimension=*/0,
72         /*is_stable=*/true);
73     xla::XlaOp sorted_partitions = xla::GetTupleElement(sorted, 0);
74     xla::XlaOp sorted_data = xla::GetTupleElement(sorted, 1);
75 
76     // `partition_length[i]` is length of partition_i
77     std::vector<xla::XlaOp> partition_length(num_partitions_);
78     // `partition_start[i]` is sum(partition_start[0:i])
79     std::vector<xla::XlaOp> partition_start(num_partitions_);
80     xla::XlaOp count_so_far = xla::Zero(ctx->builder(), xla::S32);
81     for (int64_t i = 0; i < num_partitions_; ++i) {
82       xla::XlaOp count = CountS32(ctx, sorted_partitions, /*target=*/i);
83       partition_length[i] = count;
84       partition_start[i] = count_so_far;
85       count_so_far = xla::Add(count_so_far, count);
86     }
87 
88     // Pad input with `input_count` to avoid OOB -- dynamic slice with
89     // OOB slice produces undefined result.
90     xla::PaddingConfig padding_config;
91     auto* dims = padding_config.add_dimensions();
92     dims->set_edge_padding_low(0);
93     dims->set_edge_padding_high(input_count);
94     dims->set_interior_padding(0);
95     auto padded_data =
96         xla::Pad(sorted_data, xla::Zero(ctx->builder(), ctx->input_xla_type(0)),
97                  padding_config);
98     std::vector<xla::XlaOp> output(num_partitions_);
99     for (int64_t i = 0; i < num_partitions_; ++i) {
100       // Dynamic size will be set later after this function.
101       padded_data = xla::RemoveDynamicDimension(padded_data, 0);
102       // Slice full size out of the input starting from the offsets.
103       auto sliced =
104           xla::DynamicSlice(padded_data, {partition_start[i]}, {input_count});
105       output[i] = sliced;
106     }
107     return {output, partition_length};
108   }
109 
Compile(XlaOpKernelContext * ctx)110   void Compile(XlaOpKernelContext* ctx) override {
111     xla::Shape data_shape = ctx->InputXlaShape(0).value();
112     xla::Shape partition_shape = ctx->InputXlaShape(1).value();
113     xla::XlaOp data = ctx->Input(0);
114     xla::XlaOp partitions = ctx->Input(1);
115     std::vector<int64_t> partitions_static;
116     bool partitions_are_static =
117         ctx->ConstantInputReshapedToIntVector(1, &partitions_static).ok();
118     // We know how to solve DynamicPartition on 1D inputs using
119     // DynamicPartition1D. For other input, we do two things:
120     //
121     // 1. If partition_shape has lower rank than data_shape, we broadcast
122     // partition_shape so it's the same as data_shape. This makes
123     // partition_shape the same as data_shape.
124     //
125     // 2. If the data_shape has rank higher than 1, we reshape both data and
126     // partition to R1. This reduces the problem to 1D, which we've already
127     // solved using DynamicPartition1D.
128     //
129     // 3. We reshape the result of DynamicPartition1D back from 1D to output
130     // shape.
131     if (data_shape.rank() > partition_shape.rank()) {
132       // Broadcast parititon_shape so that it can be the same as data_shape.
133       std::vector<int64_t> broadcasted_dims;
134       auto rank = partition_shape.rank();
135       broadcasted_dims.reserve(rank);
136       for (int64_t i = 0; i < rank; ++i) {
137         broadcasted_dims.push_back(i);
138       }
139       partitions = xla::BroadcastInDim(partitions, data_shape.dimensions(),
140                                        broadcasted_dims);
141     }
142 
143     // Output shape bounded is calculated by
144     // [count(partitions)] + data.shape[partitions.ndim:]
145     // See also the output shape calculation at
146     // https://www.tensorflow.org/api_docs/python/tf/dynamic_partition
147     std::vector<int64_t> output_shape_bound_dims;
148     output_shape_bound_dims.push_back(
149         xla::ShapeUtil::ElementsIn(partition_shape));
150     int64_t count_diff = 1;
151     for (int64_t i = partition_shape.rank(); i < data_shape.rank(); ++i) {
152       output_shape_bound_dims.push_back(data_shape.dimensions(i));
153       count_diff *= data_shape.dimensions(i);
154     }
155 
156     int64_t input_count = xla::ShapeUtil::ElementsIn(data_shape);
157     auto data_1d = xla::Reshape(data, {input_count});
158     auto partitions_1d = xla::Reshape(partitions, {input_count});
159     xla::Shape data_1d_shape =
160         xla::ShapeUtil::MakeShape(data_shape.element_type(), {input_count});
161 
162     xla::Shape partitions_1d_shape = xla::ShapeUtil::MakeShape(
163         partition_shape.element_type(), {input_count});
164 
165     std::vector<xla::XlaOp> output, partition_length;
166     std::tie(output, partition_length) = DynamicPartition1D(
167         ctx, data_1d, partitions_1d, data_1d_shape, partitions_1d_shape);
168     for (int64_t i = 0; i < num_partitions_; ++i) {
169       auto reshape = xla::Reshape(output[i], output_shape_bound_dims);
170       if (partitions_are_static) {
171         int64_t size = absl::c_count(partitions_static, i);
172         ctx->SetOutput(i, xla::SliceInDim(reshape, 0, size, 1, 0));
173       } else {
174         xla::XlaOp length;
175         if (count_diff != 0) {
176           length = xla::Div(partition_length[i],
177                             xla::ConstantR0<int32>(ctx->builder(), count_diff));
178         } else {
179           length = CountS32(ctx, ctx->Input(1), /*target=*/i);
180         }
181         ctx->SetOutput(i, xla::SetDimensionSize(reshape, length, 0));
182       }
183     }
184   }
185 
186  private:
187   int64_t num_partitions_;
188 };
189 
190 REGISTER_XLA_OP(Name("DynamicPartition"), DynamicPartitionOp);
191 
192 }  // namespace
193 }  // namespace tensorflow
194