• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <set>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/local_device.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mem.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/public/session_options.h"
35 
36 namespace tensorflow {
37 
38 // Names of the XLA compilation devices. These are not user-visible, and are
39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
40 // a Tensorflow graph.
41 
42 extern const char* const DEVICE_CPU_XLA_JIT;  // "CPU_XLA_JIT"
43 extern const char* const DEVICE_GPU_XLA_JIT;  // "GPU_XLA_JIT"
44 
45 extern const char* const DEVICE_XLA_CPU;
46 extern const char* const DEVICE_XLA_GPU;
47 
48 constexpr std::array<DataType, 4> kFloatTypes = {
49     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
50 constexpr std::array<DataType, 12> kNumericTypes = {
51     {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF,
52      DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}};
53 
54 constexpr std::array<DataType, 16> kCpuAllTypes = {
55     {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
56      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64,
57      DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
58 
59 constexpr std::array<DataType, 15> kGpuAllTypes = {
60     {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32,
61      DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
62      DT_BFLOAT16}};
63 
64 // Class that manages registrations of operators and devices for the XLA JIT.
65 // Not thread-safe.
66 class XlaOpRegistry {
67  public:
68   typedef OpKernel* (*Factory)(OpKernelConstruction*);
69 
70   enum class AutoclusteringPolicy {
71     // Enable autoclustering if the user requests it, e.g., via
72     // experimental_jit_scope. Does not autocluster if the JIT is enabled
73     // globally (e.g., via the OptimizerOptions in the TF session
74     // configuration.)
75     kIfExplicitlyRequested,
76     // Enable autoclustering if explicitly requested, or if the JIT is enabled
77     // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
78     kIfEnabledGlobally,
79     // Always try to autocluster ops placed on this device.
80     kAlways,
81   };
82 
83   // Describes how to compile operators assigned to a device.
84   struct DeviceRegistration {
85     // The name of the an XLA compilation device to use to compile code.
86     string compilation_device_name;
87 
88     // When should we autocluster operators assigned to this device?
89     AutoclusteringPolicy autoclustering_policy;
90 
91     // Enable compilation of operators that use DT_RESOURCE types?
92     bool compile_all_resource_ops = false;
93   };
94 
95   // Registers an XLA backend. `compilation_device_name` is the name of the
96   // device used for symbolic execution during compilation. `supported_types`
97   // is the list of non-resource types supported by the device. Each operators
98   // will be registered for the intersection of the operator's supported types
99   // and the device's supported types. `backend_op_filter` is a function used
100   // to exclude or modify operator registrations on the device; it may be
101   // nullptr, in which case all ops are included.
102   // `backend_op_filter` should return true if the op should be registered on
103   // the device; it may optionally modify the KernelDef.
104   typedef bool (*BackendOpFilter)(KernelDef* kdef);
105   static void RegisterBackend(const string& compilation_device_name,
106                               absl::Span<const DataType> supported_types,
107                               BackendOpFilter op_filter);
108 
109   // Returns the names of the registered backends.
110   static std::vector<string> BackendNames();
111 
112   // Returns true iff a backend with the given name is registered.
113   static bool IsBackendRegistered(const string& name);
114 
115   // Registers `device_name` for XLA compilation, using information from
116   // `registration`.
117   // Does nothing if a registration for `device_name` already exists.
118   static void RegisterCompilationDevice(const string& device_name,
119                                         const DeviceRegistration& registration);
120 
121   // Returns the JIT device name associated with 'device_name', setting
122   // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
123   // are not null. Returns false and leaves the outputs unchanged if no matching
124   // JIT device is registered.
125   // '*enable_jit_by_default' is set to true if we should try to JIT using this
126   // device when the JIT is enabled via the Session OptimizerOptions.
127   static bool GetCompilationDevice(const string& device_name,
128                                    const DeviceRegistration** registration);
129 
130   // Registers all JIT kernels on JIT devices, if not already registered.
131   // Does nothing otherwise.
132   static void RegisterCompilationKernels();
133 
134   // Returns KernelDefs for compilation ops registered on
135   // 'compilation_device_name'.  Does not include kernels registered as
136   // CompilationOnly, iff include_compilation_only_kernels=false.
137   static std::vector<const KernelDef*> DeviceKernels(
138       const string& compilation_device_name,
139       bool include_compilation_only_kernels);
140 
141   // Returns all operations for which there are XLA kernels on any device.
142   static std::vector<string> GetAllRegisteredOps();
143 
144   // Returns (via `result`) the indices of inputs to `node_def` that must be
145   // compile-time constants. Returns an empty vector if the op is not
146   // registered.
147   //
148   // `result` is sorted.
CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def,std::vector<int> * result)149   static Status CompileTimeConstantInputs(const NodeDef& node_def,
150                                           const OpDef& op_def,
151                                           std::vector<int>* result) {
152     return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def,
153                                      result);
154   }
155 
156   // Returns (via `result`) the indices of inputs to `op_kernel` that must be
157   // compile-time constants.
158   //
159   // `result` is sorted.
CompileTimeConstantInputs(const OpKernel & op_kernel,std::vector<int> * result)160   static Status CompileTimeConstantInputs(const OpKernel& op_kernel,
161                                           std::vector<int>* result) {
162     return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel,
163                                      /*op_def=*/nullptr, result);
164   }
165 
166   // Returns true if `op` is a "metadata" op, one that only looks at the shapes
167   // of its operands and not their values.
168   static bool IsMetadataOp(const string& op);
169 
170  private:
171   friend class XlaBackendRegistrar;
172   friend class XlaOpRegistrar;
173   friend class XlaOpRegistrationBuilder;
174 
175   static XlaOpRegistry& Instance();
176 
177   XlaOpRegistry();
178   ~XlaOpRegistry();
179 
180   mutex mutex_;
181 
182   // Describes an XLA backend.
183   struct Backend {
184     // Which types are supported by this device?
185     std::set<DataType> supported_types;
186 
187     // The per-backend operator filter function. See the comment on
188     // RegisterBackend() for details.
189     BackendOpFilter op_filter;
190 
191     // KernelDefs built by RegisterCompilationKernels() for each op supported
192     // by the device.
193     std::vector<std::unique_ptr<KernelDef>> kernel_defs;
194   };
195 
196   // Map from compilation device names to a description of the backend.
197   std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_);
198 
199   // Map from Tensorflow device names to the corresponding JIT device metadata.
200   std::unordered_map<string, DeviceRegistration> compilation_devices_
201       GUARDED_BY(mutex_);
202 
203   // A description of a Tensorflow operator that can be compiled to XLA.
204   struct OpRegistration {
205     string name;
206 
207     // Should this operator be registered only on compilation devices, without a
208     // dummy kernel registered on the corresponding XLA device?
209     bool compilation_only = false;
210 
211     // Should we allow resource types for type attributes? Used by _Arg to
212     // allow DT_RESOURCE.
213     bool allow_resource_types = false;
214 
215     // Should we allow variant types for type attributes? Used by While to
216     // allow TensorList which is of type DT_VARIANT.
217     bool allow_variant_types = false;
218 
219     // Mapping from attribute name to a list of supported types.
220     std::unordered_map<string, std::set<DataType>> type_constraints;
221 
222     // An optional whitelist of devices. If there is no whitelist, all devices
223     // are permitted.
224     bool has_device_whitelist = false;
225     std::unordered_set<string> device_whitelist;
226 
227     // Names of arguments that must be compile-time constants.
228     std::unordered_set<string> compile_time_constant_inputs;
229 
230     // True if this is a "metadata" op, one that only looks at the shapes of its
231     // operands and not their values.
232     bool is_metadata_op = false;
233 
234     // Factory used to build OpKernels that perform symbolic execution.
235     Factory factory;
236   };
237 
238   // Returns true if registrations x and y can both be added to the registry.
239   // This is always the case if they refer to different ops. If they refer to
240   // the same op name, they must: have the same values for compilation_only,
241   // allow_resource_types and allow_variant_types; use a device_whitelist; and
242   // their whitelists must not intersect.
243   static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
244 
245   static Status CompileTimeConstantInputs(const NodeDef& node_def,
246                                           const OpKernel* op_kernel,
247                                           const OpDef* op_def,
248                                           std::vector<int>* result);
249 
250   // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
251   // Registrations present under the same key must satisfy IsCompatible above,
252   // and this is checked during registration.
253   std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
254       GUARDED_BY(mutex_);
255 
256   // Have we already registered the JIT kernels on the JIT devices?
257   bool jit_kernels_registered_ = false;
258 
259   // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
260   // registrations created by RegisterCompilationKernels() and
261   // RegisterDeviceKernels().
262   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
263       kernel_registrars_ GUARDED_BY(mutex_);
264 };
265 
266 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
267 // REGISTER_XLA_OP(Name("Add"), AddOp);
268 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
269 //
270 // We don't use a variadic macro here because we don't expect JIT operators to
271 // be templated.
272 
273 #define REGISTER_XLA_OP(NAME, OP) \
274   REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
275 
276 class XlaOpRegistrationBuilder {
277  public:
278   // Starts an operator registration chain.
279   static XlaOpRegistrationBuilder Name(absl::string_view name);
280 
281   // Specifies a whitelist of devices on which the operator may run.
282   XlaOpRegistrationBuilder& Device(absl::string_view devices);
283   XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
284 
285   // Specifies a type constraint for a type variable attribute. Each constraint
286   // specifies the set of types that the type variable may assume.
287   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
288                                            DataType allowed);
289 
290   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
291                                            absl::Span<const DataType> allowed);
292 
293   // Specifies that a dummy copy of this operator should not be registered on
294   // XLA_* devices, but may be used during compilation.
295   XlaOpRegistrationBuilder& CompilationOnly();
296 
297   // Allow DT_RESOURCE types for type parameters.
298   XlaOpRegistrationBuilder& AllowResourceTypes();
299 
300   // Allow DT_VARIANT types for type parameters.
301   XlaOpRegistrationBuilder& AllowVariantTypes();
302 
303   // Mark 'input_name' as an argument whose value must be known at compile-time.
304   XlaOpRegistrationBuilder& CompileTimeConstantInput(
305       absl::string_view input_name);
306 
307   // Mark this op as a "metadata" op, one that only looks at the shapes of its
308   // operands and not their values.
309   XlaOpRegistrationBuilder& IsMetadataOp();
310 
311   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
312       XlaOpRegistry::Factory factory);
313 
314  private:
315   XlaOpRegistrationBuilder(absl::string_view name);
316 
317   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
318 };
319 
320 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
321 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
322 #define REGISTER_XLA_BACKEND(NAME, ...) \
323   REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
324 
325 // Implementation details.
326 
327 class XlaOpRegistrar {
328  public:
329   XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
330 };
331 
332 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
333   REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
334 
335 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP)                                 \
336   static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
337       ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build(                   \
338           [](::tensorflow::OpKernelConstruction* context)                      \
339               -> ::tensorflow::OpKernel* { return new OP(context); }));
340 
341 class XlaBackendRegistrar {
342  public:
343   XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
344                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
345 };
346 
347 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
348   REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
349 
350 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
351   static ::tensorflow::XlaBackendRegistrar        \
352       xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
353 
354 }  // namespace tensorflow
355 
356 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
357