1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ 17 #define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ 18 19 #include "tensorflow/compiler/jit/xla_compilation_cache.h" 20 #include "tensorflow/compiler/jit/xla_device.h" 21 #include "tensorflow/stream_executor/tf_allocator_adapter.h" 22 23 namespace tensorflow { 24 25 // Holds some information about the platform on which an 26 // XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of 27 // abstraction for normal and XLA devices. 28 class XlaPlatformInfo { 29 public: XlaPlatformInfo()30 XlaPlatformInfo() : device_type_("") {} 31 XlaPlatformInfo(XlaPlatformInfo&&) = default; XlaPlatformInfo(const DeviceType device_type,se::Platform::Id platform_id,const XlaDevice::Metadata * xla_device_metadata,se::DeviceMemoryAllocator * device_allocator)32 explicit XlaPlatformInfo(const DeviceType device_type, 33 se::Platform::Id platform_id, 34 const XlaDevice::Metadata* xla_device_metadata, 35 se::DeviceMemoryAllocator* device_allocator) 36 : device_type_(device_type), 37 platform_id_(platform_id), 38 xla_device_metadata_(xla_device_metadata), 39 device_allocator_(device_allocator) {} 40 41 XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default; 42 UseMultipleStreams()43 bool UseMultipleStreams() const { 44 return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams(); 45 } 46 47 // Non-null only when run on an XLA device. custom_allocator()48 se::DeviceMemoryAllocator* custom_allocator() const { 49 return device_allocator_; 50 } 51 device_type()52 DeviceType device_type() const { return device_type_; } 53 54 // This is equal to xla_device_metadata()->platform()->id() if 55 // xla_device_metadata() is not nullptr. platform_id()56 se::Platform::Id platform_id() const { return platform_id_; } 57 58 // This may be null if the op this XlaPlatformInfo is for was not placed on an 59 // XLA device. xla_device_metadata()60 const XlaDevice::Metadata* xla_device_metadata() const { 61 return xla_device_metadata_; 62 } is_on_xla_device()63 bool is_on_xla_device() const { return xla_device_metadata() != nullptr; } 64 65 private: 66 DeviceType device_type_; 67 se::Platform::Id platform_id_; 68 69 // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the 70 // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the 71 // XlaLaunch/_XlaCompile/_XlaRun OpKernel. 72 const XlaDevice::Metadata* xla_device_metadata_; 73 74 // If the op associated with this XlaPlatformInfo is placed on an XLA device 75 // then device_allocator_ is the xla::Backend's memory allocator. If the op 76 // is placed on a regular CPU or GPU device then device_allocator_ is null. 77 se::DeviceMemoryAllocator* device_allocator_; 78 79 TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo); 80 }; 81 82 // Returns created XLA compilation cache. 83 Status BuildXlaCompilationCache(DeviceBase* dev, 84 const XlaPlatformInfo& platform_info, 85 XlaCompilationCache** cache); 86 87 // Returns information about the platform from kernel context. 88 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device); 89 90 // Returns allocator from platform info if non-null, or populate and return a 91 // pointer to the allocator adapter with allocator from context. 92 // 93 // This is necessary because for XLA devices the underlying TF allocator returns 94 // dummy tensors. 95 // 96 // `stream` parameter is nullable when running on host. 97 se::DeviceMemoryAllocator* GetAllocator( 98 absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter, 99 DeviceBase* device, se::Stream* stream, 100 const XlaPlatformInfo& platform_info); 101 102 // Returns created options for the XLA compiler, and writes the used allocator 103 // into `tf_allocator_adapter`. 104 XlaCompiler::Options GenerateCompilerOptions( 105 const XlaCompilationCache& cache, 106 const FunctionLibraryRuntime& function_library, DeviceBase* device, 107 se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars, 108 absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter); 109 110 } // namespace tensorflow 111 112 #endif // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_ 113