• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h"
17 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h"
18 #include "tensorflow/compiler/tf2xla/lib/scatter.h"
19 #include "tensorflow/compiler/tf2xla/shape_util.h"
20 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/core/framework/kernel_def_builder.h"
25 #include "tensorflow/core/framework/types.h"
26 
27 namespace tensorflow {
28 namespace {
29 
30 class VarIsInitializedOp : public XlaOpKernel {
31  public:
VarIsInitializedOp(OpKernelConstruction * ctx)32   explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)33   void Compile(XlaOpKernelContext* ctx) override {
34     XlaResource* variable;
35     OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable));
36     ctx->SetOutput(
37         0, xla::ConstantR0<bool>(ctx->builder(), variable->initialized()));
38   }
39 };
40 REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp);
41 
42 class VariableShapeOp : public XlaOpKernel {
43  public:
VariableShapeOp(OpKernelConstruction * ctx)44   explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
45     OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
46   }
47 
Compile(XlaOpKernelContext * ctx)48   void Compile(XlaOpKernelContext* ctx) override {
49     DataType variable_dtype;
50     TensorShape shape;
51     OP_REQUIRES_OK(ctx,
52                    ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape));
53     Tensor shape_constant(out_dtype_, TensorShape({shape.dims()}));
54     OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant));
55     ctx->SetConstantOutput(0, shape_constant);
56   }
57 
58  private:
59   DataType out_dtype_;
60 };
61 REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp);
62 
63 class ReadVariableOp : public XlaOpKernel {
64  public:
ReadVariableOp(OpKernelConstruction * ctx)65   explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
66     OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
67   }
68 
Compile(XlaOpKernelContext * ctx)69   void Compile(XlaOpKernelContext* ctx) override {
70     xla::XlaOp handle;
71     OP_REQUIRES_OK(
72         ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle));
73     ctx->SetOutput(0, handle);
74   }
75 
76  private:
77   DataType dtype_;
78 };
79 REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp);
80 
81 class AssignVariableOp : public XlaOpKernel {
82  public:
AssignVariableOp(OpKernelConstruction * ctx)83   explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)84   void Compile(XlaOpKernelContext* ctx) override {
85     OP_REQUIRES_OK(ctx,
86                    ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1)));
87   }
88 };
89 REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp);
90 
91 class AssignAddVariableOp : public XlaOpKernel {
92  public:
AssignAddVariableOp(OpKernelConstruction * ctx)93   explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)94   void Compile(XlaOpKernelContext* ctx) override {
95     DataType type = ctx->input_type(1);
96     xla::XlaOp handle;
97     OP_REQUIRES_OK(ctx,
98                    ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
99     handle = xla::Add(handle, ctx->Input(1));
100     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
101   }
102 };
103 REGISTER_XLA_OP(
104     Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes),
105     AssignAddVariableOp);
106 
107 class AssignSubVariableOp : public XlaOpKernel {
108  public:
AssignSubVariableOp(OpKernelConstruction * ctx)109   explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)110   void Compile(XlaOpKernelContext* ctx) override {
111     DataType type = ctx->input_type(1);
112     xla::XlaOp handle;
113     OP_REQUIRES_OK(ctx,
114                    ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle));
115     handle = xla::Sub(handle, ctx->Input(1));
116     OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle));
117   }
118 };
119 REGISTER_XLA_OP(
120     Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes),
121     AssignSubVariableOp);
122 
123 class ResourceGatherOp : public XlaOpKernel {
124  public:
ResourceGatherOp(OpKernelConstruction * ctx)125   explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
Compile(XlaOpKernelContext * ctx)126   void Compile(XlaOpKernelContext* ctx) override {
127     xla::XlaBuilder* builder = ctx->builder();
128 
129     DataType type = ctx->expected_output_dtype(0);
130 
131     TensorShape resource_shape;
132     xla::XlaOp resource_handle;
133     OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape,
134                                                &resource_handle));
135 
136     auto indices = ctx->Input(1);
137     auto indices_shape = ctx->InputShape(1);
138     DataType index_type = ctx->input_type(1);
139     xla::XlaOp gather;
140     OP_REQUIRES_OK(
141         ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape,
142                        /*axis=*/0, /*indices_are_nd=*/false, type, index_type,
143                        builder, &gather));
144     ctx->SetOutput(0, gather);
145   }
146 };
147 REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp);
148 
149 class ResourceScatterOp : public XlaOpKernel {
150  public:
ResourceScatterOp(OpKernelConstruction * context,bool indices_are_vectors,std::function<xla::XlaOp (const xla::XlaOp &,const xla::XlaOp &,xla::XlaBuilder *)> combiner)151   explicit ResourceScatterOp(
152       OpKernelConstruction* context, bool indices_are_vectors,
153       std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
154                                xla::XlaBuilder*)>
155           combiner)
156       : XlaOpKernel(context),
157         indices_are_vectors_(indices_are_vectors),
158         combiner_(std::move(combiner)) {}
159 
Compile(XlaOpKernelContext * context)160   void Compile(XlaOpKernelContext* context) override {
161     xla::XlaBuilder* builder = context->builder();
162 
163     DataType dtype = context->input_type(2);
164     TensorShape var_shape;
165     xla::XlaOp var_value;
166     OP_REQUIRES_OK(
167         context, context->ReadVariableInput(0, dtype, &var_shape, &var_value));
168 
169     const xla::XlaOp indices = context->Input(1);
170     const xla::XlaOp updates = context->Input(2);
171 
172     auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_,
173                              combiner_, builder);
174     OP_REQUIRES_OK(context, result.status());
175     OP_REQUIRES_OK(context,
176                    context->AssignVariable(0, dtype, result.ValueOrDie()));
177   }
178 
179  private:
180   const bool indices_are_vectors_;
181   const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
182                                  xla::XlaBuilder*)>
183       combiner_;
184 };
185 
186 class ResourceScatterAddOp : public ResourceScatterOp {
187  public:
ResourceScatterAddOp(OpKernelConstruction * context)188   explicit ResourceScatterAddOp(OpKernelConstruction* context)
189       : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
190 
191  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)192   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
193                             xla::XlaBuilder* builder) {
194     return xla::Add(x, y);
195   }
196 };
197 REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp);
198 
199 class ResourceScatterSubOp : public ResourceScatterOp {
200  public:
ResourceScatterSubOp(OpKernelConstruction * context)201   explicit ResourceScatterSubOp(OpKernelConstruction* context)
202       : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
203 
204  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)205   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
206                             xla::XlaBuilder* builder) {
207     return xla::Sub(x, y);
208   }
209 };
210 REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp);
211 
212 class ResourceScatterMulOp : public ResourceScatterOp {
213  public:
ResourceScatterMulOp(OpKernelConstruction * context)214   explicit ResourceScatterMulOp(OpKernelConstruction* context)
215       : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
216 
217  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)218   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
219                             xla::XlaBuilder* builder) {
220     return xla::Mul(x, y);
221   }
222 };
223 REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp);
224 
225 class ResourceScatterDivOp : public ResourceScatterOp {
226  public:
ResourceScatterDivOp(OpKernelConstruction * context)227   explicit ResourceScatterDivOp(OpKernelConstruction* context)
228       : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
229 
230  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)231   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
232                             xla::XlaBuilder* builder) {
233     return xla::Div(x, y);
234   }
235 };
236 REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp);
237 
238 class ResourceScatterMinOp : public ResourceScatterOp {
239  public:
ResourceScatterMinOp(OpKernelConstruction * context)240   explicit ResourceScatterMinOp(OpKernelConstruction* context)
241       : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
242 
243  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)244   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
245                             xla::XlaBuilder* builder) {
246     return xla::Min(x, y);
247   }
248 };
249 REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp);
250 
251 class ResourceScatterMaxOp : public ResourceScatterOp {
252  public:
ResourceScatterMaxOp(OpKernelConstruction * context)253   explicit ResourceScatterMaxOp(OpKernelConstruction* context)
254       : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {}
255 
256  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)257   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
258                             xla::XlaBuilder* builder) {
259     return xla::Max(x, y);
260   }
261 };
262 REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp);
263 
264 class ResourceScatterUpdateOp : public ResourceScatterOp {
265  public:
ResourceScatterUpdateOp(OpKernelConstruction * context)266   explicit ResourceScatterUpdateOp(OpKernelConstruction* context)
267       : ResourceScatterOp(context, /*indices_are_vectors=*/false,
268                           /*combiner=*/{}) {}
269 };
270 REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp);
271 
272 class ResourceScatterNdUpdateOp : public ResourceScatterOp {
273  public:
ResourceScatterNdUpdateOp(OpKernelConstruction * context)274   explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context)
275       : ResourceScatterOp(context, /*indices_are_vectors=*/true,
276                           /*combiner=*/{}) {}
277 };
278 REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp);
279 
280 class ResourceScatterNdAddOp : public ResourceScatterOp {
281  public:
ResourceScatterNdAddOp(OpKernelConstruction * context)282   explicit ResourceScatterNdAddOp(OpKernelConstruction* context)
283       : ResourceScatterOp(context, /*indices_are_vectors=*/true,
284                           /*combiner=*/Combine) {}
285 
286  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)287   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
288                             xla::XlaBuilder* builder) {
289     return xla::Add(x, y);
290   }
291 };
292 REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp);
293 
294 class ResourceScatterNdSubOp : public ResourceScatterOp {
295  public:
ResourceScatterNdSubOp(OpKernelConstruction * context)296   explicit ResourceScatterNdSubOp(OpKernelConstruction* context)
297       : ResourceScatterOp(context, /*indices_are_vectors=*/true,
298                           /*combiner=*/Combine) {}
299 
300  private:
Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)301   static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y,
302                             xla::XlaBuilder* builder) {
303     return xla::Sub(x, y);
304   }
305 };
306 REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp);
307 
308 }  // namespace
309 }  // namespace tensorflow
310