• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Common kernel registrations for XLA devices.
17 
18 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
19 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/kernels/constant_op.h"
24 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
25 #include "tensorflow/core/kernels/data/iterator_ops.h"
26 #include "tensorflow/core/kernels/data/optional_ops.h"
27 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
28 #include "tensorflow/core/kernels/fifo_queue.h"
29 #include "tensorflow/core/kernels/function_ops.h"
30 #include "tensorflow/core/kernels/identity_op.h"
31 #include "tensorflow/core/kernels/resource_variable_ops.h"
32 #include "tensorflow/core/kernels/shape_ops.h"
33 #include "tensorflow/core/kernels/variable_ops.h"
34 
35 namespace tensorflow {
36 
37 // Dummy OpKernel, used for kernels assigned to an XLA device that should be
38 // compiled. Should never be called at runtime since such ops should be
39 // rewritten to a XlaLaunch op. If it is called, it means the placer placed an
40 // operator on an XLA device but the compiler did not compile it.
41 class XlaDeviceDummyOp : public OpKernel {
42  public:
43   explicit XlaDeviceDummyOp(OpKernelConstruction* ctx);
44   void Compute(OpKernelContext* ctx) override;
45 };
46 
47 class XlaAssignVariableOp : public OpKernel {
48  public:
49   explicit XlaAssignVariableOp(OpKernelConstruction* c);
50   void Compute(OpKernelContext* context) override;
51 
52  private:
53   DataType dtype_;
54 };
55 
56 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
57   REGISTER_KERNEL_BUILDER(Name("XlaLaunch")               \
58                               .Device(DEVICE)             \
59                               .HostMemory("constants")    \
60                               .HostMemory("resources"),   \
61                           KERNEL);
62 
63 #define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES)          \
64   REGISTER_KERNEL_BUILDER(Name("_XlaCompile")                       \
65                               .Device(DEVICE)                       \
66                               .HostMemory("constants")              \
67                               .HostMemory("key")                    \
68                               .HostMemory("compilation_successful") \
69                               .HostMemory("resources"),             \
70                           KERNEL);
71 
72 #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
73   REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
74 
75 #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES)                             \
76   REGISTER_KERNEL_BUILDER(                                                     \
77       Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES),             \
78       ConstantOp);                                                             \
79   REGISTER_KERNEL_BUILDER(                                                     \
80       Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
81                                                                                \
82   REGISTER_KERNEL_BUILDER(                                                     \
83       Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), VarHandleOp); \
84   REGISTER_KERNEL_BUILDER(                                                     \
85       Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"),            \
86       ResourceHandlesOp<Var>);                                                 \
87   REGISTER_KERNEL_BUILDER(                                                     \
88       Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"),            \
89       ReadVariableOp);                                                         \
90   REGISTER_KERNEL_BUILDER(                                                     \
91       Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"),         \
92       ReadVariablesOp);                                                        \
93   REGISTER_KERNEL_BUILDER(                                                     \
94       Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"),         \
95       DestroyResourceOp);                                                      \
96   REGISTER_KERNEL_BUILDER(Name("Shape")                                        \
97                               .Device(DEVICE)                                  \
98                               .HostMemory("output")                            \
99                               .TypeConstraint<int32>("out_type")               \
100                               .TypeConstraint("T", TYPES),                     \
101                           ShapeOp<int32>);                                     \
102   REGISTER_KERNEL_BUILDER(Name("Shape")                                        \
103                               .Device(DEVICE)                                  \
104                               .HostMemory("output")                            \
105                               .TypeConstraint<int64>("out_type")               \
106                               .TypeConstraint("T", TYPES),                     \
107                           ShapeOp<int64>);                                     \
108   REGISTER_KERNEL_BUILDER(Name("ShapeN")                                       \
109                               .Device(DEVICE)                                  \
110                               .HostMemory("output")                            \
111                               .TypeConstraint<int32>("out_type")               \
112                               .TypeConstraint("T", TYPES),                     \
113                           ShapeNOp<int32>);                                    \
114   REGISTER_KERNEL_BUILDER(Name("ShapeN")                                       \
115                               .Device(DEVICE)                                  \
116                               .HostMemory("output")                            \
117                               .TypeConstraint<int64>("out_type")               \
118                               .TypeConstraint("T", TYPES),                     \
119                           ShapeNOp<int64>);                                    \
120   REGISTER_KERNEL_BUILDER(Name("VariableShape")                                \
121                               .Device(DEVICE)                                  \
122                               .TypeConstraint<int32>("out_type")               \
123                               .HostMemory("output")                            \
124                               .HostMemory("input"),                            \
125                           VariableShapeOp<int32>);                             \
126   REGISTER_KERNEL_BUILDER(Name("VariableShape")                                \
127                               .Device(DEVICE)                                  \
128                               .TypeConstraint<int64>("out_type")               \
129                               .HostMemory("output")                            \
130                               .HostMemory("input"),                            \
131                           VariableShapeOp<int64>);                             \
132   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
133                               .Device(DEVICE)                                  \
134                               .HostMemory("output")                            \
135                               .TypeConstraint<int32>("out_type")               \
136                               .TypeConstraint("T", TYPES),                     \
137                           SizeOp<int32>);                                      \
138   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
139                               .Device(DEVICE)                                  \
140                               .HostMemory("output")                            \
141                               .TypeConstraint<int64>("out_type")               \
142                               .TypeConstraint("T", TYPES),                     \
143                           SizeOp<int64>);                                      \
144   REGISTER_KERNEL_BUILDER(                                                     \
145       Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T",     \
146                                                                       TYPES),  \
147       RankOp);                                                                 \
148   REGISTER_KERNEL_BUILDER(                                                     \
149       Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"),          \
150       XlaAssignVariableOp);                                                    \
151                                                                                \
152   REGISTER_KERNEL_BUILDER(                                                     \
153       Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);   \
154                                                                                \
155   REGISTER_KERNEL_BUILDER(                                                     \
156       Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp);          \
157   REGISTER_KERNEL_BUILDER(Name(kArgOp)                                         \
158                               .Device(DEVICE)                                  \
159                               .HostMemory("output")                            \
160                               .TypeConstraint<ResourceHandle>("T"),            \
161                           ArgOp);                                              \
162   REGISTER_KERNEL_BUILDER(                                                     \
163       Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp);        \
164                                                                                \
165   REGISTER_KERNEL_BUILDER(                                                     \
166       Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp);       \
167   REGISTER_KERNEL_BUILDER(Name(kRetOp)                                         \
168                               .Device(DEVICE)                                  \
169                               .TypeConstraint<ResourceHandle>("T")             \
170                               .HostMemory("input"),                            \
171                           RetvalOp);                                           \
172   REGISTER_KERNEL_BUILDER(                                                     \
173       Name(kDeviceRetOp).Device(DEVICE).TypeConstraint<int32>("T"), RetvalOp); \
174                                                                                \
175   REGISTER_KERNEL_BUILDER(                                                     \
176       Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp);   \
177                                                                                \
178   REGISTER_KERNEL_BUILDER(                                                     \
179       Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"),            \
180       data::GeneratorDatasetOp);                                               \
181   REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")                              \
182                               .Device(DEVICE)                                  \
183                               .HostMemory("buffer_size")                       \
184                               .HostMemory("input_dataset")                     \
185                               .HostMemory("handle"),                           \
186                           data::PrefetchDatasetOp);                            \
187                                                                                \
188   REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE),                   \
189                           data::IteratorHandleOp);                             \
190   REGISTER_KERNEL_BUILDER(                                                     \
191       Name("MakeIterator").Device(DEVICE).HostMemory("dataset"),               \
192       data::MakeIteratorOp);                                                   \
193   REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE),            \
194                           data::AnonymousIteratorHandleOp);                    \
195   REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE),          \
196                           data::AnonymousIteratorHandleOp);                    \
197   REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE),               \
198                           data::DeleteIteratorOp);                             \
199   REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE),              \
200                           data::IteratorGetNextOp);                            \
201   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE),    \
202                           data::IteratorGetNextAsOptionalOp);                  \
203   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE),          \
204                           data::IteratorGetNextOp);                            \
205   REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")                       \
206                               .Device(DEVICE)                                  \
207                               .HostMemory("string_handle"),                    \
208                           data::IteratorToStringHandleOp);                     \
209   REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")                   \
210                               .Device(DEVICE)                                  \
211                               .HostMemory("string_handle"),                    \
212                           data::IteratorFromStringHandleOp);                   \
213   REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE),                 \
214                           data::OptionalNoneOp);                               \
215   REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE),            \
216                           data::OptionalFromValueOp);                          \
217   REGISTER_KERNEL_BUILDER(                                                     \
218       Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"),         \
219       data::OptionalHasValueOp);                                               \
220   REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE),             \
221                           data::OptionalGetValueOp);                           \
222   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp)              \
223                               .Device(DEVICE)                                  \
224                               .HostMemory("output")                            \
225                               .TypeConstraint<tstring>("T"),                   \
226                           ArgOp);                                              \
227   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp)              \
228                               .Device(DEVICE)                                  \
229                               .TypeConstraint<tstring>("T")                    \
230                               .HostMemory("input"),                            \
231                           RetvalOp);
232 
233 // TODO(b/118881356): currently we do not register the QueueEnqueueMany,
234 // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
235 // and write the tensors they access in order to concatenate them into a batch.
236 // We would need either to call out to an XLA computation to perform the
237 // concatenation, or we would need to refactor those kernels so the splitting
238 // or merging is done in a separate operator that can be compiled.
239 
240 }  // namespace tensorflow
241 
242 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
243