1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28
29 namespace tensorflow {
30 namespace {
31
32 // Check whether updates.shape = indices.shape[:batch_dim] +
33 // buffer_shape[num_index_dims:]
ValidateUpdateShape(const TensorShape & buffer_shape,const TensorShape & indices_shape,const TensorShape & updates_shape,bool broadcast_scalar_update)34 Status ValidateUpdateShape(const TensorShape& buffer_shape,
35 const TensorShape& indices_shape,
36 const TensorShape& updates_shape,
37 bool broadcast_scalar_update) {
38 if (indices_shape.dims() < 1) {
39 return errors::InvalidArgument(
40 "indices shape must have >= 1 dimension; got ",
41 indices_shape.DebugString());
42 }
43
44 const int64_t num_index_dims =
45 indices_shape.dim_size(indices_shape.dims() - 1);
46 const int64_t batch_dim = indices_shape.dims() - 1;
47
48 auto shape_err = [&]() {
49 return errors::InvalidArgument(
50 "Must have updates.shape = indices.shape[:batch_dim] + ",
51 "buffer_shape[num_index_dims:], got updates.shape: ",
52 updates_shape.DebugString(),
53 ", indices.shape: ", indices_shape.DebugString(),
54 ", buffer_shape: ", buffer_shape.DebugString(),
55 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
56 };
57
58 if (updates_shape.dims() == 0 && broadcast_scalar_update) {
59 return OkStatus();
60 }
61
62 if (updates_shape.dims() < batch_dim) return shape_err();
63 if (buffer_shape.dims() <
64 num_index_dims + (updates_shape.dims() - batch_dim)) {
65 return shape_err();
66 }
67 if (updates_shape.dims() !=
68 batch_dim + buffer_shape.dims() - num_index_dims) {
69 return shape_err();
70 }
71 for (int d = 0; d < batch_dim; ++d) {
72 if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) {
73 return shape_err();
74 }
75 }
76 for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
77 if (updates_shape.dim_size(d + batch_dim) !=
78 buffer_shape.dim_size(d + num_index_dims)) {
79 return shape_err();
80 }
81 }
82 return OkStatus();
83 }
84
85 class ScatterNdOp : public XlaOpKernel {
86 public:
ScatterNdOp(OpKernelConstruction * context)87 explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
88
Compile(XlaOpKernelContext * context)89 void Compile(XlaOpKernelContext* context) override {
90 DataType dtype = context->input_type(1);
91
92 TensorShape indices_shape = context->InputShape(0);
93 TensorShape updates_shape = context->InputShape(1);
94
95 TensorShape buffer_shape;
96 OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape));
97
98 OP_REQUIRES(
99 context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
100 errors::InvalidArgument("Output must be at least 1-D, ",
101 "got shape: ", buffer_shape.DebugString()));
102
103 OP_REQUIRES(
104 context,
105 buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
106 updates_shape.num_elements() == 0),
107 errors::InvalidArgument(
108 "Indices and updates specified for empty output. indices shape: ",
109 indices_shape.DebugString()));
110
111 OP_REQUIRES_OK(
112 context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape,
113 /*broadcast_scalar_update=*/false));
114
115 xla::XlaBuilder* builder = context->builder();
116 auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype),
117 buffer_shape.dim_sizes());
118 auto indices = context->Input(0);
119 auto updates = context->Input(1);
120 auto combine =
121 context->input_xla_type(1) == xla::PRED ? CombineBool : CombineNum;
122 auto result =
123 XlaScatter(buffer, updates, indices,
124 /*indices_are_vectors=*/true, /*combiner=*/combine, builder);
125 OP_REQUIRES_OK(context, result.status());
126 context->SetOutput(0, result.ValueOrDie());
127 }
128
129 private:
CombineNum(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)130 static xla::XlaOp CombineNum(const xla::XlaOp x, const xla::XlaOp y,
131 xla::XlaBuilder* builder) {
132 (void)builder;
133 return xla::Add(x, y);
134 }
CombineBool(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)135 static xla::XlaOp CombineBool(const xla::XlaOp x, const xla::XlaOp y,
136 xla::XlaBuilder* builder) {
137 (void)builder;
138 return xla::Or(x, y);
139 }
140 };
141
142 REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"),
143 ScatterNdOp);
144
CompileTensorScatter(XlaOpKernelContext * context,const std::function<xla::XlaOp (xla::XlaOp,xla::XlaOp,xla::XlaBuilder *)> & combiner,bool broadcast_scalar_update)145 void CompileTensorScatter(
146 XlaOpKernelContext* context,
147 const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
148 combiner,
149 bool broadcast_scalar_update) {
150 TensorShape buffer_shape = context->InputShape(0);
151 TensorShape indices_shape = context->InputShape(1);
152 TensorShape updates_shape = context->InputShape(2);
153
154 OP_REQUIRES(
155 context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
156 errors::InvalidArgument("Output must be at least 1-D, ",
157 "got shape: ", buffer_shape.DebugString()));
158
159 OP_REQUIRES(
160 context,
161 buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
162 updates_shape.num_elements() == 0),
163 errors::InvalidArgument(
164 "Indices and updates specified for empty output. indices shape: ",
165 indices_shape.DebugString()));
166
167 OP_REQUIRES_OK(context,
168 ValidateUpdateShape(buffer_shape, indices_shape, updates_shape,
169 broadcast_scalar_update));
170
171 xla::XlaBuilder* builder = context->builder();
172 auto buffer = context->Input(0);
173 auto indices = context->Input(1);
174 auto updates = context->Input(2);
175 auto result = XlaScatter(buffer, updates, indices,
176 /*indices_are_vectors=*/true, combiner, builder);
177 OP_REQUIRES_OK(context, result.status());
178 context->SetOutput(0, result.ValueOrDie());
179 }
180
181 class TensorScatterAddOp : public XlaOpKernel {
182 public:
TensorScatterAddOp(OpKernelConstruction * context)183 explicit TensorScatterAddOp(OpKernelConstruction* context)
184 : XlaOpKernel(context) {}
185
Compile(XlaOpKernelContext * context)186 void Compile(XlaOpKernelContext* context) override {
187 CompileTensorScatter(
188 context,
189 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
190 return xla::Add(x, y);
191 },
192 /*broadcast_scalar_update=*/true);
193 }
194 };
195
196 class TensorScatterMaxOp : public XlaOpKernel {
197 public:
TensorScatterMaxOp(OpKernelConstruction * context)198 explicit TensorScatterMaxOp(OpKernelConstruction* context)
199 : XlaOpKernel(context) {}
200
Compile(XlaOpKernelContext * context)201 void Compile(XlaOpKernelContext* context) override {
202 CompileTensorScatter(
203 context,
204 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
205 return xla::Max(x, y);
206 },
207 /*broadcast_scalar_update=*/false);
208 }
209 };
210
211 class TensorScatterMinOp : public XlaOpKernel {
212 public:
TensorScatterMinOp(OpKernelConstruction * context)213 explicit TensorScatterMinOp(OpKernelConstruction* context)
214 : XlaOpKernel(context) {}
215
Compile(XlaOpKernelContext * context)216 void Compile(XlaOpKernelContext* context) override {
217 CompileTensorScatter(
218 context,
219 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
220 return xla::Min(x, y);
221 },
222 /*broadcast_scalar_update=*/false);
223 }
224 };
225
226 class TensorScatterSubOp : public XlaOpKernel {
227 public:
TensorScatterSubOp(OpKernelConstruction * context)228 explicit TensorScatterSubOp(OpKernelConstruction* context)
229 : XlaOpKernel(context) {}
230
Compile(XlaOpKernelContext * context)231 void Compile(XlaOpKernelContext* context) override {
232 CompileTensorScatter(
233 context,
234 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
235 return xla::Sub(x, y);
236 },
237 /*broadcast_scalar_update=*/false);
238 }
239 };
240
241 class TensorScatterUpdateOp : public XlaOpKernel {
242 public:
TensorScatterUpdateOp(OpKernelConstruction * context)243 explicit TensorScatterUpdateOp(OpKernelConstruction* context)
244 : XlaOpKernel(context) {}
245
Compile(XlaOpKernelContext * context)246 void Compile(XlaOpKernelContext* context) override {
247 CompileTensorScatter(
248 context, [](xla::XlaOp, xla::XlaOp y, xla::XlaBuilder*) { return y; },
249 /*broadcast_scalar_update=*/true);
250 }
251 };
252
253 REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
254 REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
255 REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
256 REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
257 REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
258
259 } // namespace
260 } // namespace tensorflow
261