• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
19 
20 #include <memory>
21 #include <set>
22 #include "utils/hash_map.h"
23 #include "utils/ms_utils.h"
24 #include "include/backend/device_address.h"
25 
26 namespace mindspore {
27 namespace runtime {
28 using DeviceTensor = mindspore::device::DeviceAddress;
29 
30 // The device tensor mainly includes address ptr, size and reference count,
31 // which represents the basic data structure of kernel launch and transfers between actors.
32 // Some device tensors (such as ref real parameters) need be refreshed in the running,
33 // so they are more suitable for store and can be obtained when they are refreshed copy by actor.
34 class DeviceTensorCopyStore {
35  public:
GetInstance()36   static DeviceTensorCopyStore &GetInstance() {
37     static DeviceTensorCopyStore instance;
38     return instance;
39   }
40 
Insert(DeviceTensor * const key,DeviceTensor * const value)41   void Insert(DeviceTensor *const key, DeviceTensor *const value) {
42     MS_EXCEPTION_IF_NULL(key);
43     MS_EXCEPTION_IF_NULL(value);
44     (void)copy_device_tensors_[key].insert(value);
45   }
46 
Fetch(DeviceTensor * const key)47   std::set<DeviceTensor *> Fetch(DeviceTensor *const key) const {
48     MS_EXCEPTION_IF_NULL(key);
49     const auto &iter = copy_device_tensors_.find(key);
50     if (iter != copy_device_tensors_.end()) {
51       return iter->second;
52     } else {
53       return {};
54     }
55   }
56 
Clear()57   void Clear() { copy_device_tensors_.clear(); }
58 
59  private:
60   DeviceTensorCopyStore() = default;
61   ~DeviceTensorCopyStore() = default;
62   DISABLE_COPY_AND_ASSIGN(DeviceTensorCopyStore);
63 
64   // The data storage of device tensor which need be back refreshed dynamically.
65   // It is created and removed dynamically in the running.
66   // Key is the dest device tensor, value is the source device tensors which provide copy data to dest device tensor.
67   mindspore::HashMap<DeviceTensor *, std::set<DeviceTensor *>> copy_device_tensors_;
68 };
69 }  // namespace runtime
70 }  // namespace mindspore
71 
72 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_COPY_STORE_H_
73