• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
18 
19 #include <functional>
20 #include <memory>
21 #include <set>
22 #include <unordered_map>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device_factory.h"
26 #include "tensorflow/core/common_runtime/local_device.h"
27 #include "tensorflow/core/framework/device_base.h"
28 #include "tensorflow/core/framework/tensor.h"
29 #include "tensorflow/core/framework/types.pb.h"
30 #include "tensorflow/core/lib/core/status.h"
31 #include "tensorflow/core/platform/mem.h"
32 #include "tensorflow/core/platform/mutex.h"
33 #include "tensorflow/core/platform/thread_annotations.h"
34 #include "tensorflow/core/public/session_options.h"
35 
36 namespace tensorflow {
37 
38 // Names of the XLA compilation devices. These are not user-visible, and are
39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of
40 // a Tensorflow graph.
41 
42 extern const char* const DEVICE_CPU_XLA_JIT;  // "CPU_XLA_JIT"
43 extern const char* const DEVICE_GPU_XLA_JIT;  // "GPU_XLA_JIT"
44 
45 extern const char* const DEVICE_XLA_CPU;
46 extern const char* const DEVICE_XLA_GPU;
47 
48 constexpr std::array<DataType, 4> kFloatTypes = {
49     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}};
50 constexpr std::array<DataType, 6> kFloatAndComplexTypes = {
51     {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}};
52 constexpr std::array<DataType, 14> kNumericTypes = {
53     {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32,
54      DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128,
55      DT_BFLOAT16}};
56 
57 constexpr std::array<DataType, 18> kCpuAllTypes = {
58     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
59      DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
60      DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
61 
62 constexpr std::array<DataType, 18> kGpuAllTypes = {
63     {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8,
64      DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
65      DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}};
66 
67 // Class that manages registrations of operators and devices for the XLA JIT.
68 // Not thread-safe.
69 class XlaOpRegistry {
70  public:
71   typedef OpKernel* (*Factory)(OpKernelConstruction*);
72 
73   enum class AutoclusteringPolicy {
74     // Enable autoclustering if the user requests it, e.g., via
75     // experimental_jit_scope. Does not autocluster if the JIT is enabled
76     // globally (e.g., via the OptimizerOptions in the TF session
77     // configuration.)
78     kIfExplicitlyRequested,
79     // Enable autoclustering if explicitly requested, or if the JIT is enabled
80     // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N.
81     kIfEnabledGlobally,
82     // Always try to autocluster ops placed on this device.
83     kAlways,
84   };
85 
86   // Describes how to compile operators assigned to a device.
87   struct DeviceRegistration {
88     // The name of the an XLA compilation device to use to compile code.
89     string compilation_device_name;
90 
91     // When should we autocluster operators assigned to this device?
92     AutoclusteringPolicy autoclustering_policy;
93 
94     // If we should ignore the resource variable memory model when clustering
95     // resource variable reads and writes placed on this device.
96     bool cluster_resource_variable_ops_unsafely = false;
97 
98     // If we should auto-cluster Stack operations placed on this device.
99     bool cluster_stack_ops = false;
100 
101     // If we should auto-cluster TensorArray operations placed on this device.
102     bool cluster_tensor_array_ops = false;
103 
104     // If we should auto-cluster stateful RNG operations placed on this device.
105     // Stateful RNG semantics are not properly supported by XLA so it is not
106     // necessarily correct to auto-cluster stateful RNG ops in general.
107     bool cluster_stateful_rng_ops = false;
108 
109     // If we should auto-cluster ControlTrigger operations placed on this
110     // device.  ControlTrigger operations are not necessarily safe to cluster
111     // since they affect deadness (a dead ControlTrigger produces a live
112     // output).
113     bool cluster_control_trigger = false;
114 
115     // If we should cluster Assert and CheckNumerics by eliding them (XLA does
116     // not natively support Assert or CheckNumerics).
117     bool elide_assert_and_checknumerics = false;
118 
119     // If we should cluster operations returning DT_VARIANT.
120     bool cluster_variant_ops = false;
121 
122     // Whether ops known to be slow should be auto-clustered.
123     bool cluster_slow_ops = false;
124 
125     // Whether ops known to have numerical accuracy issues should be
126     // auto-clustered.
127     bool cluster_inaccurate_ops = false;
128   };
129 
130   // Registers an XLA backend. `compilation_device_name` is the name of the
131   // device used for symbolic execution during compilation. `supported_types`
132   // is the list of non-resource types supported by the device. Each operators
133   // will be registered for the intersection of the operator's supported types
134   // and the device's supported types. `backend_op_filter` is a function used
135   // to exclude or modify operator registrations on the device; it may be
136   // nullptr, in which case all ops are included.
137   // `backend_op_filter` should return true if the op should be registered on
138   // the device; it may optionally modify the KernelDef.
139   typedef bool (*BackendOpFilter)(KernelDef* kdef);
140   static void RegisterBackend(const string& compilation_device_name,
141                               absl::Span<const DataType> supported_types,
142                               BackendOpFilter op_filter);
143 
144   // Returns the names of the registered backends.
145   static std::vector<string> BackendNames();
146 
147   // Returns true iff a backend with the given name is registered.
148   static bool IsBackendRegistered(const string& name);
149 
150   // Registers `device_name` for XLA compilation, using information from
151   // `registration`.
152   // Does nothing if a registration for `device_name` already exists.
153   static void RegisterCompilationDevice(const string& device_name,
154                                         const DeviceRegistration& registration);
155 
156   // Returns whether the device name is for the JIT device used exclusively for
157   // TF2XLA conversion.
158   static bool IsCompilationDevice(const string& device_name);
159 
160   // Returns the JIT device name associated with 'device_name', setting
161   // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they
162   // are not null. Returns false and leaves the outputs unchanged if no matching
163   // JIT device is registered.
164   // '*enable_jit_by_default' is set to true if we should try to JIT using this
165   // device when the JIT is enabled via the Session OptimizerOptions.
166   static bool GetCompilationDevice(const string& device_name,
167                                    const DeviceRegistration** registration);
168 
169   // Registers all JIT kernels on JIT devices, if not already registered.
170   // Does nothing otherwise.
171   static void RegisterCompilationKernels();
172 
173   // Returns KernelDefs for compilation ops registered on
174   // 'compilation_device_name'.  Does not include kernels registered as
175   // CompilationOnly, iff include_compilation_only_kernels=false.
176   static std::vector<const KernelDef*> DeviceKernels(
177       const string& compilation_device_name,
178       bool include_compilation_only_kernels);
179 
180   // Returns all operations for which there are XLA kernels on any device.
181   static std::vector<string> GetAllRegisteredOps();
182 
183   // Returns (via `result`) the indices of inputs to `node_def` that must be
184   // compile-time constants. Returns an empty vector if the op is not
185   // registered.
186   //
187   // `result` is sorted.
CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def,std::vector<int> * result)188   static Status CompileTimeConstantInputs(const NodeDef& node_def,
189                                           const OpDef& op_def,
190                                           std::vector<int>* result) {
191     return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def,
192                                      result);
193   }
194 
CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def)195   static StatusOr<std::vector<int>> CompileTimeConstantInputs(
196       const NodeDef& node_def, const OpDef& op_def) {
197     std::vector<int> out;
198     TF_RETURN_IF_ERROR(CompileTimeConstantInputs(node_def, op_def, &out));
199     return out;
200   }
201 
202   // Returns (via `result`) the indices of inputs to `op_kernel` that must be
203   // compile-time constants.
204   //
205   // `result` is sorted.
CompileTimeConstantInputs(const OpKernel & op_kernel,std::vector<int> * result)206   static Status CompileTimeConstantInputs(const OpKernel& op_kernel,
207                                           std::vector<int>* result) {
208     return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel,
209                                      /*op_def=*/nullptr, result);
210   }
211 
212   // Return names of arguments for a given op which are supposed to be
213   // constants.
214   static const std::unordered_set<std::string>*
215   CompileTimeConstantInputArgNames(const string& op);
216 
217   // Returns true if `op` is a "metadata" op, one that only looks at the shapes
218   // of its operands and not their values.
219   static bool IsMetadataOp(const string& op);
220 
221  private:
222   friend class XlaBackendRegistrar;
223   friend class XlaOpRegistrar;
224   friend class XlaOpRegistrationBuilder;
225 
226   static XlaOpRegistry& Instance();
227 
228   XlaOpRegistry();
229   ~XlaOpRegistry();
230 
231   mutex mutex_;
232 
233   // Describes an XLA backend.
234   struct Backend {
235     // Which types are supported by this device?
236     std::set<DataType> supported_types;
237 
238     // The per-backend operator filter function. See the comment on
239     // RegisterBackend() for details.
240     BackendOpFilter op_filter;
241 
242     // KernelDefs built by RegisterCompilationKernels() for each op supported
243     // by the device.
244     std::vector<std::unique_ptr<KernelDef>> kernel_defs;
245   };
246 
247   // Map from compilation device names to a description of the backend.
248   std::unordered_map<string, Backend> backends_ TF_GUARDED_BY(mutex_);
249 
250   // Map from Tensorflow device names to the corresponding JIT device metadata.
251   std::unordered_map<string, DeviceRegistration> compilation_devices_
252       TF_GUARDED_BY(mutex_);
253 
254   // A description of a Tensorflow operator that can be compiled to XLA.
255   struct OpRegistration {
256     string name;
257 
258     // Should this operator be registered only on compilation devices, without a
259     // dummy kernel registered on the corresponding XLA device?
260     bool compilation_only = false;
261 
262     // Should we allow resource types for type attributes? Used by _Arg to
263     // allow DT_RESOURCE.
264     bool allow_resource_types = false;
265 
266     // Should we allow variant types for type attributes? Used by While to
267     // allow TensorList which is of type DT_VARIANT.
268     bool allow_variant_types = false;
269 
270     // Should we allow string type for type attributes? Used by PartitionedCall
271     // to allow DT_STRING.
272     bool allow_string_type = false;
273 
274     // Mapping from attribute name to a list of supported types.
275     std::unordered_map<string, std::set<DataType>> type_constraints;
276 
277     // An optional allowlist of devices. If there is no allowlist, all devices
278     // are permitted.
279     bool has_device_allowlist = false;
280     std::unordered_set<string> device_allowlist;
281 
282     // Names of arguments that must be compile-time constants.
283     std::unordered_set<string> compile_time_constant_inputs;
284 
285     // True if this is a "metadata" op, one that only looks at the shapes of its
286     // operands and not their values.
287     bool is_metadata_op = false;
288 
289     std::string label;
290 
291     // Factory used to build OpKernels that perform symbolic execution.
292     Factory factory;
293   };
294 
295   // Returns true if registrations x and y can both be added to the registry.
296   // This is always the case if they refer to different ops. If they refer to
297   // the same op name, they must: have the same values for compilation_only,
298   // allow_resource_types and allow_variant_types; use a device_allowlist; and
299   // their allowlists must not intersect.
300   static bool IsCompatible(const OpRegistration& x, const OpRegistration& y);
301 
302   static Status CompileTimeConstantInputs(const NodeDef& node_def,
303                                           const OpKernel* op_kernel,
304                                           const OpDef* op_def,
305                                           std::vector<int>* result);
306 
307   // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP.
308   // Registrations present under the same key must satisfy IsCompatible above,
309   // and this is checked during registration.
310   std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_
311       TF_GUARDED_BY(mutex_);
312 
313   // Have we already registered the JIT kernels on the JIT devices?
314   bool jit_kernels_registered_ = false;
315 
316   // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel
317   // registrations created by RegisterCompilationKernels() and
318   // RegisterDeviceKernels().
319   std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>>
320       kernel_registrars_ TF_GUARDED_BY(mutex_);
321 };
322 
323 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example:
324 // REGISTER_XLA_OP(Name("Add"), AddOp);
325 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add".
326 //
327 // We don't use a variadic macro here because we don't expect JIT operators to
328 // be templated.
329 
330 #define REGISTER_XLA_OP(NAME, OP) \
331   REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP)
332 
333 class XlaOpRegistrationBuilder {
334  public:
335   // Starts an operator registration chain.
336   static XlaOpRegistrationBuilder Name(absl::string_view name);
337 
338   // Specifies a allowlist of devices on which the operator may run.
339   XlaOpRegistrationBuilder& Device(absl::string_view devices);
340   XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices);
341 
342   // Specifies a type constraint for a type variable attribute. Each constraint
343   // specifies the set of types that the type variable may assume.
344   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
345                                            DataType allowed);
346 
347   XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name,
348                                            absl::Span<const DataType> allowed);
349 
350   // Specifies that a dummy copy of this operator should not be registered on
351   // XLA_* devices, but may be used during compilation.
352   XlaOpRegistrationBuilder& CompilationOnly();
353 
354   // Allow DT_RESOURCE types for type parameters.
355   XlaOpRegistrationBuilder& AllowResourceTypes();
356 
357   // Allow DT_VARIANT types for type parameters.
358   XlaOpRegistrationBuilder& AllowVariantTypes();
359 
360   // Allow DT_STRING type for type parameters.
361   XlaOpRegistrationBuilder& AllowStringType();
362 
363   // Mark 'input_name' as an argument whose value must be known at compile-time.
364   XlaOpRegistrationBuilder& CompileTimeConstantInput(
365       absl::string_view input_name);
366 
367   // Mark this op as a "metadata" op, one that only looks at the shapes of its
368   // operands and not their values.
369   XlaOpRegistrationBuilder& IsMetadataOp();
370 
371   // Specifies a particular value for the "_kernel" attr.
372   XlaOpRegistrationBuilder& Label(std::string label);
373 
374   std::unique_ptr<XlaOpRegistry::OpRegistration> Build(
375       XlaOpRegistry::Factory factory);
376 
377  private:
378   XlaOpRegistrationBuilder(absl::string_view name);
379 
380   std::unique_ptr<XlaOpRegistry::OpRegistration> registration_;
381 };
382 
383 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage:
384 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter);
385 #define REGISTER_XLA_BACKEND(NAME, ...) \
386   REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__)
387 
388 // Implementation details.
389 
390 class XlaOpRegistrar {
391  public:
392   XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration);
393 };
394 
395 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \
396   REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP)
397 
398 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP)                                 \
399   static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \
400       ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build(                   \
401           [](::tensorflow::OpKernelConstruction* context)                      \
402               -> ::tensorflow::OpKernel* { return new OP(context); }));
403 
404 class XlaBackendRegistrar {
405  public:
406   XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types,
407                       XlaOpRegistry::BackendOpFilter op_filter = nullptr);
408 };
409 
410 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \
411   REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__)
412 
413 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \
414   static ::tensorflow::XlaBackendRegistrar        \
415       xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__);
416 
417 }  // namespace tensorflow
418 
419 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_
420