1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/lib/scatter.h" 17 #include "tensorflow/compiler/tf2xla/type_util.h" 18 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" 20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 21 #include "tensorflow/compiler/xla/client/lib/constants.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 24 namespace tensorflow { 25 namespace { 26 27 class UnsortedSegmentReduce : public XlaOpKernel { 28 public: UnsortedSegmentReduce(OpKernelConstruction * ctx)29 explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 30 DataType dtype; 31 OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype)); 32 OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_)); 33 } 34 35 // The initial value to initialize elements of the output to. 36 virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0; 37 38 // A function to combine two scalars with the same index (e.g., sum). 39 virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0; 40 Compile(XlaOpKernelContext * ctx)41 void Compile(XlaOpKernelContext* ctx) override { 42 // output = unsorted_segment_sum(data, indices, num_segments) 43 // Compute a tensor such that: 44 // output[i] = sum over {j where indices[j] == i} of data[j] 45 // output[i] == 0 if i does not appear in indices 46 // 47 // Contrast with segment_sum(), which assumes indices are sorted and that 48 // max(indices)+1 is the desired size of the output. 49 // 50 // The returned output tensor has the same type as data, and the same shape 51 // as data with the first indices.rank dimensions are replaced 52 // by a single dimension with size num_segments. 53 auto data = ctx->Input(0); 54 TensorShape data_shape = ctx->InputShape(0); 55 56 auto indices = ctx->Input(1); 57 TensorShape indices_shape = ctx->InputShape(1); 58 59 int64 num_segments; 60 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments)); 61 62 OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(), 63 errors::InvalidArgument(type_string(), 64 " requires that indices' rank be" 65 " less than or equal to data's rank.")); 66 // Validate that indices.shape is a prefix of data.shape. 67 for (int d = 0; d < indices_shape.dims(); ++d) { 68 OP_REQUIRES( 69 ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)), 70 errors::InvalidArgument(type_string(), 71 " requires indices shape to be prefix" 72 " of data_shape, but dimension ", 73 d, " differs ", data_shape.dim_size(d), 74 " vs. ", indices_shape.dim_size(d))); 75 } 76 xla::XlaBuilder* builder = ctx->builder(); 77 TensorShape buffer_shape = data_shape; 78 buffer_shape.RemoveDimRange(0, indices_shape.dims()); 79 buffer_shape.InsertDim(0, num_segments); 80 auto buffer = 81 xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes()); 82 83 auto combiner = [this](xla::XlaOp a, xla::XlaOp b, 84 xla::XlaBuilder* builder) { return Combine(a, b); }; 85 86 auto result = XlaScatter(buffer, /*updates=*/data, indices, 87 /*indices_are_vectors=*/false, combiner, builder); 88 OP_REQUIRES_OK(ctx, result.status()); 89 ctx->SetOutput(0, result.ValueOrDie()); 90 } 91 92 protected: 93 xla::PrimitiveType type_; 94 }; 95 96 class UnsortedSegmentSum : public UnsortedSegmentReduce { 97 public: UnsortedSegmentSum(OpKernelConstruction * ctx)98 explicit UnsortedSegmentSum(OpKernelConstruction* ctx) 99 : UnsortedSegmentReduce(ctx) {} 100 InitialValue(xla::XlaBuilder * builder)101 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 102 return xla::Zero(builder, type_); 103 }; Combine(xla::XlaOp a,xla::XlaOp b)104 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; 105 }; 106 107 REGISTER_XLA_OP( 108 Name("UnsortedSegmentSum").CompileTimeConstantInput("num_segments"), 109 UnsortedSegmentSum); 110 111 class UnsortedSegmentProd : public UnsortedSegmentReduce { 112 public: UnsortedSegmentProd(OpKernelConstruction * ctx)113 explicit UnsortedSegmentProd(OpKernelConstruction* ctx) 114 : UnsortedSegmentReduce(ctx) {} 115 InitialValue(xla::XlaBuilder * builder)116 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 117 return xla::One(builder, type_); 118 }; Combine(xla::XlaOp a,xla::XlaOp b)119 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; 120 }; 121 122 REGISTER_XLA_OP( 123 Name("UnsortedSegmentProd").CompileTimeConstantInput("num_segments"), 124 UnsortedSegmentProd); 125 126 class UnsortedSegmentMin : public UnsortedSegmentReduce { 127 public: UnsortedSegmentMin(OpKernelConstruction * ctx)128 explicit UnsortedSegmentMin(OpKernelConstruction* ctx) 129 : UnsortedSegmentReduce(ctx) {} 130 InitialValue(xla::XlaBuilder * builder)131 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 132 return xla::MaxFiniteValue(builder, type_); 133 }; Combine(xla::XlaOp a,xla::XlaOp b)134 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { 135 return xla::Min(a, b); 136 }; 137 }; 138 139 REGISTER_XLA_OP( 140 Name("UnsortedSegmentMin").CompileTimeConstantInput("num_segments"), 141 UnsortedSegmentMin); 142 143 class UnsortedSegmentMax : public UnsortedSegmentReduce { 144 public: UnsortedSegmentMax(OpKernelConstruction * ctx)145 explicit UnsortedSegmentMax(OpKernelConstruction* ctx) 146 : UnsortedSegmentReduce(ctx) {} 147 InitialValue(xla::XlaBuilder * builder)148 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { 149 return xla::MinFiniteValue(builder, type_); 150 }; Combine(xla::XlaOp a,xla::XlaOp b)151 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { 152 return xla::Max(a, b); 153 }; 154 }; 155 156 REGISTER_XLA_OP( 157 Name("UnsortedSegmentMax").CompileTimeConstantInput("num_segments"), 158 UnsortedSegmentMax); 159 160 } // namespace 161 } // namespace tensorflow 162