1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <set> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/device_factory.h" 26 #include "tensorflow/core/common_runtime/local_device.h" 27 #include "tensorflow/core/framework/device_base.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/mem.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 38 // Names of the XLA compilation devices. These are not user-visible, and are 39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of 40 // a Tensorflow graph. 41 42 extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" 43 extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" 44 45 extern const char* const DEVICE_XLA_CPU; 46 extern const char* const DEVICE_XLA_GPU; 47 48 constexpr std::array<DataType, 4> kFloatTypes = { 49 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; 50 constexpr std::array<DataType, 6> kFloatAndComplexTypes = { 51 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16, DT_COMPLEX64, DT_COMPLEX128}}; 52 constexpr std::array<DataType, 14> kNumericTypes = { 53 {DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_INT16, DT_INT32, 54 DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, 55 DT_BFLOAT16}}; 56 57 constexpr std::array<DataType, 18> kCpuAllTypes = { 58 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, 59 DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, 60 DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; 61 62 constexpr std::array<DataType, 18> kGpuAllTypes = { 63 {DT_UINT8, DT_QUINT8, DT_UINT16, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, 64 DT_INT16, DT_INT32, DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, 65 DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; 66 67 // Class that manages registrations of operators and devices for the XLA JIT. 68 // Not thread-safe. 69 class XlaOpRegistry { 70 public: 71 typedef OpKernel* (*Factory)(OpKernelConstruction*); 72 73 enum class AutoclusteringPolicy { 74 // Enable autoclustering if the user requests it, e.g., via 75 // experimental_jit_scope. Does not autocluster if the JIT is enabled 76 // globally (e.g., via the OptimizerOptions in the TF session 77 // configuration.) 78 kIfExplicitlyRequested, 79 // Enable autoclustering if explicitly requested, or if the JIT is enabled 80 // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. 81 kIfEnabledGlobally, 82 // Always try to autocluster ops placed on this device. 83 kAlways, 84 }; 85 86 // Describes how to compile operators assigned to a device. 87 struct DeviceRegistration { 88 // The name of the an XLA compilation device to use to compile code. 89 string compilation_device_name; 90 91 // When should we autocluster operators assigned to this device? 92 AutoclusteringPolicy autoclustering_policy; 93 94 // If we should ignore the resource variable memory model when clustering 95 // resource variable reads and writes placed on this device. 96 bool cluster_resource_variable_ops_unsafely = false; 97 98 // If we should auto-cluster Stack operations placed on this device. 99 bool cluster_stack_ops = false; 100 101 // If we should auto-cluster TensorArray operations placed on this device. 102 bool cluster_tensor_array_ops = false; 103 104 // If we should auto-cluster stateful RNG operations placed on this device. 105 // Stateful RNG semantics are not properly supported by XLA so it is not 106 // necessarily correct to auto-cluster stateful RNG ops in general. 107 bool cluster_stateful_rng_ops = false; 108 109 // If we should auto-cluster ControlTrigger operations placed on this 110 // device. ControlTrigger operations are not necessarily safe to cluster 111 // since they affect deadness (a dead ControlTrigger produces a live 112 // output). 113 bool cluster_control_trigger = false; 114 115 // If we should cluster Assert and CheckNumerics by eliding them (XLA does 116 // not natively support Assert or CheckNumerics). 117 bool elide_assert_and_checknumerics = false; 118 119 // If we should cluster operations returning DT_VARIANT. 120 bool cluster_variant_ops = false; 121 122 // Whether ops known to be slow should be auto-clustered. 123 bool cluster_slow_ops = false; 124 125 // Whether ops known to have numerical accuracy issues should be 126 // auto-clustered. 127 bool cluster_inaccurate_ops = false; 128 }; 129 130 // Registers an XLA backend. `compilation_device_name` is the name of the 131 // device used for symbolic execution during compilation. `supported_types` 132 // is the list of non-resource types supported by the device. Each operators 133 // will be registered for the intersection of the operator's supported types 134 // and the device's supported types. `backend_op_filter` is a function used 135 // to exclude or modify operator registrations on the device; it may be 136 // nullptr, in which case all ops are included. 137 // `backend_op_filter` should return true if the op should be registered on 138 // the device; it may optionally modify the KernelDef. 139 typedef bool (*BackendOpFilter)(KernelDef* kdef); 140 static void RegisterBackend(const string& compilation_device_name, 141 absl::Span<const DataType> supported_types, 142 BackendOpFilter op_filter); 143 144 // Returns the names of the registered backends. 145 static std::vector<string> BackendNames(); 146 147 // Returns true iff a backend with the given name is registered. 148 static bool IsBackendRegistered(const string& name); 149 150 // Registers `device_name` for XLA compilation, using information from 151 // `registration`. 152 // Does nothing if a registration for `device_name` already exists. 153 static void RegisterCompilationDevice(const string& device_name, 154 const DeviceRegistration& registration); 155 156 // Returns whether the device name is for the JIT device used exclusively for 157 // TF2XLA conversion. 158 static bool IsCompilationDevice(const string& device_name); 159 160 // Returns the JIT device name associated with 'device_name', setting 161 // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they 162 // are not null. Returns false and leaves the outputs unchanged if no matching 163 // JIT device is registered. 164 // '*enable_jit_by_default' is set to true if we should try to JIT using this 165 // device when the JIT is enabled via the Session OptimizerOptions. 166 static bool GetCompilationDevice(const string& device_name, 167 const DeviceRegistration** registration); 168 169 // Registers all JIT kernels on JIT devices, if not already registered. 170 // Does nothing otherwise. 171 static void RegisterCompilationKernels(); 172 173 // Returns KernelDefs for compilation ops registered on 174 // 'compilation_device_name'. Does not include kernels registered as 175 // CompilationOnly, iff include_compilation_only_kernels=false. 176 static std::vector<const KernelDef*> DeviceKernels( 177 const string& compilation_device_name, 178 bool include_compilation_only_kernels); 179 180 // Returns all operations for which there are XLA kernels on any device. 181 static std::vector<string> GetAllRegisteredOps(); 182 183 // Returns (via `result`) the indices of inputs to `node_def` that must be 184 // compile-time constants. Returns an empty vector if the op is not 185 // registered. 186 // 187 // `result` is sorted. CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def,std::vector<int> * result)188 static Status CompileTimeConstantInputs(const NodeDef& node_def, 189 const OpDef& op_def, 190 std::vector<int>* result) { 191 return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def, 192 result); 193 } 194 CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def)195 static StatusOr<std::vector<int>> CompileTimeConstantInputs( 196 const NodeDef& node_def, const OpDef& op_def) { 197 std::vector<int> out; 198 TF_RETURN_IF_ERROR(CompileTimeConstantInputs(node_def, op_def, &out)); 199 return out; 200 } 201 202 // Returns (via `result`) the indices of inputs to `op_kernel` that must be 203 // compile-time constants. 204 // 205 // `result` is sorted. CompileTimeConstantInputs(const OpKernel & op_kernel,std::vector<int> * result)206 static Status CompileTimeConstantInputs(const OpKernel& op_kernel, 207 std::vector<int>* result) { 208 return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel, 209 /*op_def=*/nullptr, result); 210 } 211 212 // Return names of arguments for a given op which are supposed to be 213 // constants. 214 static const std::unordered_set<std::string>* 215 CompileTimeConstantInputArgNames(const string& op); 216 217 // Returns true if `op` is a "metadata" op, one that only looks at the shapes 218 // of its operands and not their values. 219 static bool IsMetadataOp(const string& op); 220 221 private: 222 friend class XlaBackendRegistrar; 223 friend class XlaOpRegistrar; 224 friend class XlaOpRegistrationBuilder; 225 226 static XlaOpRegistry& Instance(); 227 228 XlaOpRegistry(); 229 ~XlaOpRegistry(); 230 231 mutex mutex_; 232 233 // Describes an XLA backend. 234 struct Backend { 235 // Which types are supported by this device? 236 std::set<DataType> supported_types; 237 238 // The per-backend operator filter function. See the comment on 239 // RegisterBackend() for details. 240 BackendOpFilter op_filter; 241 242 // KernelDefs built by RegisterCompilationKernels() for each op supported 243 // by the device. 244 std::vector<std::unique_ptr<KernelDef>> kernel_defs; 245 }; 246 247 // Map from compilation device names to a description of the backend. 248 std::unordered_map<string, Backend> backends_ TF_GUARDED_BY(mutex_); 249 250 // Map from Tensorflow device names to the corresponding JIT device metadata. 251 std::unordered_map<string, DeviceRegistration> compilation_devices_ 252 TF_GUARDED_BY(mutex_); 253 254 // A description of a Tensorflow operator that can be compiled to XLA. 255 struct OpRegistration { 256 string name; 257 258 // Should this operator be registered only on compilation devices, without a 259 // dummy kernel registered on the corresponding XLA device? 260 bool compilation_only = false; 261 262 // Should we allow resource types for type attributes? Used by _Arg to 263 // allow DT_RESOURCE. 264 bool allow_resource_types = false; 265 266 // Should we allow variant types for type attributes? Used by While to 267 // allow TensorList which is of type DT_VARIANT. 268 bool allow_variant_types = false; 269 270 // Should we allow string type for type attributes? Used by PartitionedCall 271 // to allow DT_STRING. 272 bool allow_string_type = false; 273 274 // Mapping from attribute name to a list of supported types. 275 std::unordered_map<string, std::set<DataType>> type_constraints; 276 277 // An optional allowlist of devices. If there is no allowlist, all devices 278 // are permitted. 279 bool has_device_allowlist = false; 280 std::unordered_set<string> device_allowlist; 281 282 // Names of arguments that must be compile-time constants. 283 std::unordered_set<string> compile_time_constant_inputs; 284 285 // True if this is a "metadata" op, one that only looks at the shapes of its 286 // operands and not their values. 287 bool is_metadata_op = false; 288 289 std::string label; 290 291 // Factory used to build OpKernels that perform symbolic execution. 292 Factory factory; 293 }; 294 295 // Returns true if registrations x and y can both be added to the registry. 296 // This is always the case if they refer to different ops. If they refer to 297 // the same op name, they must: have the same values for compilation_only, 298 // allow_resource_types and allow_variant_types; use a device_allowlist; and 299 // their allowlists must not intersect. 300 static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); 301 302 static Status CompileTimeConstantInputs(const NodeDef& node_def, 303 const OpKernel* op_kernel, 304 const OpDef* op_def, 305 std::vector<int>* result); 306 307 // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. 308 // Registrations present under the same key must satisfy IsCompatible above, 309 // and this is checked during registration. 310 std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_ 311 TF_GUARDED_BY(mutex_); 312 313 // Have we already registered the JIT kernels on the JIT devices? 314 bool jit_kernels_registered_ = false; 315 316 // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel 317 // registrations created by RegisterCompilationKernels() and 318 // RegisterDeviceKernels(). 319 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 320 kernel_registrars_ TF_GUARDED_BY(mutex_); 321 }; 322 323 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: 324 // REGISTER_XLA_OP(Name("Add"), AddOp); 325 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add". 326 // 327 // We don't use a variadic macro here because we don't expect JIT operators to 328 // be templated. 329 330 #define REGISTER_XLA_OP(NAME, OP) \ 331 REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) 332 333 class XlaOpRegistrationBuilder { 334 public: 335 // Starts an operator registration chain. 336 static XlaOpRegistrationBuilder Name(absl::string_view name); 337 338 // Specifies a allowlist of devices on which the operator may run. 339 XlaOpRegistrationBuilder& Device(absl::string_view devices); 340 XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices); 341 342 // Specifies a type constraint for a type variable attribute. Each constraint 343 // specifies the set of types that the type variable may assume. 344 XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, 345 DataType allowed); 346 347 XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, 348 absl::Span<const DataType> allowed); 349 350 // Specifies that a dummy copy of this operator should not be registered on 351 // XLA_* devices, but may be used during compilation. 352 XlaOpRegistrationBuilder& CompilationOnly(); 353 354 // Allow DT_RESOURCE types for type parameters. 355 XlaOpRegistrationBuilder& AllowResourceTypes(); 356 357 // Allow DT_VARIANT types for type parameters. 358 XlaOpRegistrationBuilder& AllowVariantTypes(); 359 360 // Allow DT_STRING type for type parameters. 361 XlaOpRegistrationBuilder& AllowStringType(); 362 363 // Mark 'input_name' as an argument whose value must be known at compile-time. 364 XlaOpRegistrationBuilder& CompileTimeConstantInput( 365 absl::string_view input_name); 366 367 // Mark this op as a "metadata" op, one that only looks at the shapes of its 368 // operands and not their values. 369 XlaOpRegistrationBuilder& IsMetadataOp(); 370 371 // Specifies a particular value for the "_kernel" attr. 372 XlaOpRegistrationBuilder& Label(std::string label); 373 374 std::unique_ptr<XlaOpRegistry::OpRegistration> Build( 375 XlaOpRegistry::Factory factory); 376 377 private: 378 XlaOpRegistrationBuilder(absl::string_view name); 379 380 std::unique_ptr<XlaOpRegistry::OpRegistration> registration_; 381 }; 382 383 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage: 384 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); 385 #define REGISTER_XLA_BACKEND(NAME, ...) \ 386 REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) 387 388 // Implementation details. 389 390 class XlaOpRegistrar { 391 public: 392 XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration); 393 }; 394 395 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ 396 REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) 397 398 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ 399 static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ 400 ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ 401 [](::tensorflow::OpKernelConstruction* context) \ 402 -> ::tensorflow::OpKernel* { return new OP(context); })); 403 404 class XlaBackendRegistrar { 405 public: 406 XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types, 407 XlaOpRegistry::BackendOpFilter op_filter = nullptr); 408 }; 409 410 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ 411 REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) 412 413 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ 414 static ::tensorflow::XlaBackendRegistrar \ 415 xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); 416 417 } // namespace tensorflow 418 419 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 420