1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" 17 #include "tensorflow/compiler/tf2xla/kernels/shape_util.h" 18 #include "tensorflow/compiler/tf2xla/lib/scatter.h" 19 #include "tensorflow/compiler/tf2xla/shape_util.h" 20 #include "tensorflow/compiler/tf2xla/xla_helpers.h" 21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h" 22 #include "tensorflow/compiler/xla/client/xla_builder.h" 23 #include "tensorflow/compiler/xla/literal.h" 24 #include "tensorflow/core/framework/kernel_def_builder.h" 25 #include "tensorflow/core/framework/types.h" 26 27 namespace tensorflow { 28 namespace { 29 30 class VarIsInitializedOp : public XlaOpKernel { 31 public: VarIsInitializedOp(OpKernelConstruction * ctx)32 explicit VarIsInitializedOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} Compile(XlaOpKernelContext * ctx)33 void Compile(XlaOpKernelContext* ctx) override { 34 XlaResource* variable; 35 OP_REQUIRES_OK(ctx, ctx->GetResourceInput(0, &variable)); 36 ctx->SetOutput( 37 0, xla::ConstantR0<bool>(ctx->builder(), variable->initialized())); 38 } 39 }; 40 REGISTER_XLA_OP(Name("VarIsInitializedOp"), VarIsInitializedOp); 41 42 class VariableShapeOp : public XlaOpKernel { 43 public: VariableShapeOp(OpKernelConstruction * ctx)44 explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 45 OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_)); 46 } 47 Compile(XlaOpKernelContext * ctx)48 void Compile(XlaOpKernelContext* ctx) override { 49 DataType variable_dtype; 50 TensorShape shape; 51 OP_REQUIRES_OK(ctx, 52 ctx->GetVariableTypeAndShape(0, &variable_dtype, &shape)); 53 Tensor shape_constant(out_dtype_, TensorShape({shape.dims()})); 54 OP_REQUIRES_OK(ctx, TensorShapeToConstant(shape, &shape_constant)); 55 ctx->SetConstantOutput(0, shape_constant); 56 } 57 58 private: 59 DataType out_dtype_; 60 }; 61 REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp); 62 63 class ReadVariableOp : public XlaOpKernel { 64 public: ReadVariableOp(OpKernelConstruction * ctx)65 explicit ReadVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) { 66 OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_)); 67 } 68 Compile(XlaOpKernelContext * ctx)69 void Compile(XlaOpKernelContext* ctx) override { 70 xla::XlaOp handle; 71 OP_REQUIRES_OK( 72 ctx, ctx->ReadVariableInput(0, dtype_, /*shape=*/nullptr, &handle)); 73 ctx->SetOutput(0, handle); 74 } 75 76 private: 77 DataType dtype_; 78 }; 79 REGISTER_XLA_OP(Name("ReadVariableOp").CompilationOnly(), ReadVariableOp); 80 81 class AssignVariableOp : public XlaOpKernel { 82 public: AssignVariableOp(OpKernelConstruction * ctx)83 explicit AssignVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} Compile(XlaOpKernelContext * ctx)84 void Compile(XlaOpKernelContext* ctx) override { 85 OP_REQUIRES_OK(ctx, 86 ctx->AssignVariable(0, ctx->input_type(1), ctx->Input(1))); 87 } 88 }; 89 REGISTER_XLA_OP(Name("AssignVariableOp").CompilationOnly(), AssignVariableOp); 90 91 class AssignAddVariableOp : public XlaOpKernel { 92 public: AssignAddVariableOp(OpKernelConstruction * ctx)93 explicit AssignAddVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} Compile(XlaOpKernelContext * ctx)94 void Compile(XlaOpKernelContext* ctx) override { 95 DataType type = ctx->input_type(1); 96 xla::XlaOp handle; 97 OP_REQUIRES_OK(ctx, 98 ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); 99 handle = xla::Add(handle, ctx->Input(1)); 100 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); 101 } 102 }; 103 REGISTER_XLA_OP( 104 Name("AssignAddVariableOp").TypeConstraint("dtype", kNumericTypes), 105 AssignAddVariableOp); 106 107 class AssignSubVariableOp : public XlaOpKernel { 108 public: AssignSubVariableOp(OpKernelConstruction * ctx)109 explicit AssignSubVariableOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} Compile(XlaOpKernelContext * ctx)110 void Compile(XlaOpKernelContext* ctx) override { 111 DataType type = ctx->input_type(1); 112 xla::XlaOp handle; 113 OP_REQUIRES_OK(ctx, 114 ctx->ReadVariableInput(0, type, /*shape=*/nullptr, &handle)); 115 handle = xla::Sub(handle, ctx->Input(1)); 116 OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, handle)); 117 } 118 }; 119 REGISTER_XLA_OP( 120 Name("AssignSubVariableOp").TypeConstraint("dtype", kNumericTypes), 121 AssignSubVariableOp); 122 123 class ResourceGatherOp : public XlaOpKernel { 124 public: ResourceGatherOp(OpKernelConstruction * ctx)125 explicit ResourceGatherOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {} Compile(XlaOpKernelContext * ctx)126 void Compile(XlaOpKernelContext* ctx) override { 127 xla::XlaBuilder* builder = ctx->builder(); 128 129 DataType type = ctx->expected_output_dtype(0); 130 131 TensorShape resource_shape; 132 xla::XlaOp resource_handle; 133 OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, type, &resource_shape, 134 &resource_handle)); 135 136 auto indices = ctx->Input(1); 137 auto indices_shape = ctx->InputShape(1); 138 DataType index_type = ctx->input_type(1); 139 xla::XlaOp gather; 140 OP_REQUIRES_OK( 141 ctx, XlaGather(resource_handle, resource_shape, indices, indices_shape, 142 /*axis=*/0, /*indices_are_nd=*/false, type, index_type, 143 builder, &gather)); 144 ctx->SetOutput(0, gather); 145 } 146 }; 147 REGISTER_XLA_OP(Name("ResourceGather"), ResourceGatherOp); 148 149 class ResourceScatterOp : public XlaOpKernel { 150 public: ResourceScatterOp(OpKernelConstruction * context,bool indices_are_vectors,std::function<xla::XlaOp (const xla::XlaOp &,const xla::XlaOp &,xla::XlaBuilder *)> combiner)151 explicit ResourceScatterOp( 152 OpKernelConstruction* context, bool indices_are_vectors, 153 std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&, 154 xla::XlaBuilder*)> 155 combiner) 156 : XlaOpKernel(context), 157 indices_are_vectors_(indices_are_vectors), 158 combiner_(std::move(combiner)) {} 159 Compile(XlaOpKernelContext * context)160 void Compile(XlaOpKernelContext* context) override { 161 xla::XlaBuilder* builder = context->builder(); 162 163 DataType dtype = context->input_type(2); 164 TensorShape var_shape; 165 xla::XlaOp var_value; 166 OP_REQUIRES_OK( 167 context, context->ReadVariableInput(0, dtype, &var_shape, &var_value)); 168 169 const xla::XlaOp indices = context->Input(1); 170 const xla::XlaOp updates = context->Input(2); 171 172 auto result = XlaScatter(var_value, updates, indices, indices_are_vectors_, 173 combiner_, builder); 174 OP_REQUIRES_OK(context, result.status()); 175 OP_REQUIRES_OK(context, 176 context->AssignVariable(0, dtype, result.ValueOrDie())); 177 } 178 179 private: 180 const bool indices_are_vectors_; 181 const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&, 182 xla::XlaBuilder*)> 183 combiner_; 184 }; 185 186 class ResourceScatterAddOp : public ResourceScatterOp { 187 public: ResourceScatterAddOp(OpKernelConstruction * context)188 explicit ResourceScatterAddOp(OpKernelConstruction* context) 189 : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} 190 191 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)192 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 193 xla::XlaBuilder* builder) { 194 return xla::Add(x, y); 195 } 196 }; 197 REGISTER_XLA_OP(Name("ResourceScatterAdd"), ResourceScatterAddOp); 198 199 class ResourceScatterSubOp : public ResourceScatterOp { 200 public: ResourceScatterSubOp(OpKernelConstruction * context)201 explicit ResourceScatterSubOp(OpKernelConstruction* context) 202 : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} 203 204 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)205 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 206 xla::XlaBuilder* builder) { 207 return xla::Sub(x, y); 208 } 209 }; 210 REGISTER_XLA_OP(Name("ResourceScatterSub"), ResourceScatterSubOp); 211 212 class ResourceScatterMulOp : public ResourceScatterOp { 213 public: ResourceScatterMulOp(OpKernelConstruction * context)214 explicit ResourceScatterMulOp(OpKernelConstruction* context) 215 : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} 216 217 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)218 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 219 xla::XlaBuilder* builder) { 220 return xla::Mul(x, y); 221 } 222 }; 223 REGISTER_XLA_OP(Name("ResourceScatterMul"), ResourceScatterMulOp); 224 225 class ResourceScatterDivOp : public ResourceScatterOp { 226 public: ResourceScatterDivOp(OpKernelConstruction * context)227 explicit ResourceScatterDivOp(OpKernelConstruction* context) 228 : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} 229 230 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)231 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 232 xla::XlaBuilder* builder) { 233 return xla::Div(x, y); 234 } 235 }; 236 REGISTER_XLA_OP(Name("ResourceScatterDiv"), ResourceScatterDivOp); 237 238 class ResourceScatterMinOp : public ResourceScatterOp { 239 public: ResourceScatterMinOp(OpKernelConstruction * context)240 explicit ResourceScatterMinOp(OpKernelConstruction* context) 241 : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} 242 243 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)244 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 245 xla::XlaBuilder* builder) { 246 return xla::Min(x, y); 247 } 248 }; 249 REGISTER_XLA_OP(Name("ResourceScatterMin"), ResourceScatterMinOp); 250 251 class ResourceScatterMaxOp : public ResourceScatterOp { 252 public: ResourceScatterMaxOp(OpKernelConstruction * context)253 explicit ResourceScatterMaxOp(OpKernelConstruction* context) 254 : ResourceScatterOp(context, /*indices_are_vectors=*/false, Combine) {} 255 256 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)257 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 258 xla::XlaBuilder* builder) { 259 return xla::Max(x, y); 260 } 261 }; 262 REGISTER_XLA_OP(Name("ResourceScatterMax"), ResourceScatterMaxOp); 263 264 class ResourceScatterUpdateOp : public ResourceScatterOp { 265 public: ResourceScatterUpdateOp(OpKernelConstruction * context)266 explicit ResourceScatterUpdateOp(OpKernelConstruction* context) 267 : ResourceScatterOp(context, /*indices_are_vectors=*/false, 268 /*combiner=*/{}) {} 269 }; 270 REGISTER_XLA_OP(Name("ResourceScatterUpdate"), ResourceScatterUpdateOp); 271 272 class ResourceScatterNdUpdateOp : public ResourceScatterOp { 273 public: ResourceScatterNdUpdateOp(OpKernelConstruction * context)274 explicit ResourceScatterNdUpdateOp(OpKernelConstruction* context) 275 : ResourceScatterOp(context, /*indices_are_vectors=*/true, 276 /*combiner=*/{}) {} 277 }; 278 REGISTER_XLA_OP(Name("ResourceScatterNdUpdate"), ResourceScatterNdUpdateOp); 279 280 class ResourceScatterNdAddOp : public ResourceScatterOp { 281 public: ResourceScatterNdAddOp(OpKernelConstruction * context)282 explicit ResourceScatterNdAddOp(OpKernelConstruction* context) 283 : ResourceScatterOp(context, /*indices_are_vectors=*/true, 284 /*combiner=*/Combine) {} 285 286 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)287 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 288 xla::XlaBuilder* builder) { 289 return xla::Add(x, y); 290 } 291 }; 292 REGISTER_XLA_OP(Name("ResourceScatterNdAdd"), ResourceScatterNdAddOp); 293 294 class ResourceScatterNdSubOp : public ResourceScatterOp { 295 public: ResourceScatterNdSubOp(OpKernelConstruction * context)296 explicit ResourceScatterNdSubOp(OpKernelConstruction* context) 297 : ResourceScatterOp(context, /*indices_are_vectors=*/true, 298 /*combiner=*/Combine) {} 299 300 private: Combine(const xla::XlaOp & x,const xla::XlaOp & y,xla::XlaBuilder * builder)301 static xla::XlaOp Combine(const xla::XlaOp& x, const xla::XlaOp& y, 302 xla::XlaBuilder* builder) { 303 return xla::Sub(x, y); 304 } 305 }; 306 REGISTER_XLA_OP(Name("ResourceScatterNdSub"), ResourceScatterNdSubOp); 307 308 } // namespace 309 } // namespace tensorflow 310