• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // Common kernel registrations for XLA devices.
17 
18 #ifndef TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
19 #define TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
20 
21 #include "tensorflow/core/framework/op_kernel.h"
22 #include "tensorflow/core/framework/resource_mgr.h"
23 #include "tensorflow/core/kernels/constant_op.h"
24 #include "tensorflow/core/kernels/data/finalize_dataset_op.h"
25 #include "tensorflow/core/kernels/data/generator_dataset_op.h"
26 #include "tensorflow/core/kernels/data/iterator_ops.h"
27 #include "tensorflow/core/kernels/data/optional_ops.h"
28 #include "tensorflow/core/kernels/data/options_dataset_op.h"
29 #include "tensorflow/core/kernels/data/prefetch_dataset_op.h"
30 #include "tensorflow/core/kernels/fifo_queue.h"
31 #include "tensorflow/core/kernels/function_ops.h"
32 #include "tensorflow/core/kernels/identity_op.h"
33 #include "tensorflow/core/kernels/resource_variable_ops.h"
34 #include "tensorflow/core/kernels/shape_ops.h"
35 #include "tensorflow/core/kernels/variable_ops.h"
36 
37 namespace tensorflow {
38 
39 // Dummy OpKernel, used for kernels assigned to an XLA device that should be
40 // compiled. Should never be called at runtime since such ops should be
41 // rewritten to a XlaLaunch op. If it is called, it means the placer placed an
42 // operator on an XLA device but the compiler did not compile it.
43 class XlaDeviceDummyOp : public OpKernel {
44  public:
45   explicit XlaDeviceDummyOp(OpKernelConstruction* ctx);
46   void Compute(OpKernelContext* ctx) override;
47 };
48 
49 class XlaAssignVariableOp : public OpKernel {
50  public:
51   explicit XlaAssignVariableOp(OpKernelConstruction* c);
52   void Compute(OpKernelContext* context) override;
53 
54  private:
55   DataType dtype_;
56 };
57 
58 #define REGISTER_XLA_LAUNCH_KERNEL(DEVICE, KERNEL, TYPES) \
59   REGISTER_KERNEL_BUILDER(Name("XlaLaunch")               \
60                               .Device(DEVICE)             \
61                               .HostMemory("constants")    \
62                               .HostMemory("resources"),   \
63                           KERNEL);
64 
65 #define REGISTER_XLA_COMPILE_KERNEL(DEVICE, KERNEL, TYPES)          \
66   REGISTER_KERNEL_BUILDER(Name("_XlaCompile")                       \
67                               .Device(DEVICE)                       \
68                               .HostMemory("constants")              \
69                               .HostMemory("key")                    \
70                               .HostMemory("compilation_successful") \
71                               .HostMemory("resources"),             \
72                           KERNEL);
73 
74 #define REGISTER_XLA_RUN_KERNEL(DEVICE, KERNEL, TYPES) \
75   REGISTER_KERNEL_BUILDER(Name("_XlaRun").Device(DEVICE), KERNEL);
76 
77 #define REGISTER_XLA_DEVICE_KERNELS(DEVICE, TYPES)                             \
78   REGISTER_KERNEL_BUILDER(                                                     \
79       Name("Const").Device(DEVICE).TypeConstraint("dtype", TYPES),             \
80       ConstantOp);                                                             \
81   REGISTER_KERNEL_BUILDER(                                                     \
82       Name("Identity").Device(DEVICE).TypeConstraint("T", TYPES), IdentityOp); \
83                                                                                \
84   REGISTER_KERNEL_BUILDER(                                                     \
85       Name("VarHandleOp").Device(DEVICE).HostMemory("resource"), VarHandleOp); \
86   REGISTER_KERNEL_BUILDER(                                                     \
87       Name("_VarHandlesOp").Device(DEVICE).HostMemory("resources"),            \
88       ResourceHandlesOp<Var>);                                                 \
89   REGISTER_KERNEL_BUILDER(                                                     \
90       Name("ReadVariableOp").Device(DEVICE).HostMemory("resource"),            \
91       ReadVariableOp);                                                         \
92   REGISTER_KERNEL_BUILDER(                                                     \
93       Name("_ReadVariablesOp").Device(DEVICE).HostMemory("resources"),         \
94       ReadVariablesOp);                                                        \
95   REGISTER_KERNEL_BUILDER(                                                     \
96       Name("DestroyResourceOp").Device(DEVICE).HostMemory("resource"),         \
97       DestroyResourceOp);                                                      \
98   REGISTER_KERNEL_BUILDER(Name("Shape")                                        \
99                               .Device(DEVICE)                                  \
100                               .HostMemory("output")                            \
101                               .TypeConstraint<int32>("out_type")               \
102                               .TypeConstraint("T", TYPES),                     \
103                           ShapeOp<int32>);                                     \
104   REGISTER_KERNEL_BUILDER(Name("Shape")                                        \
105                               .Device(DEVICE)                                  \
106                               .HostMemory("output")                            \
107                               .TypeConstraint<int64_t>("out_type")             \
108                               .TypeConstraint("T", TYPES),                     \
109                           ShapeOp<int64_t>);                                   \
110   REGISTER_KERNEL_BUILDER(Name("ShapeN")                                       \
111                               .Device(DEVICE)                                  \
112                               .HostMemory("output")                            \
113                               .TypeConstraint<int32>("out_type")               \
114                               .TypeConstraint("T", TYPES),                     \
115                           ShapeNOp<int32>);                                    \
116   REGISTER_KERNEL_BUILDER(Name("ShapeN")                                       \
117                               .Device(DEVICE)                                  \
118                               .HostMemory("output")                            \
119                               .TypeConstraint<int64_t>("out_type")             \
120                               .TypeConstraint("T", TYPES),                     \
121                           ShapeNOp<int64_t>);                                  \
122   REGISTER_KERNEL_BUILDER(Name("VariableShape")                                \
123                               .Device(DEVICE)                                  \
124                               .TypeConstraint<int32>("out_type")               \
125                               .HostMemory("output")                            \
126                               .HostMemory("input"),                            \
127                           VariableShapeOp<int32>);                             \
128   REGISTER_KERNEL_BUILDER(Name("VariableShape")                                \
129                               .Device(DEVICE)                                  \
130                               .TypeConstraint<int64_t>("out_type")             \
131                               .HostMemory("output")                            \
132                               .HostMemory("input"),                            \
133                           VariableShapeOp<int64_t>);                           \
134   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
135                               .Device(DEVICE)                                  \
136                               .HostMemory("output")                            \
137                               .TypeConstraint<int32>("out_type")               \
138                               .TypeConstraint("T", TYPES),                     \
139                           SizeOp<int32>);                                      \
140   REGISTER_KERNEL_BUILDER(Name("Size")                                         \
141                               .Device(DEVICE)                                  \
142                               .HostMemory("output")                            \
143                               .TypeConstraint<int64_t>("out_type")             \
144                               .TypeConstraint("T", TYPES),                     \
145                           SizeOp<int64_t>);                                    \
146   REGISTER_KERNEL_BUILDER(                                                     \
147       Name("Rank").Device(DEVICE).HostMemory("output").TypeConstraint("T",     \
148                                                                       TYPES),  \
149       RankOp);                                                                 \
150   REGISTER_KERNEL_BUILDER(                                                     \
151       Name("AssignVariableOp").Device(DEVICE).HostMemory("resource"),          \
152       XlaAssignVariableOp);                                                    \
153                                                                                \
154   REGISTER_KERNEL_BUILDER(                                                     \
155       Name("FIFOQueueV2").Device(DEVICE).HostMemory("handle"), FIFOQueueOp);   \
156                                                                                \
157   REGISTER_KERNEL_BUILDER(                                                     \
158       Name(kArgOp).Device(DEVICE).TypeConstraint("T", TYPES), ArgOp);          \
159   REGISTER_KERNEL_BUILDER(Name(kArgOp)                                         \
160                               .Device(DEVICE)                                  \
161                               .HostMemory("output")                            \
162                               .TypeConstraint<ResourceHandle>("T"),            \
163                           ArgOp);                                              \
164   REGISTER_KERNEL_BUILDER(                                                     \
165       Name(kArgOp).Device(DEVICE).TypeConstraint<Variant>("T"), ArgOp);        \
166                                                                                \
167   REGISTER_KERNEL_BUILDER(                                                     \
168       Name(kRetOp).Device(DEVICE).TypeConstraint("T", TYPES), RetvalOp);       \
169   REGISTER_KERNEL_BUILDER(Name(kRetOp)                                         \
170                               .Device(DEVICE)                                  \
171                               .TypeConstraint<ResourceHandle>("T")             \
172                               .HostMemory("input"),                            \
173                           RetvalOp);                                           \
174   REGISTER_KERNEL_BUILDER(                                                     \
175       Name(kDeviceRetOp).Device(DEVICE).TypeConstraint<int32>("T"), RetvalOp); \
176                                                                                \
177   REGISTER_KERNEL_BUILDER(                                                     \
178       Name("RemoteCall").Device(DEVICE).HostMemory("target"), RemoteCallOp);   \
179                                                                                \
180   REGISTER_KERNEL_BUILDER(                                                     \
181       Name("GeneratorDataset").Device(DEVICE).HostMemory("handle"),            \
182       data::GeneratorDatasetOp);                                               \
183   REGISTER_KERNEL_BUILDER(Name("PrefetchDataset")                              \
184                               .Device(DEVICE)                                  \
185                               .HostMemory("buffer_size")                       \
186                               .HostMemory("input_dataset")                     \
187                               .HostMemory("handle"),                           \
188                           data::PrefetchDatasetOp);                            \
189   REGISTER_KERNEL_BUILDER(Name("OptionsDataset")                               \
190                               .Device(DEVICE)                                  \
191                               .HostMemory("input_dataset")                     \
192                               .HostMemory("handle"),                           \
193                           data::OptionsDatasetOp);                             \
194   REGISTER_KERNEL_BUILDER(Name("FinalizeDataset")                              \
195                               .Device(DEVICE)                                  \
196                               .HostMemory("input_dataset")                     \
197                               .HostMemory("handle"),                           \
198                           data::FinalizeDatasetOp);                            \
199                                                                                \
200   REGISTER_KERNEL_BUILDER(Name("IteratorV2").Device(DEVICE),                   \
201                           data::IteratorHandleOp);                             \
202   REGISTER_KERNEL_BUILDER(                                                     \
203       Name("MakeIterator").Device(DEVICE).HostMemory("dataset"),               \
204       data::MakeIteratorOp);                                                   \
205   REGISTER_KERNEL_BUILDER(Name("AnonymousIterator").Device(DEVICE),            \
206                           data::AnonymousIteratorHandleOp);                    \
207   REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV2").Device(DEVICE),          \
208                           data::AnonymousIteratorHandleOp);                    \
209   REGISTER_KERNEL_BUILDER(Name("AnonymousIteratorV3").Device(DEVICE),          \
210                           data::AnonymousIteratorHandleOp);                    \
211   REGISTER_KERNEL_BUILDER(Name("DeleteIterator").Device(DEVICE),               \
212                           data::DeleteIteratorOp);                             \
213   REGISTER_KERNEL_BUILDER(Name("IteratorGetNext").Device(DEVICE),              \
214                           data::IteratorGetNextOp);                            \
215   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextAsOptional").Device(DEVICE),    \
216                           data::IteratorGetNextAsOptionalOp);                  \
217   REGISTER_KERNEL_BUILDER(Name("IteratorGetNextSync").Device(DEVICE),          \
218                           data::IteratorGetNextOp);                            \
219   REGISTER_KERNEL_BUILDER(Name("IteratorToStringHandle")                       \
220                               .Device(DEVICE)                                  \
221                               .HostMemory("string_handle"),                    \
222                           data::IteratorToStringHandleOp);                     \
223   REGISTER_KERNEL_BUILDER(Name("IteratorFromStringHandleV2")                   \
224                               .Device(DEVICE)                                  \
225                               .HostMemory("string_handle"),                    \
226                           data::IteratorFromStringHandleOp);                   \
227   REGISTER_KERNEL_BUILDER(Name("OptionalNone").Device(DEVICE),                 \
228                           data::OptionalNoneOp);                               \
229   REGISTER_KERNEL_BUILDER(Name("OptionalFromValue").Device(DEVICE),            \
230                           data::OptionalFromValueOp);                          \
231   REGISTER_KERNEL_BUILDER(                                                     \
232       Name("OptionalHasValue").Device(DEVICE).HostMemory("has_value"),         \
233       data::OptionalHasValueOp);                                               \
234   REGISTER_KERNEL_BUILDER(Name("OptionalGetValue").Device(DEVICE),             \
235                           data::OptionalGetValueOp);                           \
236   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kArgOp)              \
237                               .Device(DEVICE)                                  \
238                               .HostMemory("output")                            \
239                               .TypeConstraint<tstring>("T"),                   \
240                           ArgOp);                                              \
241   REGISTER_KERNEL_BUILDER(Name(FunctionLibraryDefinition::kRetOp)              \
242                               .Device(DEVICE)                                  \
243                               .TypeConstraint<tstring>("T")                    \
244                               .HostMemory("input"),                            \
245                           RetvalOp);
246 
247 // TODO(b/118881356): currently we do not register the QueueEnqueueMany,
248 // QueueDequeueMany, or QueueDequeueUpTo kernels because they attempt to read
249 // and write the tensors they access in order to concatenate them into a batch.
250 // We would need either to call out to an XLA computation to perform the
251 // concatenation, or we would need to refactor those kernels so the splitting
252 // or merging is done in a separate operator that can be compiled.
253 
254 }  // namespace tensorflow
255 
256 #endif  // TENSORFLOW_COMPILER_JIT_XLA_DEVICE_OPS_H_
257