• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
17 #define TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
18 
19 #include "tensorflow/compiler/jit/xla_compilation_cache.h"
20 #include "tensorflow/compiler/jit/xla_device.h"
21 #include "tensorflow/stream_executor/tf_allocator_adapter.h"
22 
23 namespace tensorflow {
24 
25 // Holds some information about the platform on which an
26 // XlaLaunch/_XlaCompile/_XlaRun op must run on. Provides a common layer of
27 // abstraction for normal and XLA devices.
28 class XlaPlatformInfo {
29  public:
XlaPlatformInfo()30   XlaPlatformInfo() : device_type_("") {}
31   XlaPlatformInfo(XlaPlatformInfo&&) = default;
XlaPlatformInfo(const DeviceType device_type,se::Platform::Id platform_id,const XlaDevice::Metadata * xla_device_metadata,se::DeviceMemoryAllocator * device_allocator)32   explicit XlaPlatformInfo(const DeviceType device_type,
33                            se::Platform::Id platform_id,
34                            const XlaDevice::Metadata* xla_device_metadata,
35                            se::DeviceMemoryAllocator* device_allocator)
36       : device_type_(device_type),
37         platform_id_(platform_id),
38         xla_device_metadata_(xla_device_metadata),
39         device_allocator_(device_allocator) {}
40 
41   XlaPlatformInfo& operator=(XlaPlatformInfo&& other) = default;
42 
UseMultipleStreams()43   bool UseMultipleStreams() const {
44     return xla_device_metadata_ && xla_device_metadata_->UseMultipleStreams();
45   }
46 
47   // Non-null only when run on an XLA device.
custom_allocator()48   se::DeviceMemoryAllocator* custom_allocator() const {
49     return device_allocator_;
50   }
51 
device_type()52   DeviceType device_type() const { return device_type_; }
53 
54   // This is equal to xla_device_metadata()->platform()->id() if
55   // xla_device_metadata() is not nullptr.
platform_id()56   se::Platform::Id platform_id() const { return platform_id_; }
57 
58   // This may be null if the op this XlaPlatformInfo is for was not placed on an
59   // XLA device.
xla_device_metadata()60   const XlaDevice::Metadata* xla_device_metadata() const {
61     return xla_device_metadata_;
62   }
is_on_xla_device()63   bool is_on_xla_device() const { return xla_device_metadata() != nullptr; }
64 
65  private:
66   DeviceType device_type_;
67   se::Platform::Id platform_id_;
68 
69   // xla_device_metadata_ lives in the tensorflow::DeviceBase in which the
70   // XlaLaunch/_XlaCompile/_XlaRun op is placed and thus does not die before the
71   // XlaLaunch/_XlaCompile/_XlaRun OpKernel.
72   const XlaDevice::Metadata* xla_device_metadata_;
73 
74   // If the op associated with this XlaPlatformInfo is placed on an XLA device
75   // then device_allocator_ is the xla::Backend's memory allocator.  If the op
76   // is placed on a regular CPU or GPU device then device_allocator_ is null.
77   se::DeviceMemoryAllocator* device_allocator_;
78 
79   TF_DISALLOW_COPY_AND_ASSIGN(XlaPlatformInfo);
80 };
81 
82 // Returns created XLA compilation cache.
83 Status BuildXlaCompilationCache(DeviceBase* dev,
84                                 const XlaPlatformInfo& platform_info,
85                                 XlaCompilationCache** cache);
86 
87 // Returns information about the platform from kernel context.
88 XlaPlatformInfo XlaPlatformInfoFromDevice(DeviceBase* device);
89 
90 // Returns allocator from platform info if non-null, or populate and return a
91 // pointer to the allocator adapter with allocator from context.
92 //
93 // This is necessary because for XLA devices the underlying TF allocator returns
94 // dummy tensors.
95 //
96 // `stream` parameter is nullable when running on host.
97 se::DeviceMemoryAllocator* GetAllocator(
98     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter,
99     DeviceBase* device, se::Stream* stream,
100     const XlaPlatformInfo& platform_info);
101 
102 // Returns created options for the XLA compiler, and writes the used allocator
103 // into `tf_allocator_adapter`.
104 XlaCompiler::Options GenerateCompilerOptions(
105     const XlaCompilationCache& cache,
106     const FunctionLibraryRuntime& function_library, DeviceBase* device,
107     se::Stream* stream, const XlaPlatformInfo& platform_info, bool has_ref_vars,
108     absl::optional<se::TfAllocatorAdapter>* tf_allocator_adapter);
109 
110 }  // namespace tensorflow
111 
112 #endif  // TENSORFLOW_COMPILER_JIT_XLA_PLATFORM_INFO_H_
113