1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_ 19 20 #include <memory> 21 #include <vector> 22 #include <shared_mutex> 23 #include "utils/hash_map.h" 24 #include "utils/ms_utils.h" 25 #include "include/backend/device_address.h" 26 27 namespace mindspore { 28 namespace runtime { 29 using DeviceTensor = mindspore::device::DeviceAddress; 30 using DeviceTensorType = mindspore::device::DeviceType; 31 using DeviceTensorPtr = std::shared_ptr<DeviceTensor>; 32 33 // The device tensor mainly includes address ptr, size and reference count, 34 // which represents the basic data structure of kernel launch and transfers between actors. 35 // Some device tensors (such as weights and value nodes of graph) are fixed addresses and persistent, 36 // so they are more suitable for store and can be obtained when they are used by actor. 37 class DeviceTensorStore { 38 public: GetInstance()39 static DeviceTensorStore &GetInstance() { 40 static DeviceTensorStore instance; 41 return instance; 42 } 43 44 // Support value modifiable. Insert(AnfNode * key,const DeviceTensorPtr & value)45 void Insert(AnfNode *key, const DeviceTensorPtr &value) { 46 MS_EXCEPTION_IF_NULL(key); 47 MS_EXCEPTION_IF_NULL(value); 48 std::unique_lock<std::shared_mutex> lock(map_mutex_); 49 const auto &iter = device_tensors_.find(key); 50 if (iter == device_tensors_.end()) { 51 device_tensors_[key].emplace_back(value); 52 return; 53 } 54 55 for (size_t i = 0; i < iter->second.size(); ++i) { 56 MS_EXCEPTION_IF_NULL(iter->second[i]); 57 if (iter->second[i]->GetDeviceType() == value->GetDeviceType()) { 58 if (iter->second[i]->GetSize() != value->GetSize()) { 59 MS_LOG(INFO) << "The update size:" << value->GetSize() 60 << " is not equal of the old size:" << iter->second[i]->GetSize() 61 << " for node:" << key->fullname_with_scope() 62 << ". Please check whether it causes accuracy problem."; 63 } 64 iter->second[i] = value; 65 return; 66 } 67 } 68 iter->second.emplace_back(value); 69 } 70 Remove(AnfNode * key)71 void Remove(AnfNode *key) { 72 MS_EXCEPTION_IF_NULL(key); 73 std::unique_lock<std::shared_mutex> lock(map_mutex_); 74 const auto &iter = device_tensors_.find(key); 75 if (iter != device_tensors_.end()) { 76 (void)device_tensors_.erase(iter); 77 } 78 } 79 Fetch(AnfNode * key)80 std::vector<DeviceTensorPtr> Fetch(AnfNode *key) const { 81 MS_EXCEPTION_IF_NULL(key); 82 std::shared_lock<std::shared_mutex> lock(map_mutex_); 83 const auto &iter = device_tensors_.find(key); 84 if (iter != device_tensors_.end()) { 85 return iter->second; 86 } else { 87 std::vector<DeviceTensorPtr> empty_value; 88 return empty_value; 89 } 90 } 91 Fetch(AnfNode * key,DeviceTensorType value_type)92 DeviceTensorPtr Fetch(AnfNode *key, DeviceTensorType value_type) const { 93 MS_EXCEPTION_IF_NULL(key); 94 std::shared_lock<std::shared_mutex> lock(map_mutex_); 95 const auto &iter = device_tensors_.find(key); 96 if (iter != device_tensors_.end()) { 97 for (const auto &device_tensor : iter->second) { 98 MS_EXCEPTION_IF_NULL(device_tensor); 99 if (device_tensor->GetDeviceType() == value_type) { 100 return device_tensor; 101 } 102 } 103 } 104 return nullptr; 105 } 106 Clear()107 void Clear() { 108 std::unique_lock<std::shared_mutex> lock(map_mutex_); 109 device_tensors_.clear(); 110 } 111 112 private: 113 DeviceTensorStore() = default; 114 ~DeviceTensorStore() = default; 115 DISABLE_COPY_AND_ASSIGN(DeviceTensorStore); 116 117 // The data storage of device tensor. Key is the anf node, value is the vector which may contains the device 118 // tensors from different devices. 119 mindspore::HashMap<AnfNode *, std::vector<DeviceTensorPtr>> device_tensors_; 120 // Read/Write lock for map. 121 mutable std::shared_mutex map_mutex_; 122 }; 123 } // namespace runtime 124 } // namespace mindspore 125 126 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_ 127