1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <vector> 17 18 #include "absl/algorithm/container.h" 19 #include "tensorflow/compiler/tf2xla/literal_util.h" 20 #include "tensorflow/compiler/tf2xla/type_util.h" 21 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 22 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 23 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 24 #include "tensorflow/compiler/xla/client/lib/arithmetic.h" 25 #include "tensorflow/compiler/xla/client/lib/comparators.h" 26 #include "tensorflow/compiler/xla/client/lib/constants.h" 27 #include "tensorflow/compiler/xla/client/xla_builder.h" 28 #include "tensorflow/compiler/xla/comparison_util.h" 29 #include "tensorflow/compiler/xla/shape.h" 30 #include "tensorflow/compiler/xla/shape_util.h" 31 #include "tensorflow/core/framework/op_kernel.h" 32 #include "tensorflow/core/framework/ops_util.h" 33 #include "tensorflow/core/framework/register_types.h" 34 #include "tensorflow/core/framework/tensor.h" 35 #include "tensorflow/core/lib/core/status.h" 36 #include "tensorflow/core/platform/errors.h" 37 #include "tensorflow/core/tpu/tpu_defs.h" 38 39 namespace tensorflow { 40 namespace { 41 42 class DynamicPartitionOp : public XlaOpKernel { 43 public: DynamicPartitionOp(OpKernelConstruction * ctx)44 explicit DynamicPartitionOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 45 OP_REQUIRES_OK(ctx, ctx->GetAttr("num_partitions", &num_partitions_)); 46 } 47 48 // Returns a S32 tensor representing how many items in `input` are equal to 49 // `target` CountS32(XlaOpKernelContext * ctx,xla::XlaOp input,int64_t target)50 xla::XlaOp CountS32(XlaOpKernelContext* ctx, xla::XlaOp input, 51 int64_t target) { 52 xla::XlaOp equal_dim = 53 xla::Compare(input, xla::ConstantR0<int32>(ctx->builder(), target), {}, 54 xla::ComparisonDirection::kEq); 55 xla::XlaOp casted = xla::ConvertElementType(equal_dim, xla::S32); 56 return xla::ReduceAll( 57 casted, xla::Zero(ctx->builder(), xla::S32), 58 xla::CreateScalarAddComputation(xla::S32, ctx->builder())); 59 } 60 61 std::pair<std::vector<xla::XlaOp>, std::vector<xla::XlaOp>> DynamicPartition1D(XlaOpKernelContext * ctx,xla::XlaOp data_1d,xla::XlaOp partitions_1d,const xla::Shape & data_1d_shape,const xla::Shape & partition_1d_shape)62 DynamicPartition1D(XlaOpKernelContext* ctx, xla::XlaOp data_1d, 63 xla::XlaOp partitions_1d, const xla::Shape& data_1d_shape, 64 const xla::Shape& partition_1d_shape) { 65 int64_t input_count = data_1d_shape.dimensions(0); 66 std::vector<xla::XlaOp> to_sort = {partitions_1d, data_1d}; 67 std::vector<xla::PrimitiveType> types_to_sort = { 68 partition_1d_shape.element_type(), data_1d_shape.element_type()}; 69 xla::XlaOp sorted = xla::Sort( 70 to_sort, xla::CreateScalarLtComputation(types_to_sort, ctx->builder()), 71 /*dimension=*/0, 72 /*is_stable=*/true); 73 xla::XlaOp sorted_partitions = xla::GetTupleElement(sorted, 0); 74 xla::XlaOp sorted_data = xla::GetTupleElement(sorted, 1); 75 76 // `partition_length[i]` is length of partition_i 77 std::vector<xla::XlaOp> partition_length(num_partitions_); 78 // `partition_start[i]` is sum(partition_start[0:i]) 79 std::vector<xla::XlaOp> partition_start(num_partitions_); 80 xla::XlaOp count_so_far = xla::Zero(ctx->builder(), xla::S32); 81 for (int64_t i = 0; i < num_partitions_; ++i) { 82 xla::XlaOp count = CountS32(ctx, sorted_partitions, /*target=*/i); 83 partition_length[i] = count; 84 partition_start[i] = count_so_far; 85 count_so_far = xla::Add(count_so_far, count); 86 } 87 88 // Pad input with `input_count` to avoid OOB -- dynamic slice with 89 // OOB slice produces undefined result. 90 xla::PaddingConfig padding_config; 91 auto* dims = padding_config.add_dimensions(); 92 dims->set_edge_padding_low(0); 93 dims->set_edge_padding_high(input_count); 94 dims->set_interior_padding(0); 95 auto padded_data = 96 xla::Pad(sorted_data, xla::Zero(ctx->builder(), ctx->input_xla_type(0)), 97 padding_config); 98 std::vector<xla::XlaOp> output(num_partitions_); 99 for (int64_t i = 0; i < num_partitions_; ++i) { 100 // Dynamic size will be set later after this function. 101 padded_data = xla::RemoveDynamicDimension(padded_data, 0); 102 // Slice full size out of the input starting from the offsets. 103 auto sliced = 104 xla::DynamicSlice(padded_data, {partition_start[i]}, {input_count}); 105 output[i] = sliced; 106 } 107 return {output, partition_length}; 108 } 109 Compile(XlaOpKernelContext * ctx)110 void Compile(XlaOpKernelContext* ctx) override { 111 xla::Shape data_shape = ctx->InputXlaShape(0).value(); 112 xla::Shape partition_shape = ctx->InputXlaShape(1).value(); 113 xla::XlaOp data = ctx->Input(0); 114 xla::XlaOp partitions = ctx->Input(1); 115 std::vector<int64_t> partitions_static; 116 bool partitions_are_static = 117 ctx->ConstantInputReshapedToIntVector(1, &partitions_static).ok(); 118 // We know how to solve DynamicPartition on 1D inputs using 119 // DynamicPartition1D. For other input, we do two things: 120 // 121 // 1. If partition_shape has lower rank than data_shape, we broadcast 122 // partition_shape so it's the same as data_shape. This makes 123 // partition_shape the same as data_shape. 124 // 125 // 2. If the data_shape has rank higher than 1, we reshape both data and 126 // partition to R1. This reduces the problem to 1D, which we've already 127 // solved using DynamicPartition1D. 128 // 129 // 3. We reshape the result of DynamicPartition1D back from 1D to output 130 // shape. 131 if (data_shape.rank() > partition_shape.rank()) { 132 // Broadcast parititon_shape so that it can be the same as data_shape. 133 std::vector<int64_t> broadcasted_dims; 134 auto rank = partition_shape.rank(); 135 broadcasted_dims.reserve(rank); 136 for (int64_t i = 0; i < rank; ++i) { 137 broadcasted_dims.push_back(i); 138 } 139 partitions = xla::BroadcastInDim(partitions, data_shape.dimensions(), 140 broadcasted_dims); 141 } 142 143 // Output shape bounded is calculated by 144 // [count(partitions)] + data.shape[partitions.ndim:] 145 // See also the output shape calculation at 146 // https://www.tensorflow.org/api_docs/python/tf/dynamic_partition 147 std::vector<int64_t> output_shape_bound_dims; 148 output_shape_bound_dims.push_back( 149 xla::ShapeUtil::ElementsIn(partition_shape)); 150 int64_t count_diff = 1; 151 for (int64_t i = partition_shape.rank(); i < data_shape.rank(); ++i) { 152 output_shape_bound_dims.push_back(data_shape.dimensions(i)); 153 count_diff *= data_shape.dimensions(i); 154 } 155 156 int64_t input_count = xla::ShapeUtil::ElementsIn(data_shape); 157 auto data_1d = xla::Reshape(data, {input_count}); 158 auto partitions_1d = xla::Reshape(partitions, {input_count}); 159 xla::Shape data_1d_shape = 160 xla::ShapeUtil::MakeShape(data_shape.element_type(), {input_count}); 161 162 xla::Shape partitions_1d_shape = xla::ShapeUtil::MakeShape( 163 partition_shape.element_type(), {input_count}); 164 165 std::vector<xla::XlaOp> output, partition_length; 166 std::tie(output, partition_length) = DynamicPartition1D( 167 ctx, data_1d, partitions_1d, data_1d_shape, partitions_1d_shape); 168 for (int64_t i = 0; i < num_partitions_; ++i) { 169 auto reshape = xla::Reshape(output[i], output_shape_bound_dims); 170 if (partitions_are_static) { 171 int64_t size = absl::c_count(partitions_static, i); 172 ctx->SetOutput(i, xla::SliceInDim(reshape, 0, size, 1, 0)); 173 } else { 174 xla::XlaOp length; 175 if (count_diff != 0) { 176 length = xla::Div(partition_length[i], 177 xla::ConstantR0<int32>(ctx->builder(), count_diff)); 178 } else { 179 length = CountS32(ctx, ctx->Input(1), /*target=*/i); 180 } 181 ctx->SetOutput(i, xla::SetDimensionSize(reshape, length, 0)); 182 } 183 } 184 } 185 186 private: 187 int64_t num_partitions_; 188 }; 189 190 REGISTER_XLA_OP(Name("DynamicPartition"), DynamicPartitionOp); 191 192 } // namespace 193 } // namespace tensorflow 194