1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <set> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/device_factory.h" 26 #include "tensorflow/core/common_runtime/local_device.h" 27 #include "tensorflow/core/framework/device_base.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/mem.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 38 // Names of the XLA compilation devices. These are not user-visible, and are 39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of 40 // a Tensorflow graph. 41 42 extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" 43 extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" 44 45 extern const char* const DEVICE_XLA_CPU; 46 extern const char* const DEVICE_XLA_GPU; 47 48 constexpr std::array<DataType, 4> kFloatTypes = { 49 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; 50 constexpr std::array<DataType, 6> kFloatAndComplexTypes = { 51 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}}; 52 constexpr std::array<DataType, 14> kNumericTypes = { 53 {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, 54 DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, 55 DT_BFLOAT16}}; 56 57 constexpr std::array<DataType, 18> kCpuAllTypes = { 58 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, 59 DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, 60 DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; 61 62 constexpr std::array<DataType, 18> kGpuAllTypes = { 63 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, 64 DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, 65 DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; 66 67 // Class that manages registrations of operators and devices for the XLA JIT. 68 // Not thread-safe. 69 class XlaOpRegistry { 70 public: 71 typedef OpKernel* (*Factory)(OpKernelConstruction*); 72 73 enum class AutoclusteringPolicy { 74 // Enable autoclustering if the user requests it, e.g., via 75 // experimental_jit_scope. Does not autocluster if the JIT is enabled 76 // globally (e.g., via the OptimizerOptions in the TF session 77 // configuration.) 78 kIfExplicitlyRequested, 79 // Enable autoclustering if explicitly requested, or if the JIT is enabled 80 // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. 81 kIfEnabledGlobally, 82 // Always try to autocluster ops placed on this device. 83 kAlways, 84 }; 85 86 // Describes how to compile operators assigned to a device. 87 struct DeviceRegistration { 88 // The name of the an XLA compilation device to use to compile code. 89 string compilation_device_name; 90 91 // When should we autocluster operators assigned to this device? 92 AutoclusteringPolicy autoclustering_policy; 93 94 // If we should ignore the resource variable memory model when clustering 95 // resource variable reads and writes placed on this device. 96 bool cluster_resource_variable_ops_unsafely = false; 97 98 // If we should auto-cluster Stack operations placed on this device. 99 bool cluster_stack_ops = false; 100 101 // If we should auto-cluster TensorArray operations placed on this device. 102 bool cluster_tensor_array_ops = false; 103 104 // If we should auto-cluster stateful RNG operations placed on this device. 105 // Stateful RNG semantics are not properly supported by XLA so it is not 106 // necessarily correct to auto-cluster stateful RNG ops in general. 107 bool cluster_stateful_rng_ops = false; 108 109 // If we should auto-cluster ControlTrigger operations placed on this 110 // device. ControlTrigger operations are not necessarily safe to cluster 111 // since they affect deadness (a dead ControlTrigger produces a live 112 // output). 113 bool cluster_control_trigger = false; 114 115 // If we should cluster Assert and CheckNumerics by eliding them (XLA does 116 // not natively support Assert or CheckNumerics). 117 bool elide_assert_and_checknumerics = false; 118 119 // If we should cluster operations returning DT_VARIANT. 120 bool cluster_variant_ops = false; 121 122 // Whether ops known to be slow should be auto-clustered. 123 bool cluster_slow_ops = false; 124 125 // Whether ops known to have numerical accuracy issues should be 126 // auto-clustered. 127 bool cluster_inaccurate_ops = false; 128 }; 129 130 // Registers an XLA backend. `compilation_device_name` is the name of the 131 // device used for symbolic execution during compilation. `supported_types` 132 // is the list of non-resource types supported by the device. Each operators 133 // will be registered for the intersection of the operator's supported types 134 // and the device's supported types. `backend_op_filter` is a function used 135 // to exclude or modify operator registrations on the device; it may be 136 // nullptr, in which case all ops are included. 137 // `backend_op_filter` should return true if the op should be registered on 138 // the device; it may optionally modify the KernelDef. 139 typedef bool (*BackendOpFilter)(KernelDef* kdef); 140 static void RegisterBackend(const string& compilation_device_name, 141 absl::Span<const DataType> supported_types, 142 BackendOpFilter op_filter); 143 144 // Returns the names of the registered backends. 145 static std::vector<string> BackendNames(); 146 147 // Returns true iff a backend with the given name is registered. 148 static bool IsBackendRegistered(const string& name); 149 150 // Registers `device_name` for XLA compilation, using information from 151 // `registration`. 152 // Does nothing if a registration for `device_name` already exists. 153 static void RegisterCompilationDevice(const string& device_name, 154 const DeviceRegistration& registration); 155 156 // Returns whether the device name is for the JIT device used exclusively for 157 // TF2XLA conversion. 158 static bool IsCompilationDevice(const string& device_name); 159 160 // Returns the JIT device name associated with 'device_name', setting 161 // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they 162 // are not null. Returns false and leaves the outputs unchanged if no matching 163 // JIT device is registered. 164 // '*enable_jit_by_default' is set to true if we should try to JIT using this 165 // device when the JIT is enabled via the Session OptimizerOptions. 166 static bool GetCompilationDevice(const string& device_name, 167 const DeviceRegistration** registration); 168 169 // Registers all JIT kernels on JIT devices, if not already registered. 170 // Does nothing otherwise. 171 static void RegisterCompilationKernels(); 172 173 // Returns KernelDefs for compilation ops registered on 174 // 'compilation_device_name'. Does not include kernels registered as 175 // CompilationOnly, iff include_compilation_only_kernels=false. 176 static std::vector<const KernelDef*> DeviceKernels( 177 const string& compilation_device_name, 178 bool include_compilation_only_kernels); 179 180 // Returns all operations for which there are XLA kernels on any device. 181 static std::vector<string> GetAllRegisteredOps(); 182 183 // Returns (via `result`) the indices of inputs to `node_def` that must be 184 // compile-time constants. Returns an empty vector if the op is not 185 // registered. 186 // 187 // `result` is sorted. CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def,std::vector<int> * result)188 static Status CompileTimeConstantInputs(const NodeDef& node_def, 189 const OpDef& op_def, 190 std::vector<int>* result) { 191 return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def, 192 result); 193 } 194 195 // Returns (via `result`) the indices of inputs to `op_kernel` that must be 196 // compile-time constants. 197 // 198 // `result` is sorted. CompileTimeConstantInputs(const OpKernel & op_kernel,std::vector<int> * result)199 static Status CompileTimeConstantInputs(const OpKernel& op_kernel, 200 std::vector<int>* result) { 201 return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel, 202 /*op_def=*/nullptr, result); 203 } 204 205 // Return names of arguments for a given op which are supposed to be 206 // constants. 207 static const std::unordered_set<std::string>* 208 CompileTimeConstantInputArgNames(const string& op); 209 210 // Returns true if `op` is a "metadata" op, one that only looks at the shapes 211 // of its operands and not their values. 212 static bool IsMetadataOp(const string& op); 213 214 private: 215 friend class XlaBackendRegistrar; 216 friend class XlaOpRegistrar; 217 friend class XlaOpRegistrationBuilder; 218 219 static XlaOpRegistry& Instance(); 220 221 XlaOpRegistry(); 222 ~XlaOpRegistry(); 223 224 mutex mutex_; 225 226 // Describes an XLA backend. 227 struct Backend { 228 // Which types are supported by this device? 229 std::set<DataType> supported_types; 230 231 // The per-backend operator filter function. See the comment on 232 // RegisterBackend() for details. 233 BackendOpFilter op_filter; 234 235 // KernelDefs built by RegisterCompilationKernels() for each op supported 236 // by the device. 237 std::vector<std::unique_ptr<KernelDef>> kernel_defs; 238 }; 239 240 // Map from compilation device names to a description of the backend. 241 std::unordered_map<string, Backend> backends_ TF_GUARDED_BY(mutex_); 242 243 // Map from Tensorflow device names to the corresponding JIT device metadata. 244 std::unordered_map<string, DeviceRegistration> compilation_devices_ 245 TF_GUARDED_BY(mutex_); 246 247 // A description of a Tensorflow operator that can be compiled to XLA. 248 struct OpRegistration { 249 string name; 250 251 // Should this operator be registered only on compilation devices, without a 252 // dummy kernel registered on the corresponding XLA device? 253 bool compilation_only = false; 254 255 // Should we allow resource types for type attributes? Used by _Arg to 256 // allow DT_RESOURCE. 257 bool allow_resource_types = false; 258 259 // Should we allow variant types for type attributes? Used by While to 260 // allow TensorList which is of type DT_VARIANT. 261 bool allow_variant_types = false; 262 263 // Should we allow string type for type attributes? Used by PartitionedCall 264 // to allow DT_STRING. 265 bool allow_string_type = false; 266 267 // Mapping from attribute name to a list of supported types. 268 std::unordered_map<string, std::set<DataType>> type_constraints; 269 270 // An optional allowlist of devices. If there is no allowlist, all devices 271 // are permitted. 272 bool has_device_allowlist = false; 273 std::unordered_set<string> device_allowlist; 274 275 // Names of arguments that must be compile-time constants. 276 std::unordered_set<string> compile_time_constant_inputs; 277 278 // True if this is a "metadata" op, one that only looks at the shapes of its 279 // operands and not their values. 280 bool is_metadata_op = false; 281 282 std::string label; 283 284 // Factory used to build OpKernels that perform symbolic execution. 285 Factory factory; 286 }; 287 288 // Returns true if registrations x and y can both be added to the registry. 289 // This is always the case if they refer to different ops. If they refer to 290 // the same op name, they must: have the same values for compilation_only, 291 // allow_resource_types and allow_variant_types; use a device_allowlist; and 292 // their allowlists must not intersect. 293 static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); 294 295 static Status CompileTimeConstantInputs(const NodeDef& node_def, 296 const OpKernel* op_kernel, 297 const OpDef* op_def, 298 std::vector<int>* result); 299 300 // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. 301 // Registrations present under the same key must satisfy IsCompatible above, 302 // and this is checked during registration. 303 std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_ 304 TF_GUARDED_BY(mutex_); 305 306 // Have we already registered the JIT kernels on the JIT devices? 307 bool jit_kernels_registered_ = false; 308 309 // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel 310 // registrations created by RegisterCompilationKernels() and 311 // RegisterDeviceKernels(). 312 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 313 kernel_registrars_ TF_GUARDED_BY(mutex_); 314 }; 315 316 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: 317 // REGISTER_XLA_OP(Name("Add"), AddOp); 318 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add". 319 // 320 // We don't use a variadic macro here because we don't expect JIT operators to 321 // be templated. 322 323 #define REGISTER_XLA_OP(NAME, OP) \ 324 REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) 325 326 class XlaOpRegistrationBuilder { 327 public: 328 // Starts an operator registration chain. 329 static XlaOpRegistrationBuilder Name(absl::string_view name); 330 331 // Specifies a allowlist of devices on which the operator may run. 332 XlaOpRegistrationBuilder& Device(absl::string_view devices); 333 XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices); 334 335 // Specifies a type constraint for a type variable attribute. Each constraint 336 // specifies the set of types that the type variable may assume. 337 XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, 338 DataType allowed); 339 340 XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, 341 absl::Span<const DataType> allowed); 342 343 // Specifies that a dummy copy of this operator should not be registered on 344 // XLA_* devices, but may be used during compilation. 345 XlaOpRegistrationBuilder& CompilationOnly(); 346 347 // Allow DT_RESOURCE types for type parameters. 348 XlaOpRegistrationBuilder& AllowResourceTypes(); 349 350 // Allow DT_VARIANT types for type parameters. 351 XlaOpRegistrationBuilder& AllowVariantTypes(); 352 353 // Allow DT_STRING type for type parameters. 354 XlaOpRegistrationBuilder& AllowStringType(); 355 356 // Mark 'input_name' as an argument whose value must be known at compile-time. 357 XlaOpRegistrationBuilder& CompileTimeConstantInput( 358 absl::string_view input_name); 359 360 // Mark this op as a "metadata" op, one that only looks at the shapes of its 361 // operands and not their values. 362 XlaOpRegistrationBuilder& IsMetadataOp(); 363 364 // Specifies a particular value for the "_kernel" attr. 365 XlaOpRegistrationBuilder& Label(std::string label); 366 367 std::unique_ptr<XlaOpRegistry::OpRegistration> Build( 368 XlaOpRegistry::Factory factory); 369 370 private: 371 XlaOpRegistrationBuilder(absl::string_view name); 372 373 std::unique_ptr<XlaOpRegistry::OpRegistration> registration_; 374 }; 375 376 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage: 377 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); 378 #define REGISTER_XLA_BACKEND(NAME, ...) \ 379 REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) 380 381 // Implementation details. 382 383 class XlaOpRegistrar { 384 public: 385 XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration); 386 }; 387 388 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ 389 REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) 390 391 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ 392 static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ 393 ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ 394 [](::tensorflow::OpKernelConstruction* context) \ 395 -> ::tensorflow::OpKernel* { return new OP(context); })); 396 397 class XlaBackendRegistrar { 398 public: 399 XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types, 400 XlaOpRegistry::BackendOpFilter op_filter = nullptr); 401 }; 402 403 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ 404 REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) 405 406 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ 407 static ::tensorflow::XlaBackendRegistrar \ 408 xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); 409 410 } // namespace tensorflow 411 412 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 413