• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 // Check whether updates.shape = indices.shape[:batch_dim] +
33 // buffer_shape[num_index_dims:]
ValidateUpdateShape(const TensorShape & buffer_shape,const TensorShape & indices_shape,const TensorShape & updates_shape,bool broadcast_scalar_update)34 Status ValidateUpdateShape(const TensorShape& buffer_shape,
35                            const TensorShape& indices_shape,
36                            const TensorShape& updates_shape,
37                            bool broadcast_scalar_update) {
38   if (indices_shape.dims() < 1) {
39     return errors::InvalidArgument(
40         "indices shape must have >= 1 dimension; got ",
41         indices_shape.DebugString());
42   }
43 
44   const int64_t num_index_dims =
45       indices_shape.dim_size(indices_shape.dims() - 1);
46   const int64_t batch_dim = indices_shape.dims() - 1;
47 
48   auto shape_err = [&]() {
49     return errors::InvalidArgument(
50         "Must have updates.shape = indices.shape[:batch_dim] + ",
51         "buffer_shape[num_index_dims:], got updates.shape: ",
52         updates_shape.DebugString(),
53         ", indices.shape: ", indices_shape.DebugString(),
54         ", buffer_shape: ", buffer_shape.DebugString(),
55         ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
56   };
57 
58   if (updates_shape.dims() == 0 && broadcast_scalar_update) {
59     return OkStatus();
60   }
61 
62   if (updates_shape.dims() < batch_dim) return shape_err();
63   if (buffer_shape.dims() <
64       num_index_dims + (updates_shape.dims() - batch_dim)) {
65     return shape_err();
66   }
67   if (updates_shape.dims() !=
68       batch_dim + buffer_shape.dims() - num_index_dims) {
69     return shape_err();
70   }
71   for (int d = 0; d < batch_dim; ++d) {
72     if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) {
73       return shape_err();
74     }
75   }
76   for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
77     if (updates_shape.dim_size(d + batch_dim) !=
78         buffer_shape.dim_size(d + num_index_dims)) {
79       return shape_err();
80     }
81   }
82   return OkStatus();
83 }
84 
85 class ScatterNdOp : public XlaOpKernel {
86  public:
ScatterNdOp(OpKernelConstruction * context)87   explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
88 
Compile(XlaOpKernelContext * context)89   void Compile(XlaOpKernelContext* context) override {
90     DataType dtype = context->input_type(1);
91 
92     TensorShape indices_shape = context->InputShape(0);
93     TensorShape updates_shape = context->InputShape(1);
94 
95     TensorShape buffer_shape;
96     OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape));
97 
98     OP_REQUIRES(
99         context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
100         errors::InvalidArgument("Output must be at least 1-D, ",
101                                 "got shape: ", buffer_shape.DebugString()));
102 
103     OP_REQUIRES(
104         context,
105         buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
106                                             updates_shape.num_elements() == 0),
107         errors::InvalidArgument(
108             "Indices and updates specified for empty output. indices shape: ",
109             indices_shape.DebugString()));
110 
111     OP_REQUIRES_OK(
112         context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape,
113                                      /*broadcast_scalar_update=*/false));
114 
115     xla::XlaBuilder* builder = context->builder();
116     auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype),
117                                  buffer_shape.dim_sizes());
118     auto indices = context->Input(0);
119     auto updates = context->Input(1);
120     auto combine =
121         context->input_xla_type(1) == xla::PRED ? CombineBool : CombineNum;
122     auto result =
123         XlaScatter(buffer, updates, indices,
124                    /*indices_are_vectors=*/true, /*combiner=*/combine, builder);
125     OP_REQUIRES_OK(context, result.status());
126     context->SetOutput(0, result.ValueOrDie());
127   }
128 
129  private:
CombineNum(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)130   static xla::XlaOp CombineNum(const xla::XlaOp x, const xla::XlaOp y,
131                                xla::XlaBuilder* builder) {
132     (void)builder;
133     return xla::Add(x, y);
134   }
CombineBool(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)135   static xla::XlaOp CombineBool(const xla::XlaOp x, const xla::XlaOp y,
136                                 xla::XlaBuilder* builder) {
137     (void)builder;
138     return xla::Or(x, y);
139   }
140 };
141 
142 REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"),
143                 ScatterNdOp);
144 
CompileTensorScatter(XlaOpKernelContext * context,const std::function<xla::XlaOp (xla::XlaOp,xla::XlaOp,xla::XlaBuilder *)> & combiner,bool broadcast_scalar_update)145 void CompileTensorScatter(
146     XlaOpKernelContext* context,
147     const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
148         combiner,
149     bool broadcast_scalar_update) {
150   TensorShape buffer_shape = context->InputShape(0);
151   TensorShape indices_shape = context->InputShape(1);
152   TensorShape updates_shape = context->InputShape(2);
153 
154   OP_REQUIRES(
155       context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
156       errors::InvalidArgument("Output must be at least 1-D, ",
157                               "got shape: ", buffer_shape.DebugString()));
158 
159   OP_REQUIRES(
160       context,
161       buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
162                                           updates_shape.num_elements() == 0),
163       errors::InvalidArgument(
164           "Indices and updates specified for empty output. indices shape: ",
165           indices_shape.DebugString()));
166 
167   OP_REQUIRES_OK(context,
168                  ValidateUpdateShape(buffer_shape, indices_shape, updates_shape,
169                                      broadcast_scalar_update));
170 
171   xla::XlaBuilder* builder = context->builder();
172   auto buffer = context->Input(0);
173   auto indices = context->Input(1);
174   auto updates = context->Input(2);
175   auto result = XlaScatter(buffer, updates, indices,
176                            /*indices_are_vectors=*/true, combiner, builder);
177   OP_REQUIRES_OK(context, result.status());
178   context->SetOutput(0, result.ValueOrDie());
179 }
180 
181 class TensorScatterAddOp : public XlaOpKernel {
182  public:
TensorScatterAddOp(OpKernelConstruction * context)183   explicit TensorScatterAddOp(OpKernelConstruction* context)
184       : XlaOpKernel(context) {}
185 
Compile(XlaOpKernelContext * context)186   void Compile(XlaOpKernelContext* context) override {
187     CompileTensorScatter(
188         context,
189         [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
190           return xla::Add(x, y);
191         },
192         /*broadcast_scalar_update=*/true);
193   }
194 };
195 
196 class TensorScatterMaxOp : public XlaOpKernel {
197  public:
TensorScatterMaxOp(OpKernelConstruction * context)198   explicit TensorScatterMaxOp(OpKernelConstruction* context)
199       : XlaOpKernel(context) {}
200 
Compile(XlaOpKernelContext * context)201   void Compile(XlaOpKernelContext* context) override {
202     CompileTensorScatter(
203         context,
204         [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
205           return xla::Max(x, y);
206         },
207         /*broadcast_scalar_update=*/false);
208   }
209 };
210 
211 class TensorScatterMinOp : public XlaOpKernel {
212  public:
TensorScatterMinOp(OpKernelConstruction * context)213   explicit TensorScatterMinOp(OpKernelConstruction* context)
214       : XlaOpKernel(context) {}
215 
Compile(XlaOpKernelContext * context)216   void Compile(XlaOpKernelContext* context) override {
217     CompileTensorScatter(
218         context,
219         [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
220           return xla::Min(x, y);
221         },
222         /*broadcast_scalar_update=*/false);
223   }
224 };
225 
226 class TensorScatterSubOp : public XlaOpKernel {
227  public:
TensorScatterSubOp(OpKernelConstruction * context)228   explicit TensorScatterSubOp(OpKernelConstruction* context)
229       : XlaOpKernel(context) {}
230 
Compile(XlaOpKernelContext * context)231   void Compile(XlaOpKernelContext* context) override {
232     CompileTensorScatter(
233         context,
234         [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
235           return xla::Sub(x, y);
236         },
237         /*broadcast_scalar_update=*/false);
238   }
239 };
240 
241 class TensorScatterUpdateOp : public XlaOpKernel {
242  public:
TensorScatterUpdateOp(OpKernelConstruction * context)243   explicit TensorScatterUpdateOp(OpKernelConstruction* context)
244       : XlaOpKernel(context) {}
245 
Compile(XlaOpKernelContext * context)246   void Compile(XlaOpKernelContext* context) override {
247     CompileTensorScatter(
248         context, [](xla::XlaOp, xla::XlaOp y, xla::XlaBuilder*) { return y; },
249         /*broadcast_scalar_update=*/true);
250   }
251 };
252 
253 REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
254 REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
255 REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
256 REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
257 REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
258 
259 }  // namespace
260 }  // namespace tensorflow
261