• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
19 
20 #include <memory>
21 #include <vector>
22 #include <shared_mutex>
23 #include "utils/hash_map.h"
24 #include "utils/ms_utils.h"
25 #include "include/backend/device_address.h"
26 
27 namespace mindspore {
28 namespace runtime {
29 using DeviceTensor = mindspore::device::DeviceAddress;
30 using DeviceTensorType = mindspore::device::DeviceType;
31 using DeviceTensorPtr = std::shared_ptr<DeviceTensor>;
32 
33 // The device tensor mainly includes address ptr, size and reference count,
34 // which represents the basic data structure of kernel launch and transfers between actors.
35 // Some device tensors (such as weights and value nodes of graph) are fixed addresses and persistent,
36 // so they are more suitable for store and can be obtained when they are used by actor.
37 class DeviceTensorStore {
38  public:
GetInstance()39   static DeviceTensorStore &GetInstance() {
40     static DeviceTensorStore instance;
41     return instance;
42   }
43 
44   //  Support value modifiable.
Insert(AnfNode * key,const DeviceTensorPtr & value)45   void Insert(AnfNode *key, const DeviceTensorPtr &value) {
46     MS_EXCEPTION_IF_NULL(key);
47     MS_EXCEPTION_IF_NULL(value);
48     std::unique_lock<std::shared_mutex> lock(map_mutex_);
49     const auto &iter = device_tensors_.find(key);
50     if (iter == device_tensors_.end()) {
51       device_tensors_[key].emplace_back(value);
52       return;
53     }
54 
55     for (size_t i = 0; i < iter->second.size(); ++i) {
56       MS_EXCEPTION_IF_NULL(iter->second[i]);
57       if (iter->second[i]->GetDeviceType() == value->GetDeviceType()) {
58         if (iter->second[i]->GetSize() != value->GetSize()) {
59           MS_LOG(INFO) << "The update size:" << value->GetSize()
60                        << " is not equal of the old size:" << iter->second[i]->GetSize()
61                        << " for node:" << key->fullname_with_scope()
62                        << ". Please check whether it causes accuracy problem.";
63         }
64         iter->second[i] = value;
65         return;
66       }
67     }
68     iter->second.emplace_back(value);
69   }
70 
Remove(AnfNode * key)71   void Remove(AnfNode *key) {
72     MS_EXCEPTION_IF_NULL(key);
73     std::unique_lock<std::shared_mutex> lock(map_mutex_);
74     const auto &iter = device_tensors_.find(key);
75     if (iter != device_tensors_.end()) {
76       (void)device_tensors_.erase(iter);
77     }
78   }
79 
Fetch(AnfNode * key)80   std::vector<DeviceTensorPtr> Fetch(AnfNode *key) const {
81     MS_EXCEPTION_IF_NULL(key);
82     std::shared_lock<std::shared_mutex> lock(map_mutex_);
83     const auto &iter = device_tensors_.find(key);
84     if (iter != device_tensors_.end()) {
85       return iter->second;
86     } else {
87       std::vector<DeviceTensorPtr> empty_value;
88       return empty_value;
89     }
90   }
91 
Fetch(AnfNode * key,DeviceTensorType value_type)92   DeviceTensorPtr Fetch(AnfNode *key, DeviceTensorType value_type) const {
93     MS_EXCEPTION_IF_NULL(key);
94     std::shared_lock<std::shared_mutex> lock(map_mutex_);
95     const auto &iter = device_tensors_.find(key);
96     if (iter != device_tensors_.end()) {
97       for (const auto &device_tensor : iter->second) {
98         MS_EXCEPTION_IF_NULL(device_tensor);
99         if (device_tensor->GetDeviceType() == value_type) {
100           return device_tensor;
101         }
102       }
103     }
104     return nullptr;
105   }
106 
Clear()107   void Clear() {
108     std::unique_lock<std::shared_mutex> lock(map_mutex_);
109     device_tensors_.clear();
110   }
111 
112  private:
113   DeviceTensorStore() = default;
114   ~DeviceTensorStore() = default;
115   DISABLE_COPY_AND_ASSIGN(DeviceTensorStore);
116 
117   // The data storage of device tensor. Key is the anf node, value is the vector which may contains the device
118   // tensors from different devices.
119   mindspore::HashMap<AnfNode *, std::vector<DeviceTensorPtr>> device_tensors_;
120   // Read/Write lock for map.
121   mutable std::shared_mutex map_mutex_;
122 };
123 }  // namespace runtime
124 }  // namespace mindspore
125 
126 #endif  // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_DEVICE_TENSOR_STORE_H_
127