1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_ 17 #define TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_ 18 19 #include <functional> 20 #include <map> 21 #include <unordered_map> 22 23 #include "tensorflow/core/common_runtime/device/device_id.h" 24 #include "tensorflow/core/common_runtime/process_state.h" 25 #include "tensorflow/core/common_runtime/shared_counter.h" 26 #include "tensorflow/core/framework/allocator.h" 27 #include "tensorflow/core/platform/mutex.h" 28 #include "tensorflow/core/platform/thread_annotations.h" 29 #include "tensorflow/core/platform/types.h" 30 #include "tensorflow/core/protobuf/config.pb.h" 31 32 namespace tensorflow { 33 34 class Allocator; 35 class PluggableDeviceBFCAllocator; 36 class PluggableDeviceSimpleAllocator; 37 class PoolAllocator; 38 39 // Singleton that manages per-process state when PluggableDevices are present. 40 class PluggableDeviceProcessState { 41 public: 42 // Singleton that manages each platform's per-process state. e.g. allocation 43 // of shared resource. 44 static PluggableDeviceProcessState* singleton(const string& device_type, 45 const string& platform_name); 46 47 // Query whether any PluggableDevice has been created so far. 48 // Disable thread safety analysis since a race is benign here. HasPluggableDevice()49 bool HasPluggableDevice() const TF_NO_THREAD_SAFETY_ANALYSIS { 50 return pluggable_device_enabled_; 51 } 52 53 // Set the flag to indicate a PluggableDevice has been created. 54 // Disable thread safety analysis since a race is benign here. EnablePluggableDevice()55 void EnablePluggableDevice() TF_NO_THREAD_SAFETY_ANALYSIS { 56 pluggable_device_enabled_ = true; 57 } 58 59 // Returns the one PluggableDevice allocator used for the indexed 60 // PluggableDevice. Note that this is a system PluggableDevice index. 61 // 62 // 'total_bytes' is the total number of bytes that should be made 63 // available to the allocator. The first call to this function for 64 // a given tf_device_id creates the allocator, so only the 65 // total_bytes used on that first call is used. 66 // 67 // 'allocator_type' describes the type of algorithm to use for the 68 // underlying allocator. REQUIRES: Must be a valid type (see 69 // config.proto for the list of supported strings.). 70 // 71 // REQUIRES: tf_device_id must be a valid id for a PluggableDevice 72 // available in the current system environment. Otherwise returns nullptr. 73 virtual Allocator* GetPluggableDeviceAllocator(const GPUOptions& options, 74 TfDeviceId tf_device_id, 75 size_t total_bytes); 76 NumPluggableDeviceAllocators()77 int NumPluggableDeviceAllocators() { 78 mutex_lock l(mu_); 79 return pluggable_device_allocators_.size(); 80 } 81 82 virtual Allocator* GetPluggableDeviceHostAllocator(int numa_node); 83 84 // Returns bus_id for the given PluggableDevice id. 85 virtual int BusIdForPluggableDevice(TfDeviceId tf_device_id); 86 87 protected: 88 // PluggableDeviceProcessState is a singleton that should not normally be 89 // deleted except at process shutdown. 90 PluggableDeviceProcessState(const string& device_type, 91 const string& platform_name); ~PluggableDeviceProcessState()92 virtual ~PluggableDeviceProcessState() {} 93 mem_desc_map()94 ProcessState::MDMap* mem_desc_map() { 95 if (process_state_) return &process_state_->mem_desc_map_; 96 return nullptr; 97 } 98 99 static PluggableDeviceProcessState* instance_; 100 ProcessState* process_state_; // Not owned. 101 bool pluggable_device_enabled_; 102 const string device_type_; 103 const string platform_name_; 104 mutex mu_; 105 106 struct AllocatorParts { 107 std::unique_ptr<Allocator> allocator; 108 Allocator* device_allocator; 109 SubAllocator* sub_allocator; // owned by allocator 110 }; 111 112 std::vector<AllocatorParts> pluggable_device_allocators_ TF_GUARDED_BY(mu_); 113 std::vector<std::vector<SubAllocator::Visitor>> pluggable_device_visitors_ 114 TF_GUARDED_BY(mu_); 115 116 std::vector<AllocatorParts> pluggable_device_host_allocators_ 117 TF_GUARDED_BY(mu_); 118 std::vector<std::vector<SubAllocator::Visitor>> 119 pluggable_device_host_alloc_visitors_ TF_GUARDED_BY(mu_); 120 std::vector<std::vector<SubAllocator::Visitor>> 121 pluggable_device_host_free_visitors_ TF_GUARDED_BY(mu_); 122 }; 123 124 } // namespace tensorflow 125 126 #endif // TENSORFLOW_CORE_COMMON_RUNTIME_PLUGGABLE_DEVICE_PLUGGABLE_DEVICE_PROCESS_STATE_H_ 127