1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_ 19 20 #include <memory> 21 #include <set> 22 #include "utils/hash_map.h" 23 #include "utils/ms_utils.h" 24 #include "include/backend/device_address.h" 25 26 namespace mindspore { 27 namespace runtime { 28 using DeviceTensor = mindspore::device::DeviceAddress; 29 30 // The device tensor mainly includes address ptr, size and reference count, 31 // which represents the basic data structure of kernel launch and transfers between actors. 32 // Some device tensors (such as ref real parameters) need be refreshed in the running, 33 // so they are more suitable for store and can be obtained when they are refreshed copy by actor. 34 class DeviceTensorCopyStore { 35 public: GetInstance()36 static DeviceTensorCopyStore &GetInstance() { 37 static DeviceTensorCopyStore instance; 38 return instance; 39 } 40 Insert(DeviceTensor * const key,DeviceTensor * const value)41 void Insert(DeviceTensor *const key, DeviceTensor *const value) { 42 MS_EXCEPTION_IF_NULL(key); 43 MS_EXCEPTION_IF_NULL(value); 44 (void)copy_device_tensors_[key].insert(value); 45 } 46 Fetch(DeviceTensor * const key)47 std::set<DeviceTensor *> Fetch(DeviceTensor *const key) const { 48 MS_EXCEPTION_IF_NULL(key); 49 const auto &iter = copy_device_tensors_.find(key); 50 if (iter != copy_device_tensors_.end()) { 51 return iter->second; 52 } else { 53 return {}; 54 } 55 } 56 Clear()57 void Clear() { copy_device_tensors_.clear(); } 58 59 private: 60 DeviceTensorCopyStore() = default; 61 ~DeviceTensorCopyStore() = default; 62 DISABLE_COPY_AND_ASSIGN(DeviceTensorCopyStore); 63 64 // The data storage of device tensor which need be back refreshed dynamically. 65 // It is created and removed dynamically in the running. 66 // Key is the dest device tensor, value is the source device tensors which provide copy data to dest device tensor. 67 mindspore::HashMap<DeviceTensor *, std::set<DeviceTensor *>> copy_device_tensors_; 68 }; 69 } // namespace runtime 70 } // namespace mindspore 71 72 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_ 73