• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
19 
20 #include <memory>
21 #include <unordered_map>
22 #include <vector>
23 #include "utils/ms_utils.h"
24 #include "runtime/device/device_address.h"
25 
26 namespace mindspore {
27 namespace runtime {
28 using DeviceTensor = mindspore::device::DeviceAddress;
29 using DeviceTensorType = mindspore::device::DeviceAddressType;
30 using DeviceTensorPtr = std::shared_ptr<DeviceTensor>;
31 
32 // The device tensor mainly includes address ptr, size and reference count,
33 // which represents the basic data structure of kernel launch and transfers between actors.
34 // Some device tensors (such as weights and value nodes of graph) are fixed addresses and persistent,
35 // so they are more suitable for store and can be obtained when they are used by actor.
36 class DeviceTensorStore {
37  public:
GetInstance()38   static DeviceTensorStore &GetInstance() {
39     static DeviceTensorStore instance;
40     return instance;
41   }
42 
43   //  Support value modifiable.
Insert(AnfNode * key,const DeviceTensorPtr & value)44   void Insert(AnfNode *key, const DeviceTensorPtr &value) {
45     MS_EXCEPTION_IF_NULL(key);
46     const auto &iter = device_tensors_.find(key);
47     if (iter == device_tensors_.end()) {
48       device_tensors_[key].emplace_back(value);
49       return;
50     }
51 
52     for (size_t i = 0; i < iter->second.size(); ++i) {
53       if (iter->second[i]->DeviceType() == value->DeviceType()) {
54         iter->second[i] = value;
55         return;
56       }
57     }
58     iter->second.emplace_back(value);
59   }
60 
Remove(AnfNode * key)61   void Remove(AnfNode *key) {
62     MS_EXCEPTION_IF_NULL(key);
63     const auto &iter = device_tensors_.find(key);
64     if (iter != device_tensors_.end()) {
65       (void)device_tensors_.erase(iter);
66     }
67   }
68 
Fetch(AnfNode * key)69   std::vector<DeviceTensorPtr> Fetch(AnfNode *key) const {
70     MS_EXCEPTION_IF_NULL(key);
71     const auto &iter = device_tensors_.find(key);
72     if (iter != device_tensors_.end()) {
73       return iter->second;
74     } else {
75       std::vector<DeviceTensorPtr> empty_value;
76       return empty_value;
77     }
78   }
79 
Fetch(AnfNode * key,DeviceTensorType value_type)80   DeviceTensor *Fetch(AnfNode *key, DeviceTensorType value_type) const {
81     MS_EXCEPTION_IF_NULL(key);
82     const auto &iter = device_tensors_.find(key);
83     if (iter != device_tensors_.end()) {
84       for (const auto &device_tensor : iter->second) {
85         MS_EXCEPTION_IF_NULL(device_tensor);
86         if (device_tensor->DeviceType() == value_type) {
87           return device_tensor.get();
88         }
89       }
90     }
91     return nullptr;
92   }
93 
Clear()94   void Clear() { device_tensors_.clear(); }
95 
96  private:
97   DeviceTensorStore() = default;
98   ~DeviceTensorStore() = default;
99   DISABLE_COPY_AND_ASSIGN(DeviceTensorStore);
100 
101   // The data storage of device tensor. Key is the anf node, value is the vector which may contains the device
102   // tensors from different devices.
103   std::unordered_map<AnfNode *, std::vector<DeviceTensorPtr>> device_tensors_;
104 };
105 }  // namespace runtime
106 }  // namespace mindspore
107 
108 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
109