1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 18 19 #include <functional> 20 #include <memory> 21 #include <set> 22 #include <unordered_map> 23 #include <vector> 24 25 #include "tensorflow/core/common_runtime/device_factory.h" 26 #include "tensorflow/core/common_runtime/local_device.h" 27 #include "tensorflow/core/framework/device_base.h" 28 #include "tensorflow/core/framework/tensor.h" 29 #include "tensorflow/core/framework/types.pb.h" 30 #include "tensorflow/core/lib/core/status.h" 31 #include "tensorflow/core/platform/mem.h" 32 #include "tensorflow/core/platform/mutex.h" 33 #include "tensorflow/core/platform/thread_annotations.h" 34 #include "tensorflow/core/public/session_options.h" 35 36 namespace tensorflow { 37 38 // Names of the XLA compilation devices. These are not user-visible, and are 39 // used internally by the Tensorflow/XLA bridge to perform symbolic execution of 40 // a Tensorflow graph. 41 42 extern const char* const DEVICE_CPU_XLA_JIT; // "CPU_XLA_JIT" 43 extern const char* const DEVICE_GPU_XLA_JIT; // "GPU_XLA_JIT" 44 45 extern const char* const DEVICE_XLA_CPU; 46 extern const char* const DEVICE_XLA_GPU; 47 48 constexpr std::array<DataType, 4> kFloatTypes = { 49 {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_BFLOAT16}}; 50 constexpr std::array<DataType, 12> kNumericTypes = { 51 {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, 52 DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_COMPLEX128, DT_BFLOAT16}}; 53 54 constexpr std::array<DataType, 16> kCpuAllTypes = { 55 {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, 56 DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, 57 DT_COMPLEX128, DT_BOOL, DT_BFLOAT16}}; 58 59 constexpr std::array<DataType, 15> kGpuAllTypes = { 60 {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, 61 DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, 62 DT_BFLOAT16}}; 63 64 // Class that manages registrations of operators and devices for the XLA JIT. 65 // Not thread-safe. 66 class XlaOpRegistry { 67 public: 68 typedef OpKernel* (*Factory)(OpKernelConstruction*); 69 70 enum class AutoclusteringPolicy { 71 // Enable autoclustering if the user requests it, e.g., via 72 // experimental_jit_scope. Does not autocluster if the JIT is enabled 73 // globally (e.g., via the OptimizerOptions in the TF session 74 // configuration.) 75 kIfExplicitlyRequested, 76 // Enable autoclustering if explicitly requested, or if the JIT is enabled 77 // globally in the session options, or via TF_XLA_FLAGS=--tf_xla_auto_jit=N. 78 kIfEnabledGlobally, 79 // Always try to autocluster ops placed on this device. 80 kAlways, 81 }; 82 83 // Describes how to compile operators assigned to a device. 84 struct DeviceRegistration { 85 // The name of the an XLA compilation device to use to compile code. 86 string compilation_device_name; 87 88 // When should we autocluster operators assigned to this device? 89 AutoclusteringPolicy autoclustering_policy; 90 91 // Enable compilation of operators that use DT_RESOURCE types? 92 bool compile_all_resource_ops = false; 93 }; 94 95 // Registers an XLA backend. `compilation_device_name` is the name of the 96 // device used for symbolic execution during compilation. `supported_types` 97 // is the list of non-resource types supported by the device. Each operators 98 // will be registered for the intersection of the operator's supported types 99 // and the device's supported types. `backend_op_filter` is a function used 100 // to exclude or modify operator registrations on the device; it may be 101 // nullptr, in which case all ops are included. 102 // `backend_op_filter` should return true if the op should be registered on 103 // the device; it may optionally modify the KernelDef. 104 typedef bool (*BackendOpFilter)(KernelDef* kdef); 105 static void RegisterBackend(const string& compilation_device_name, 106 absl::Span<const DataType> supported_types, 107 BackendOpFilter op_filter); 108 109 // Returns the names of the registered backends. 110 static std::vector<string> BackendNames(); 111 112 // Returns true iff a backend with the given name is registered. 113 static bool IsBackendRegistered(const string& name); 114 115 // Registers `device_name` for XLA compilation, using information from 116 // `registration`. 117 // Does nothing if a registration for `device_name` already exists. 118 static void RegisterCompilationDevice(const string& device_name, 119 const DeviceRegistration& registration); 120 121 // Returns the JIT device name associated with 'device_name', setting 122 // 'jit_device_name', 'requires_jit', and 'enabled_jit_by_default', if they 123 // are not null. Returns false and leaves the outputs unchanged if no matching 124 // JIT device is registered. 125 // '*enable_jit_by_default' is set to true if we should try to JIT using this 126 // device when the JIT is enabled via the Session OptimizerOptions. 127 static bool GetCompilationDevice(const string& device_name, 128 const DeviceRegistration** registration); 129 130 // Registers all JIT kernels on JIT devices, if not already registered. 131 // Does nothing otherwise. 132 static void RegisterCompilationKernels(); 133 134 // Returns KernelDefs for compilation ops registered on 135 // 'compilation_device_name'. Does not include kernels registered as 136 // CompilationOnly, iff include_compilation_only_kernels=false. 137 static std::vector<const KernelDef*> DeviceKernels( 138 const string& compilation_device_name, 139 bool include_compilation_only_kernels); 140 141 // Returns all operations for which there are XLA kernels on any device. 142 static std::vector<string> GetAllRegisteredOps(); 143 144 // Returns (via `result`) the indices of inputs to `node_def` that must be 145 // compile-time constants. Returns an empty vector if the op is not 146 // registered. 147 // 148 // `result` is sorted. CompileTimeConstantInputs(const NodeDef & node_def,const OpDef & op_def,std::vector<int> * result)149 static Status CompileTimeConstantInputs(const NodeDef& node_def, 150 const OpDef& op_def, 151 std::vector<int>* result) { 152 return CompileTimeConstantInputs(node_def, /*op_kernel=*/nullptr, &op_def, 153 result); 154 } 155 156 // Returns (via `result`) the indices of inputs to `op_kernel` that must be 157 // compile-time constants. 158 // 159 // `result` is sorted. CompileTimeConstantInputs(const OpKernel & op_kernel,std::vector<int> * result)160 static Status CompileTimeConstantInputs(const OpKernel& op_kernel, 161 std::vector<int>* result) { 162 return CompileTimeConstantInputs(op_kernel.def(), /*op_kernel=*/&op_kernel, 163 /*op_def=*/nullptr, result); 164 } 165 166 // Returns true if `op` is a "metadata" op, one that only looks at the shapes 167 // of its operands and not their values. 168 static bool IsMetadataOp(const string& op); 169 170 private: 171 friend class XlaBackendRegistrar; 172 friend class XlaOpRegistrar; 173 friend class XlaOpRegistrationBuilder; 174 175 static XlaOpRegistry& Instance(); 176 177 XlaOpRegistry(); 178 ~XlaOpRegistry(); 179 180 mutex mutex_; 181 182 // Describes an XLA backend. 183 struct Backend { 184 // Which types are supported by this device? 185 std::set<DataType> supported_types; 186 187 // The per-backend operator filter function. See the comment on 188 // RegisterBackend() for details. 189 BackendOpFilter op_filter; 190 191 // KernelDefs built by RegisterCompilationKernels() for each op supported 192 // by the device. 193 std::vector<std::unique_ptr<KernelDef>> kernel_defs; 194 }; 195 196 // Map from compilation device names to a description of the backend. 197 std::unordered_map<string, Backend> backends_ GUARDED_BY(mutex_); 198 199 // Map from Tensorflow device names to the corresponding JIT device metadata. 200 std::unordered_map<string, DeviceRegistration> compilation_devices_ 201 GUARDED_BY(mutex_); 202 203 // A description of a Tensorflow operator that can be compiled to XLA. 204 struct OpRegistration { 205 string name; 206 207 // Should this operator be registered only on compilation devices, without a 208 // dummy kernel registered on the corresponding XLA device? 209 bool compilation_only = false; 210 211 // Should we allow resource types for type attributes? Used by _Arg to 212 // allow DT_RESOURCE. 213 bool allow_resource_types = false; 214 215 // Should we allow variant types for type attributes? Used by While to 216 // allow TensorList which is of type DT_VARIANT. 217 bool allow_variant_types = false; 218 219 // Mapping from attribute name to a list of supported types. 220 std::unordered_map<string, std::set<DataType>> type_constraints; 221 222 // An optional whitelist of devices. If there is no whitelist, all devices 223 // are permitted. 224 bool has_device_whitelist = false; 225 std::unordered_set<string> device_whitelist; 226 227 // Names of arguments that must be compile-time constants. 228 std::unordered_set<string> compile_time_constant_inputs; 229 230 // True if this is a "metadata" op, one that only looks at the shapes of its 231 // operands and not their values. 232 bool is_metadata_op = false; 233 234 // Factory used to build OpKernels that perform symbolic execution. 235 Factory factory; 236 }; 237 238 // Returns true if registrations x and y can both be added to the registry. 239 // This is always the case if they refer to different ops. If they refer to 240 // the same op name, they must: have the same values for compilation_only, 241 // allow_resource_types and allow_variant_types; use a device_whitelist; and 242 // their whitelists must not intersect. 243 static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); 244 245 static Status CompileTimeConstantInputs(const NodeDef& node_def, 246 const OpKernel* op_kernel, 247 const OpDef* op_def, 248 std::vector<int>* result); 249 250 // Map from operator name to OpRegistrations, populated by REGISTER_XLA_OP. 251 // Registrations present under the same key must satisfy IsCompatible above, 252 // and this is checked during registration. 253 std::unordered_map<string, std::vector<std::unique_ptr<OpRegistration>>> ops_ 254 GUARDED_BY(mutex_); 255 256 // Have we already registered the JIT kernels on the JIT devices? 257 bool jit_kernels_registered_ = false; 258 259 // Holds ownership of OpKernelRegistrars that represent the Tensorflow kernel 260 // registrations created by RegisterCompilationKernels() and 261 // RegisterDeviceKernels(). 262 std::vector<std::unique_ptr<kernel_factory::OpKernelRegistrar>> 263 kernel_registrars_ GUARDED_BY(mutex_); 264 }; 265 266 // REGISTER_XLA_OP() registers an XLA OpKernel by name, for example: 267 // REGISTER_XLA_OP(Name("Add"), AddOp); 268 // where 'AddOp' is the name of a JIT OpKernel class that implements "Add". 269 // 270 // We don't use a variadic macro here because we don't expect JIT operators to 271 // be templated. 272 273 #define REGISTER_XLA_OP(NAME, OP) \ 274 REGISTER_XLA_OP_UNIQ_HELPER(__COUNTER__, NAME, OP) 275 276 class XlaOpRegistrationBuilder { 277 public: 278 // Starts an operator registration chain. 279 static XlaOpRegistrationBuilder Name(absl::string_view name); 280 281 // Specifies a whitelist of devices on which the operator may run. 282 XlaOpRegistrationBuilder& Device(absl::string_view devices); 283 XlaOpRegistrationBuilder& Device(absl::Span<const absl::string_view> devices); 284 285 // Specifies a type constraint for a type variable attribute. Each constraint 286 // specifies the set of types that the type variable may assume. 287 XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, 288 DataType allowed); 289 290 XlaOpRegistrationBuilder& TypeConstraint(absl::string_view attr_name, 291 absl::Span<const DataType> allowed); 292 293 // Specifies that a dummy copy of this operator should not be registered on 294 // XLA_* devices, but may be used during compilation. 295 XlaOpRegistrationBuilder& CompilationOnly(); 296 297 // Allow DT_RESOURCE types for type parameters. 298 XlaOpRegistrationBuilder& AllowResourceTypes(); 299 300 // Allow DT_VARIANT types for type parameters. 301 XlaOpRegistrationBuilder& AllowVariantTypes(); 302 303 // Mark 'input_name' as an argument whose value must be known at compile-time. 304 XlaOpRegistrationBuilder& CompileTimeConstantInput( 305 absl::string_view input_name); 306 307 // Mark this op as a "metadata" op, one that only looks at the shapes of its 308 // operands and not their values. 309 XlaOpRegistrationBuilder& IsMetadataOp(); 310 311 std::unique_ptr<XlaOpRegistry::OpRegistration> Build( 312 XlaOpRegistry::Factory factory); 313 314 private: 315 XlaOpRegistrationBuilder(absl::string_view name); 316 317 std::unique_ptr<XlaOpRegistry::OpRegistration> registration_; 318 }; 319 320 // REGISTER_XLA_BACKEND() registers an XLA backend. Example usage: 321 // REGISTER_XLA_BACKEND(DEVICE_GPU_XLA_JIT, kGpuAllTypes, GpuOpFilter); 322 #define REGISTER_XLA_BACKEND(NAME, ...) \ 323 REGISTER_XLA_BACKEND_UNIQ_HELPER(__COUNTER__, NAME, __VA_ARGS__) 324 325 // Implementation details. 326 327 class XlaOpRegistrar { 328 public: 329 XlaOpRegistrar(std::unique_ptr<XlaOpRegistry::OpRegistration> registration); 330 }; 331 332 #define REGISTER_XLA_OP_UNIQ_HELPER(COUNTER, BUILDER, OP) \ 333 REGISTER_XLA_OP_UNIQ(COUNTER, BUILDER, OP) 334 335 #define REGISTER_XLA_OP_UNIQ(CTR, BUILDER, OP) \ 336 static ::tensorflow::XlaOpRegistrar xla_op_registrar__body__##CTR##__object( \ 337 ::tensorflow::XlaOpRegistrationBuilder::BUILDER.Build( \ 338 [](::tensorflow::OpKernelConstruction* context) \ 339 -> ::tensorflow::OpKernel* { return new OP(context); })); 340 341 class XlaBackendRegistrar { 342 public: 343 XlaBackendRegistrar(absl::string_view name, absl::Span<const DataType> types, 344 XlaOpRegistry::BackendOpFilter op_filter = nullptr); 345 }; 346 347 #define REGISTER_XLA_BACKEND_UNIQ_HELPER(COUNTER, NAME, ...) \ 348 REGISTER_XLA_BACKEND_UNIQ(COUNTER, NAME, __VA_ARGS__) 349 350 #define REGISTER_XLA_BACKEND_UNIQ(CTR, NAME, ...) \ 351 static ::tensorflow::XlaBackendRegistrar \ 352 xla_backend_registrar__body__##CTR##__object(NAME, __VA_ARGS__); 353 354 } // namespace tensorflow 355 356 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_OP_REGISTRY_H_ 357