• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
17 #include "tensorflow/compiler/tf2xla/type_util.h"
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/lib/constants.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 
24 namespace tensorflow {
25 namespace {
26 
27 class UnsortedSegmentReduce : public XlaOpKernel {
28  public:
UnsortedSegmentReduce(OpKernelConstruction * ctx)29   explicit UnsortedSegmentReduce(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
30     DataType dtype;
31     OP_REQUIRES_OK(ctx, ctx->GetAttr("T", &dtype));
32     OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(dtype, &type_));
33   }
34 
35   // The initial value to initialize elements of the output to.
36   virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
37 
38   // A function to combine two scalars with the same index (e.g., sum).
39   virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0;
40 
Compile(XlaOpKernelContext * ctx)41   void Compile(XlaOpKernelContext* ctx) override {
42     // output = unsorted_segment_sum(data, indices, num_segments)
43     // Compute a tensor such that:
44     //    output[i] = sum over {j where indices[j] == i} of data[j]
45     //    output[i] == 0 if i does not appear in indices
46     //
47     // Contrast with segment_sum(), which assumes indices are sorted and that
48     // max(indices)+1 is the desired size of the output.
49     //
50     // The returned output tensor has the same type as data, and the same shape
51     // as data with the first indices.rank dimensions are replaced
52     // by a single dimension with size num_segments.
53     auto data = ctx->Input(0);
54     TensorShape data_shape = ctx->InputShape(0);
55 
56     auto indices = ctx->Input(1);
57     TensorShape indices_shape = ctx->InputShape(1);
58 
59     int64 num_segments;
60     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &num_segments));
61 
62     OP_REQUIRES(ctx, data_shape.dims() >= indices_shape.dims(),
63                 errors::InvalidArgument(type_string(),
64                                         " requires that indices' rank be"
65                                         " less than or equal to data's rank."));
66     // Validate that indices.shape is a prefix of data.shape.
67     for (int d = 0; d < indices_shape.dims(); ++d) {
68       OP_REQUIRES(
69           ctx, (data_shape.dim_size(d) == indices_shape.dim_size(d)),
70           errors::InvalidArgument(type_string(),
71                                   " requires indices shape to be prefix"
72                                   " of data_shape, but dimension ",
73                                   d, " differs ", data_shape.dim_size(d),
74                                   " vs. ", indices_shape.dim_size(d)));
75     }
76     xla::XlaBuilder* builder = ctx->builder();
77     TensorShape buffer_shape = data_shape;
78     buffer_shape.RemoveDimRange(0, indices_shape.dims());
79     buffer_shape.InsertDim(0, num_segments);
80     auto buffer =
81         xla::Broadcast(InitialValue(builder), buffer_shape.dim_sizes());
82 
83     auto combiner = [this](xla::XlaOp a, xla::XlaOp b,
84                            xla::XlaBuilder* builder) { return Combine(a, b); };
85 
86     auto result = XlaScatter(buffer, /*updates=*/data, indices,
87                              /*indices_are_vectors=*/false, combiner, builder);
88     OP_REQUIRES_OK(ctx, result.status());
89     ctx->SetOutput(0, result.ValueOrDie());
90   }
91 
92  protected:
93   xla::PrimitiveType type_;
94 };
95 
96 class UnsortedSegmentSum : public UnsortedSegmentReduce {
97  public:
UnsortedSegmentSum(OpKernelConstruction * ctx)98   explicit UnsortedSegmentSum(OpKernelConstruction* ctx)
99       : UnsortedSegmentReduce(ctx) {}
100 
InitialValue(xla::XlaBuilder * builder)101   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
102     return xla::Zero(builder, type_);
103   };
Combine(xla::XlaOp a,xla::XlaOp b)104   xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; };
105 };
106 
107 REGISTER_XLA_OP(
108     Name("UnsortedSegmentSum").CompileTimeConstantInput("num_segments"),
109     UnsortedSegmentSum);
110 
111 class UnsortedSegmentProd : public UnsortedSegmentReduce {
112  public:
UnsortedSegmentProd(OpKernelConstruction * ctx)113   explicit UnsortedSegmentProd(OpKernelConstruction* ctx)
114       : UnsortedSegmentReduce(ctx) {}
115 
InitialValue(xla::XlaBuilder * builder)116   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
117     return xla::One(builder, type_);
118   };
Combine(xla::XlaOp a,xla::XlaOp b)119   xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; };
120 };
121 
122 REGISTER_XLA_OP(
123     Name("UnsortedSegmentProd").CompileTimeConstantInput("num_segments"),
124     UnsortedSegmentProd);
125 
126 class UnsortedSegmentMin : public UnsortedSegmentReduce {
127  public:
UnsortedSegmentMin(OpKernelConstruction * ctx)128   explicit UnsortedSegmentMin(OpKernelConstruction* ctx)
129       : UnsortedSegmentReduce(ctx) {}
130 
InitialValue(xla::XlaBuilder * builder)131   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
132     return xla::MaxFiniteValue(builder, type_);
133   };
Combine(xla::XlaOp a,xla::XlaOp b)134   xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
135     return xla::Min(a, b);
136   };
137 };
138 
139 REGISTER_XLA_OP(
140     Name("UnsortedSegmentMin").CompileTimeConstantInput("num_segments"),
141     UnsortedSegmentMin);
142 
143 class UnsortedSegmentMax : public UnsortedSegmentReduce {
144  public:
UnsortedSegmentMax(OpKernelConstruction * ctx)145   explicit UnsortedSegmentMax(OpKernelConstruction* ctx)
146       : UnsortedSegmentReduce(ctx) {}
147 
InitialValue(xla::XlaBuilder * builder)148   xla::XlaOp InitialValue(xla::XlaBuilder* builder) override {
149     return xla::MinFiniteValue(builder, type_);
150   };
Combine(xla::XlaOp a,xla::XlaOp b)151   xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override {
152     return xla::Max(a, b);
153   };
154 };
155 
156 REGISTER_XLA_OP(
157     Name("UnsortedSegmentMax").CompileTimeConstantInput("num_segments"),
158     UnsortedSegmentMax);
159 
160 }  // namespace
161 }  // namespace tensorflow
162