• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <set>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/local_device.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mem.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/public/session_options.h"
35 
36 namespace tensorflow {
37 
38 // Names of the XLA compilation devices. These are not user-visible, and are
39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
40 // a Tensorflow graph.
41 
42 extern const char* const DEVICE_CPU_XLA_JIT;  // "CPU_XLA_JIT"
43 extern const char* const DEVICE_GPU_XLA_JIT;  // "GPU_XLA_JIT"
44 
45 extern const char* const DEVICE_XLA_CPU;
46 extern const char* const DEVICE_XLA_GPU;
47 
48 constexpr std::array<DataType, 4> kFloatTypes = {
49     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
50 constexpr std::array<DataType, 6> kFloatAndComplexTypes = {
51     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}};
52 constexpr std::array<DataType, 14> kNumericTypes = {
53     {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32,
54      DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
55      DT_BFLOAT16}};
56 
57 constexpr std::array<DataType, 18> kCpuAllTypes = {
58     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
59      DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
60      DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
61 
62 constexpr std::array<DataType, 18> kGpuAllTypes = {
63     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
64      DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
65      DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
66 
67 // Class that manages registrations of operators and devices for the XLA JIT.
68 // Not thread-safe.
69 class XlaOpRegistry {
70  public:
71   typedef OpKernel* (*Factory)(OpKernelConstruction*);
72 
73   enum class AutoclusteringPolicy {
74     // Enable autoclustering if the user requests it, e.g., via
75     // experimental_jit_scope. Does not autocluster if the JIT is enabled
76     // globally (e.g., via the OptimizerOptions in the TF session
77     // configuration.)
78     kIfExplicitlyRequested,
79     // Enable autoclustering if explicitly requested, or if the JIT is enabled
80     // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
81     kIfEnabledGlobally,
82     // Always try to autocluster ops placed on this device.
83     kAlways,
84   };
85 
86   // Describes how to compile operators assigned to a device.
87   struct DeviceRegistration {
88     // The name of the an XLA compilation device to use to compile code.
89     string compilation_device_name;
90 
91     // When should we autocluster operators assigned to this device?
92     AutoclusteringPolicy autoclustering_policy;
93 
94     // If we should ignore the resource variable memory model when clustering
95     // resource variable reads and writes placed on this device.
96     bool cluster_resource_variable_ops_unsafely = false;
97 
98     // If we should auto-cluster Stack operations placed on this device.
99     bool cluster_stack_ops = false;
100 
101     // If we should auto-cluster TensorArray operations placed on this device.
102     bool cluster_tensor_array_ops = false;
103 
104     // If we should auto-cluster stateful RNG operations placed on this device.
105     // Stateful RNG semantics are not properly supported by XLA so it is not
106     // necessarily correct to auto-cluster stateful RNG ops in general.
107     bool cluster_stateful_rng_ops = false;
108 
109     // If we should auto-cluster ControlTrigger operations placed on this
110     // device.  ControlTrigger operations are not necessarily safe to cluster
111     // since they affect deadness (a dead ControlTrigger produces a live
112     // output).
113     bool cluster_control_trigger = false;
114 
115     // If we should cluster Assert and CheckNumerics by eliding them (XLA does
116     // not natively support Assert or CheckNumerics).
117     bool elide_assert_and_checknumerics = false;
118 
119     // If we should cluster operations returning DT_VARIANT.
120     bool cluster_variant_ops = false;
121 
122     // Whether ops known to be slow should be auto-clustered.
123     bool cluster_slow_ops = false;
124 
125     // Whether ops known to have numerical accuracy issues should be
126     // auto-clustered.
127     bool cluster_inaccurate_ops = false;
128   };
129 
130   // Registers an XLA backend. `compilation_device_name` is the name of the
131   // device used for symbolic execution during compilation. `supported_types`
132   // is the list of non-resource types supported by the device. Each operators
133   // will be registered for the intersection of the operator's supported types
134   // and the device's supported types. `backend_op_filter` is a function used
135   // to exclude or modify operator registrations on the device; it may be
136   // nullptr, in which case all ops are included.
137   // `backend_op_filter` should return true if the op should be registered on
138   // the device; it may optionally modify the KernelDef.
139   typedef bool (*BackendOpFilter)(KernelDef* kdef);
140   static void RegisterBackend(const string& compilation_device_name,
141                               absl::Span<const DataType> supported_types,
142                               BackendOpFilter op_filter);
143 
144   // Returns the names of the registered backends.
145   static std::vector<string> BackendNames();
146 
147   // Returns true iff a backend with the given name is registered.
148   static bool IsBackendRegistered(const string& name);
149 
150   // Registers `device_name` for XLA compilation, using information from
151   // `registration`.
152   // Does nothing if a registration for `device_name` already exists.
153   static void RegisterCompilationDevice(const string& device_name,
154                                         const DeviceRegistration& registration);
155 
156   // Returns whether the device name is for the JIT device used exclusively for
157   // TF2XLA conversion.
158   static bool IsCompilationDevice(const string& device_name);
159 
160   // Returns the JIT device name associated with 'device_name', setting
161   // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
162   // are not null. Returns false and leaves the outputs unchanged if no matching
163   // JIT device is registered.
164   // '*enable_jit_by_default' is set to true if we should try to JIT using this
165   // device when the JIT is enabled via the Session OptimizerOptions.
166   static bool GetCompilationDevice(const string& device_name,
167                                    const DeviceRegistration** registration);
168 
169   // Registers all JIT kernels on JIT devices, if not already registered.
170   // Does nothing otherwise.
171   static void RegisterCompilationKernels();
172 
173   // Returns KernelDefs for compilation ops registered on
174   // 'compilation_device_name'.  Does not include kernels registered as
175   // CompilationOnly, iff include_compilation_only_kernels=false.
176   static std::vector<const KernelDef*> DeviceKernels(
177       const string& compilation_device_name,
178       bool include_compilation_only_kernels);
179 
180   // Returns all operations for which there are XLA kernels on any device.
181   static std::vector<string> GetAllRegisteredOps();
182 
183   // Returns (via `result`) the indices of inputs to `node_def` that must be
184   // compile-time constants. Returns an empty vector if the op is not
185   // registered.
186   //
187   // `result` is sorted.
CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def,std::vector<int> * result)188   static Status CompileTimeConstantInputs(const NodeDef& node_def,
189                                           const OpDef& op_def,
190                                           std::vector<int>* result) {
191     return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def,
192                                      result);
193   }
194 
195   // Returns (via `result`) the indices of inputs to `op_kernel` that must be
196   // compile-time constants.
197   //
198   // `result` is sorted.
CompileTimeConstantInputs(const OpKernel & op_kernel,std::vector<int> * result)199   static Status CompileTimeConstantInputs(const OpKernel& op_kernel,
200                                           std::vector<int>* result) {
201     return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel,
202                                      /*op_def=*/nullptr, result);
203   }
204 
205   // Return names of arguments for a given op which are supposed to be
206   // constants.
207   static const std::unordered_set<std::string>*
208   CompileTimeConstantInputArgNames(const string& op);
209 
210   // Returns true if `op` is a "metadata" op, one that only looks at the shapes
211   // of its operands and not their values.
212   static bool IsMetadataOp(const string& op);
213 
214  private:
215   friend class XlaBackendRegistrar;
216   friend class XlaOpRegistrar;
217   friend class XlaOpRegistrationBuilder;
218 
219   static XlaOpRegistry& Instance();
220 
221   XlaOpRegistry();
222   ~XlaOpRegistry();
223 
224   mutex mutex_;
225 
226   // Describes an XLA backend.
227   struct Backend {
228     // Which types are supported by this device?
229     std::set<DataType> supported_types;
230 
231     // The per-backend operator filter function. See the comment on
232     // RegisterBackend() for details.
233     BackendOpFilter op_filter;
234 
235     // KernelDefs built by RegisterCompilationKernels() for each op supported
236     // by the device.
237     std::vector<std::unique_ptr<KernelDef>> kernel_defs;
238   };
239 
240   // Map from compilation device names to a description of the backend.
241   std::unordered_map<string, Backend> backends_ TF_GUARDED_BY(mutex_);
242 
243   // Map from Tensorflow device names to the corresponding JIT device metadata.
244   std::unordered_map<string, DeviceRegistration> compilation_devices_
245       TF_GUARDED_BY(mutex_);
246 
247   // A description of a Tensorflow operator that can be compiled to XLA.
248   struct OpRegistration {
249     string name;
250 
251     // Should this operator be registered only on compilation devices, without a
252     // dummy kernel registered on the corresponding XLA device?
253     bool compilation_only = false;
254 
255     // Should we allow resource types for type attributes? Used by _Arg to
256     // allow DT_RESOURCE.
257     bool allow_resource_types = false;
258 
259     // Should we allow variant types for type attributes? Used by While to
260     // allow TensorList which is of type DT_VARIANT.
261     bool allow_variant_types = false;
262 
263     // Should we allow string type for type attributes? Used by PartitionedCall
264     // to allow DT_STRING.
265     bool allow_string_type = false;
266 
267     // Mapping from attribute name to a list of supported types.
268     std::unordered_map<string, std::set<DataType>> type_constraints;
269 
270     // An optional allowlist of devices. If there is no allowlist, all devices
271     // are permitted.
272     bool has_device_allowlist = false;
273     std::unordered_set<string> device_allowlist;
274 
275     // Names of arguments that must be compile-time constants.
276     std::unordered_set<string> compile_time_constant_inputs;
277 
278     // True if this is a "metadata" op, one that only looks at the shapes of its
279     // operands and not their values.
280     bool is_metadata_op = false;
281 
282     std::string label;
283 
284     // Factory used to build OpKernels that perform symbolic execution.
285     Factory factory;
286   };
287 
288   // Returns true if registrations x and y can both be added to the registry.
289   // This is always the case if they refer to different ops. If they refer to
290   // the same op name, they must: have the same values for compilation_only,
291   // allow_resource_types and allow_variant_types; use a device_allowlist; and
292   // their allowlists must not intersect.
293   static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
294 
295   static Status CompileTimeConstantInputs(const NodeDef& node_def,
296                                           const OpKernel* op_kernel,
297                                           const OpDef* op_def,
298                                           std::vector<int>* result);
299 
300   // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
301   // Registrations present under the same key must satisfy IsCompatible above,
302   // and this is checked during registration.
303   std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
304       TF_GUARDED_BY(mutex_);
305 
306   // Have we already registered the JIT kernels on the JIT devices?
307   bool jit_kernels_registered_ = false;
308 
309   // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
310   // registrations created by RegisterCompilationKernels() and
311   // RegisterDeviceKernels().
312   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
313       kernel_registrars_ TF_GUARDED_BY(mutex_);
314 };
315 
316 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
317 // REGISTER_XLA_OP(Name("Add"), AddOp);
318 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
319 //
320 // We don't use a variadic macro here because we don't expect JIT operators to
321 // be templated.
322 
323 #define REGISTER_XLA_OP(NAME, OP) \
324   REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
325 
326 class XlaOpRegistrationBuilder {
327  public:
328   // Starts an operator registration chain.
329   static XlaOpRegistrationBuilder Name(absl::string_view name);
330 
331   // Specifies a allowlist of devices on which the operator may run.
332   XlaOpRegistrationBuilder& Device(absl::string_view devices);
333   XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
334 
335   // Specifies a type constraint for a type variable attribute. Each constraint
336   // specifies the set of types that the type variable may assume.
337   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
338                                            DataType allowed);
339 
340   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
341                                            absl::Span<const DataType> allowed);
342 
343   // Specifies that a dummy copy of this operator should not be registered on
344   // XLA_* devices, but may be used during compilation.
345   XlaOpRegistrationBuilder& CompilationOnly();
346 
347   // Allow DT_RESOURCE types for type parameters.
348   XlaOpRegistrationBuilder& AllowResourceTypes();
349 
350   // Allow DT_VARIANT types for type parameters.
351   XlaOpRegistrationBuilder& AllowVariantTypes();
352 
353   // Allow DT_STRING type for type parameters.
354   XlaOpRegistrationBuilder& AllowStringType();
355 
356   // Mark 'input_name' as an argument whose value must be known at compile-time.
357   XlaOpRegistrationBuilder& CompileTimeConstantInput(
358       absl::string_view input_name);
359 
360   // Mark this op as a "metadata" op, one that only looks at the shapes of its
361   // operands and not their values.
362   XlaOpRegistrationBuilder& IsMetadataOp();
363 
364   // Specifies a particular value for the "_kernel" attr.
365   XlaOpRegistrationBuilder& Label(std::string label);
366 
367   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
368       XlaOpRegistry::Factory factory);
369 
370  private:
371   XlaOpRegistrationBuilder(absl::string_view name);
372 
373   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
374 };
375 
376 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
377 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
378 #define REGISTER_XLA_BACKEND(NAME, ...) \
379   REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
380 
381 // Implementation details.
382 
383 class XlaOpRegistrar {
384  public:
385   XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
386 };
387 
388 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
389   REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
390 
391 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP)                                 \
392   static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
393       ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build(                   \
394           [](::tensorflow::OpKernelConstruction* context)                      \
395               -> ::tensorflow::OpKernel* { return new OP(context); }));
396 
397 class XlaBackendRegistrar {
398  public:
399   XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
400                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
401 };
402 
403 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
404   REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
405 
406 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
407   static ::tensorflow::XlaBackendRegistrar        \
408       xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
409 
410 }  // namespace tensorflow
411 
412 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
413