1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_ 19 20 #include <memory> 21 #include <unordered_map> 22 #include <vector> 23 #include "utils/ms_utils.h" 24 #include "runtime/device/device_address.h" 25 26 namespace mindspore { 27 namespace runtime { 28 using DeviceTensor = mindspore::device::DeviceAddress; 29 using DeviceTensorType = mindspore::device::DeviceAddressType; 30 using DeviceTensorPtr = std::shared_ptr<DeviceTensor>; 31 32 // The device tensor mainly includes address ptr, size and reference count, 33 // which represents the basic data structure of kernel launch and transfers between actors. 34 // Some device tensors (such as weights and value nodes of graph) are fixed addresses and persistent, 35 // so they are more suitable for store and can be obtained when they are used by actor. 36 class DeviceTensorStore { 37 public: GetInstance()38 static DeviceTensorStore &GetInstance() { 39 static DeviceTensorStore instance; 40 return instance; 41 } 42 43 // Support value modifiable. Insert(AnfNode * key,const DeviceTensorPtr & value)44 void Insert(AnfNode *key, const DeviceTensorPtr &value) { 45 MS_EXCEPTION_IF_NULL(key); 46 const auto &iter = device_tensors_.find(key); 47 if (iter == device_tensors_.end()) { 48 device_tensors_[key].emplace_back(value); 49 return; 50 } 51 52 for (size_t i = 0; i < iter->second.size(); ++i) { 53 if (iter->second[i]->DeviceType() == value->DeviceType()) { 54 iter->second[i] = value; 55 return; 56 } 57 } 58 iter->second.emplace_back(value); 59 } 60 Remove(AnfNode * key)61 void Remove(AnfNode *key) { 62 MS_EXCEPTION_IF_NULL(key); 63 const auto &iter = device_tensors_.find(key); 64 if (iter != device_tensors_.end()) { 65 (void)device_tensors_.erase(iter); 66 } 67 } 68 Fetch(AnfNode * key)69 std::vector<DeviceTensorPtr> Fetch(AnfNode *key) const { 70 MS_EXCEPTION_IF_NULL(key); 71 const auto &iter = device_tensors_.find(key); 72 if (iter != device_tensors_.end()) { 73 return iter->second; 74 } else { 75 std::vector<DeviceTensorPtr> empty_value; 76 return empty_value; 77 } 78 } 79 Fetch(AnfNode * key,DeviceTensorType value_type)80 DeviceTensor *Fetch(AnfNode *key, DeviceTensorType value_type) const { 81 MS_EXCEPTION_IF_NULL(key); 82 const auto &iter = device_tensors_.find(key); 83 if (iter != device_tensors_.end()) { 84 for (const auto &device_tensor : iter->second) { 85 MS_EXCEPTION_IF_NULL(device_tensor); 86 if (device_tensor->DeviceType() == value_type) { 87 return device_tensor.get(); 88 } 89 } 90 } 91 return nullptr; 92 } 93 Clear()94 void Clear() { device_tensors_.clear(); } 95 96 private: 97 DeviceTensorStore() = default; 98 ~DeviceTensorStore() = default; 99 DISABLE_COPY_AND_ASSIGN(DeviceTensorStore); 100 101 // The data storage of device tensor. Key is the anf node, value is the vector which may contains the device 102 // tensors from different devices. 103 std::unordered_map<AnfNode *, std::vector<DeviceTensorPtr>> device_tensors_; 104 }; 105 } // namespace runtime 106 } // namespace mindspore 107 108 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_ 109