• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 // Check whether updates.shape = indices.shape[:batch_dim] +
33 // buffer_shape[num_index_dims:]
ValidateUpdateShape(const TensorShape & buffer_shape,const TensorShape & indices_shape,const TensorShape & updates_shape)34 Status ValidateUpdateShape(const TensorShape& buffer_shape,
35                            const TensorShape& indices_shape,
36                            const TensorShape& updates_shape) {
37   if (indices_shape.dims() < 1) {
38     return errors::InvalidArgument(
39         "indices shape must have >= 1 dimension; got ",
40         indices_shape.DebugString());
41   }
42 
43   const int64_t num_index_dims =
44       indices_shape.dim_size(indices_shape.dims() - 1);
45   const int64_t batch_dim = indices_shape.dims() - 1;
46 
47   auto shape_err = [&]() {
48     return errors::InvalidArgument(
49         "Must have updates.shape = indices.shape[:batch_dim] + ",
50         "buffer_shape[num_index_dims:], got updates.shape: ",
51         updates_shape.DebugString(),
52         ", indices.shape: ", indices_shape.DebugString(),
53         ", buffer_shape: ", buffer_shape.DebugString(),
54         ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
55   };
56 
57   if (updates_shape.dims() < batch_dim) return shape_err();
58   if (buffer_shape.dims() <
59       num_index_dims + (updates_shape.dims() - batch_dim)) {
60     return shape_err();
61   }
62   if (updates_shape.dims() !=
63       batch_dim + buffer_shape.dims() - num_index_dims) {
64     return shape_err();
65   }
66   for (int d = 0; d < batch_dim; ++d) {
67     if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) {
68       return shape_err();
69     }
70   }
71   for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
72     if (updates_shape.dim_size(d + batch_dim) !=
73         buffer_shape.dim_size(d + num_index_dims)) {
74       return shape_err();
75     }
76   }
77   return Status::OK();
78 }
79 
80 class ScatterNdOp : public XlaOpKernel {
81  public:
ScatterNdOp(OpKernelConstruction * context)82   explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
83 
Compile(XlaOpKernelContext * context)84   void Compile(XlaOpKernelContext* context) override {
85     DataType dtype = context->input_type(1);
86 
87     TensorShape indices_shape = context->InputShape(0);
88     TensorShape updates_shape = context->InputShape(1);
89 
90     TensorShape buffer_shape;
91     OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape));
92 
93     OP_REQUIRES(
94         context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
95         errors::InvalidArgument("Output must be at least 1-D, ",
96                                 "got shape: ", buffer_shape.DebugString()));
97 
98     OP_REQUIRES(
99         context,
100         buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
101                                             updates_shape.num_elements() == 0),
102         errors::InvalidArgument(
103             "Indices and updates specified for empty output. indices shape: ",
104             indices_shape.DebugString()));
105 
106     OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape,
107                                                 updates_shape));
108 
109     xla::XlaBuilder* builder = context->builder();
110     auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype),
111                                  buffer_shape.dim_sizes());
112     auto indices = context->Input(0);
113     auto updates = context->Input(1);
114     auto combine =
115         context->input_xla_type(1) == xla::PRED ? CombineBool : CombineNum;
116     auto result =
117         XlaScatter(buffer, updates, indices,
118                    /*indices_are_vectors=*/true, /*combiner=*/combine, builder);
119     OP_REQUIRES_OK(context, result.status());
120     context->SetOutput(0, result.ValueOrDie());
121   }
122 
123  private:
CombineNum(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)124   static xla::XlaOp CombineNum(const xla::XlaOp x, const xla::XlaOp y,
125                                xla::XlaBuilder* builder) {
126     (void)builder;
127     return xla::Add(x, y);
128   }
CombineBool(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)129   static xla::XlaOp CombineBool(const xla::XlaOp x, const xla::XlaOp y,
130                                 xla::XlaBuilder* builder) {
131     (void)builder;
132     return xla::Or(x, y);
133   }
134 };
135 
136 REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"),
137                 ScatterNdOp);
138 
CompileTensorScatter(XlaOpKernelContext * context,const std::function<xla::XlaOp (xla::XlaOp,xla::XlaOp,xla::XlaBuilder *)> & combiner)139 void CompileTensorScatter(
140     XlaOpKernelContext* context,
141     const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
142         combiner) {
143   TensorShape buffer_shape = context->InputShape(0);
144   TensorShape indices_shape = context->InputShape(1);
145   TensorShape updates_shape = context->InputShape(2);
146 
147   OP_REQUIRES(
148       context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
149       errors::InvalidArgument("Output must be at least 1-D, ",
150                               "got shape: ", buffer_shape.DebugString()));
151 
152   OP_REQUIRES(
153       context,
154       buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
155                                           updates_shape.num_elements() == 0),
156       errors::InvalidArgument(
157           "Indices and updates specified for empty output. indices shape: ",
158           indices_shape.DebugString()));
159 
160   OP_REQUIRES_OK(
161       context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape));
162 
163   xla::XlaBuilder* builder = context->builder();
164   auto buffer = context->Input(0);
165   auto indices = context->Input(1);
166   auto updates = context->Input(2);
167   auto result = XlaScatter(buffer, updates, indices,
168                            /*indices_are_vectors=*/true, combiner, builder);
169   OP_REQUIRES_OK(context, result.status());
170   context->SetOutput(0, result.ValueOrDie());
171 }
172 
173 class TensorScatterAddOp : public XlaOpKernel {
174  public:
TensorScatterAddOp(OpKernelConstruction * context)175   explicit TensorScatterAddOp(OpKernelConstruction* context)
176       : XlaOpKernel(context) {}
177 
Compile(XlaOpKernelContext * context)178   void Compile(XlaOpKernelContext* context) override {
179     CompileTensorScatter(context,
180                          [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
181                            return xla::Add(x, y);
182                          });
183   }
184 };
185 
186 class TensorScatterMaxOp : public XlaOpKernel {
187  public:
TensorScatterMaxOp(OpKernelConstruction * context)188   explicit TensorScatterMaxOp(OpKernelConstruction* context)
189       : XlaOpKernel(context) {}
190 
Compile(XlaOpKernelContext * context)191   void Compile(XlaOpKernelContext* context) override {
192     CompileTensorScatter(context,
193                          [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
194                            return xla::Max(x, y);
195                          });
196   }
197 };
198 
199 class TensorScatterMinOp : public XlaOpKernel {
200  public:
TensorScatterMinOp(OpKernelConstruction * context)201   explicit TensorScatterMinOp(OpKernelConstruction* context)
202       : XlaOpKernel(context) {}
203 
Compile(XlaOpKernelContext * context)204   void Compile(XlaOpKernelContext* context) override {
205     CompileTensorScatter(context,
206                          [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
207                            return xla::Min(x, y);
208                          });
209   }
210 };
211 
212 class TensorScatterSubOp : public XlaOpKernel {
213  public:
TensorScatterSubOp(OpKernelConstruction * context)214   explicit TensorScatterSubOp(OpKernelConstruction* context)
215       : XlaOpKernel(context) {}
216 
Compile(XlaOpKernelContext * context)217   void Compile(XlaOpKernelContext* context) override {
218     CompileTensorScatter(context,
219                          [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
220                            return xla::Sub(x, y);
221                          });
222   }
223 };
224 
225 class TensorScatterUpdateOp : public XlaOpKernel {
226  public:
TensorScatterUpdateOp(OpKernelConstruction * context)227   explicit TensorScatterUpdateOp(OpKernelConstruction* context)
228       : XlaOpKernel(context) {}
229 
Compile(XlaOpKernelContext * context)230   void Compile(XlaOpKernelContext* context) override {
231     CompileTensorScatter(
232         context, [](xla::XlaOp, xla::XlaOp y, xla::XlaBuilder*) { return y; });
233   }
234 };
235 
236 REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
237 REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
238 REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
239 REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
240 REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
241 
242 }  // namespace
243 }  // namespace tensorflow
244