• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_
17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_
18 
19 #include <functional>
20 #include <map>
21 #include <unordered_map>
22 
23 #include "tensorflow/core/common_runtime/device/device_id.h"
24 #include "tensorflow/core/common_runtime/process_state.h"
25 #include "tensorflow/core/common_runtime/shared_counter.h"
26 #include "tensorflow/core/framework/allocator.h"
27 #include "tensorflow/core/platform/mutex.h"
28 #include "tensorflow/core/platform/thread_annotations.h"
29 #include "tensorflow/core/platform/types.h"
30 #include "tensorflow/core/protobuf/config.pb.h"
31 
32 namespace tensorflow {
33 
34 class Allocator;
35 class PluggableDeviceBFCAllocator;
36 class PluggableDeviceSimpleAllocator;
37 class PoolAllocator;
38 
39 // Singleton that manages per-process state when PluggableDevices are present.
40 class PluggableDeviceProcessState {
41  public:
42   // Singleton that manages each platform's per-process state. e.g. allocation
43   // of shared resource.
44   static PluggableDeviceProcessState* singleton(const string& device_type,
45                                                 const string& platform_name);
46 
47   // Query whether any PluggableDevice has been created so far.
48   // Disable thread safety analysis since a race is benign here.
HasPluggableDevice()49   bool HasPluggableDevice() const TF_NO_THREAD_SAFETY_ANALYSIS {
50     return pluggable_device_enabled_;
51   }
52 
53   // Set the flag to indicate a PluggableDevice has been created.
54   // Disable thread safety analysis since a race is benign here.
EnablePluggableDevice()55   void EnablePluggableDevice() TF_NO_THREAD_SAFETY_ANALYSIS {
56     pluggable_device_enabled_ = true;
57   }
58 
59   // Returns the one PluggableDevice allocator used for the indexed
60   // PluggableDevice. Note that this is a system PluggableDevice index.
61   //
62   // 'total_bytes' is the total number of bytes that should be made
63   // available to the allocator.  The first call to this function for
64   // a given tf_device_id creates the allocator, so only the
65   // total_bytes used on that first call is used.
66   //
67   // 'allocator_type' describes the type of algorithm to use for the
68   // underlying allocator.  REQUIRES: Must be a valid type (see
69   // config.proto for the list of supported strings.).
70   //
71   // REQUIRES: tf_device_id must be a valid id for a PluggableDevice
72   // available in the current system environment. Otherwise returns nullptr.
73   virtual Allocator* GetPluggableDeviceAllocator(const GPUOptions& options,
74                                                  TfDeviceId tf_device_id,
75                                                  size_t total_bytes);
76 
NumPluggableDeviceAllocators()77   int NumPluggableDeviceAllocators() {
78     mutex_lock l(mu_);
79     return pluggable_device_allocators_.size();
80   }
81 
82   virtual Allocator* GetPluggableDeviceHostAllocator(int numa_node);
83 
84   // Returns bus_id for the given PluggableDevice id.
85   virtual int BusIdForPluggableDevice(TfDeviceId tf_device_id);
86 
87  protected:
88   // PluggableDeviceProcessState is a singleton that should not normally be
89   // deleted except at process shutdown.
90   PluggableDeviceProcessState(const string& device_type,
91                               const string& platform_name);
~PluggableDeviceProcessState()92   virtual ~PluggableDeviceProcessState() {}
93 
mem_desc_map()94   ProcessState::MDMap* mem_desc_map() {
95     if (process_state_) return &process_state_->mem_desc_map_;
96     return nullptr;
97   }
98 
99   static PluggableDeviceProcessState* instance_;
100   ProcessState* process_state_;  // Not owned.
101   bool pluggable_device_enabled_;
102   const string device_type_;
103   const string platform_name_;
104   mutex mu_;
105 
106   struct AllocatorParts {
107     std::unique_ptr<Allocator> allocator;
108     Allocator* device_allocator;
109     SubAllocator* sub_allocator;  // owned by allocator
110   };
111 
112   std::vector<AllocatorParts> pluggable_device_allocators_ TF_GUARDED_BY(mu_);
113   std::vector<std::vector<SubAllocator::Visitor>> pluggable_device_visitors_
114       TF_GUARDED_BY(mu_);
115 
116   std::vector<AllocatorParts> pluggable_device_host_allocators_
117       TF_GUARDED_BY(mu_);
118   std::vector<std::vector<SubAllocator::Visitor>>
119       pluggable_device_host_alloc_visitors_ TF_GUARDED_BY(mu_);
120   std::vector<std::vector<SubAllocator::Visitor>>
121       pluggable_device_host_free_visitors_ TF_GUARDED_BY(mu_);
122 };
123 
124 }  // namespace tensorflow
125 
126 #endif  // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_
127