1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
17 #include "tensorflow/compiler/tf2xla/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/compiler/xla/status_macros.h"
25 #include "tensorflow/compiler/xla/xla_data.pb.h"
26 #include "tensorflow/core/framework/kernel_def_builder.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28
29 namespace tensorflow {
30 namespace {
31
32 // Check whether updates.shape = indices.shape[:batch_dim] +
33 // buffer_shape[num_index_dims:]
ValidateUpdateShape(const TensorShape & buffer_shape,const TensorShape & indices_shape,const TensorShape & updates_shape)34 Status ValidateUpdateShape(const TensorShape& buffer_shape,
35 const TensorShape& indices_shape,
36 const TensorShape& updates_shape) {
37 if (indices_shape.dims() < 1) {
38 return errors::InvalidArgument(
39 "indices shape must have >= 1 dimension; got ",
40 indices_shape.DebugString());
41 }
42
43 const int64_t num_index_dims =
44 indices_shape.dim_size(indices_shape.dims() - 1);
45 const int64_t batch_dim = indices_shape.dims() - 1;
46
47 auto shape_err = [&]() {
48 return errors::InvalidArgument(
49 "Must have updates.shape = indices.shape[:batch_dim] + ",
50 "buffer_shape[num_index_dims:], got updates.shape: ",
51 updates_shape.DebugString(),
52 ", indices.shape: ", indices_shape.DebugString(),
53 ", buffer_shape: ", buffer_shape.DebugString(),
54 ", num_index_dims: ", num_index_dims, ", and batch_dim: ", batch_dim);
55 };
56
57 if (updates_shape.dims() < batch_dim) return shape_err();
58 if (buffer_shape.dims() <
59 num_index_dims + (updates_shape.dims() - batch_dim)) {
60 return shape_err();
61 }
62 if (updates_shape.dims() !=
63 batch_dim + buffer_shape.dims() - num_index_dims) {
64 return shape_err();
65 }
66 for (int d = 0; d < batch_dim; ++d) {
67 if (updates_shape.dim_size(d) != indices_shape.dim_size(d)) {
68 return shape_err();
69 }
70 }
71 for (int d = 0; d < updates_shape.dims() - batch_dim; ++d) {
72 if (updates_shape.dim_size(d + batch_dim) !=
73 buffer_shape.dim_size(d + num_index_dims)) {
74 return shape_err();
75 }
76 }
77 return Status::OK();
78 }
79
80 class ScatterNdOp : public XlaOpKernel {
81 public:
ScatterNdOp(OpKernelConstruction * context)82 explicit ScatterNdOp(OpKernelConstruction* context) : XlaOpKernel(context) {}
83
Compile(XlaOpKernelContext * context)84 void Compile(XlaOpKernelContext* context) override {
85 DataType dtype = context->input_type(1);
86
87 TensorShape indices_shape = context->InputShape(0);
88 TensorShape updates_shape = context->InputShape(1);
89
90 TensorShape buffer_shape;
91 OP_REQUIRES_OK(context, context->ConstantInputAsShape(2, &buffer_shape));
92
93 OP_REQUIRES(
94 context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
95 errors::InvalidArgument("Output must be at least 1-D, ",
96 "got shape: ", buffer_shape.DebugString()));
97
98 OP_REQUIRES(
99 context,
100 buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
101 updates_shape.num_elements() == 0),
102 errors::InvalidArgument(
103 "Indices and updates specified for empty output. indices shape: ",
104 indices_shape.DebugString()));
105
106 OP_REQUIRES_OK(context, ValidateUpdateShape(buffer_shape, indices_shape,
107 updates_shape));
108
109 xla::XlaBuilder* builder = context->builder();
110 auto buffer = xla::Broadcast(XlaHelpers::Zero(builder, dtype),
111 buffer_shape.dim_sizes());
112 auto indices = context->Input(0);
113 auto updates = context->Input(1);
114 auto combine =
115 context->input_xla_type(1) == xla::PRED ? CombineBool : CombineNum;
116 auto result =
117 XlaScatter(buffer, updates, indices,
118 /*indices_are_vectors=*/true, /*combiner=*/combine, builder);
119 OP_REQUIRES_OK(context, result.status());
120 context->SetOutput(0, result.ValueOrDie());
121 }
122
123 private:
CombineNum(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)124 static xla::XlaOp CombineNum(const xla::XlaOp x, const xla::XlaOp y,
125 xla::XlaBuilder* builder) {
126 (void)builder;
127 return xla::Add(x, y);
128 }
CombineBool(const xla::XlaOp x,const xla::XlaOp y,xla::XlaBuilder * builder)129 static xla::XlaOp CombineBool(const xla::XlaOp x, const xla::XlaOp y,
130 xla::XlaBuilder* builder) {
131 (void)builder;
132 return xla::Or(x, y);
133 }
134 };
135
136 REGISTER_XLA_OP(Name("ScatterNd").CompileTimeConstantInput("shape"),
137 ScatterNdOp);
138
CompileTensorScatter(XlaOpKernelContext * context,const std::function<xla::XlaOp (xla::XlaOp,xla::XlaOp,xla::XlaBuilder *)> & combiner)139 void CompileTensorScatter(
140 XlaOpKernelContext* context,
141 const std::function<xla::XlaOp(xla::XlaOp, xla::XlaOp, xla::XlaBuilder*)>&
142 combiner) {
143 TensorShape buffer_shape = context->InputShape(0);
144 TensorShape indices_shape = context->InputShape(1);
145 TensorShape updates_shape = context->InputShape(2);
146
147 OP_REQUIRES(
148 context, TensorShapeUtils::IsVectorOrHigher(buffer_shape),
149 errors::InvalidArgument("Output must be at least 1-D, ",
150 "got shape: ", buffer_shape.DebugString()));
151
152 OP_REQUIRES(
153 context,
154 buffer_shape.num_elements() > 0 || (indices_shape.num_elements() == 0 &&
155 updates_shape.num_elements() == 0),
156 errors::InvalidArgument(
157 "Indices and updates specified for empty output. indices shape: ",
158 indices_shape.DebugString()));
159
160 OP_REQUIRES_OK(
161 context, ValidateUpdateShape(buffer_shape, indices_shape, updates_shape));
162
163 xla::XlaBuilder* builder = context->builder();
164 auto buffer = context->Input(0);
165 auto indices = context->Input(1);
166 auto updates = context->Input(2);
167 auto result = XlaScatter(buffer, updates, indices,
168 /*indices_are_vectors=*/true, combiner, builder);
169 OP_REQUIRES_OK(context, result.status());
170 context->SetOutput(0, result.ValueOrDie());
171 }
172
173 class TensorScatterAddOp : public XlaOpKernel {
174 public:
TensorScatterAddOp(OpKernelConstruction * context)175 explicit TensorScatterAddOp(OpKernelConstruction* context)
176 : XlaOpKernel(context) {}
177
Compile(XlaOpKernelContext * context)178 void Compile(XlaOpKernelContext* context) override {
179 CompileTensorScatter(context,
180 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
181 return xla::Add(x, y);
182 });
183 }
184 };
185
186 class TensorScatterMaxOp : public XlaOpKernel {
187 public:
TensorScatterMaxOp(OpKernelConstruction * context)188 explicit TensorScatterMaxOp(OpKernelConstruction* context)
189 : XlaOpKernel(context) {}
190
Compile(XlaOpKernelContext * context)191 void Compile(XlaOpKernelContext* context) override {
192 CompileTensorScatter(context,
193 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
194 return xla::Max(x, y);
195 });
196 }
197 };
198
199 class TensorScatterMinOp : public XlaOpKernel {
200 public:
TensorScatterMinOp(OpKernelConstruction * context)201 explicit TensorScatterMinOp(OpKernelConstruction* context)
202 : XlaOpKernel(context) {}
203
Compile(XlaOpKernelContext * context)204 void Compile(XlaOpKernelContext* context) override {
205 CompileTensorScatter(context,
206 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
207 return xla::Min(x, y);
208 });
209 }
210 };
211
212 class TensorScatterSubOp : public XlaOpKernel {
213 public:
TensorScatterSubOp(OpKernelConstruction * context)214 explicit TensorScatterSubOp(OpKernelConstruction* context)
215 : XlaOpKernel(context) {}
216
Compile(XlaOpKernelContext * context)217 void Compile(XlaOpKernelContext* context) override {
218 CompileTensorScatter(context,
219 [](xla::XlaOp x, xla::XlaOp y, xla::XlaBuilder*) {
220 return xla::Sub(x, y);
221 });
222 }
223 };
224
225 class TensorScatterUpdateOp : public XlaOpKernel {
226 public:
TensorScatterUpdateOp(OpKernelConstruction * context)227 explicit TensorScatterUpdateOp(OpKernelConstruction* context)
228 : XlaOpKernel(context) {}
229
Compile(XlaOpKernelContext * context)230 void Compile(XlaOpKernelContext* context) override {
231 CompileTensorScatter(
232 context, [](xla::XlaOp, xla::XlaOp y, xla::XlaBuilder*) { return y; });
233 }
234 };
235
236 REGISTER_XLA_OP(Name("TensorScatterAdd"), TensorScatterAddOp);
237 REGISTER_XLA_OP(Name("TensorScatterMax"), TensorScatterMaxOp);
238 REGISTER_XLA_OP(Name("TensorScatterMin"), TensorScatterMinOp);
239 REGISTER_XLA_OP(Name("TensorScatterSub"), TensorScatterSubOp);
240 REGISTER_XLA_OP(Name("TensorScatterUpdate"), TensorScatterUpdateOp);
241
242 } // namespace
243 } // namespace tensorflow
244