• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "plugin/device/cpu/hal/device/cpu_simple_mem_plan.h"
17 #include "include/backend/anf_runtime_algorithm.h"
18 #include "include/common/utils/anfalgo.h"
19 
20 namespace mindspore {
21 namespace device {
22 namespace cpu {
MemPlan(const session::KernelGraph * graph) const23 size_t CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) const {
24   MS_EXCEPTION_IF_NULL(graph);
25   size_t total_mem_size = 32;
26   auto kernels = graph->execution_order();
27   for (const auto &kernel : kernels) {
28     MS_EXCEPTION_IF_NULL(kernel);
29     size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
30     for (size_t i = 0; i < input_num; ++i) {
31       auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, i);
32       MS_EXCEPTION_IF_NULL(kernel_with_index.first);
33       if (kernel_with_index.first->isa<Parameter>()) {
34         continue;
35       }
36       auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true);
37       MS_EXCEPTION_IF_NULL(address);
38       if (address->GetDevicePtr() == nullptr) {
39         total_mem_size += address->GetSize();
40       }
41     }
42 
43     size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
44     for (size_t i = 0; i < output_num; ++i) {
45       auto address = AnfAlgo::GetOutputAddr(kernel, i);
46       MS_EXCEPTION_IF_NULL(address);
47       if (address->GetDevicePtr() == nullptr) {
48         total_mem_size += address->GetSize();
49       }
50     }
51 
52     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
53     MS_EXCEPTION_IF_NULL(kernel_mod);
54     for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
55       auto address = AnfAlgo::GetWorkspaceAddr(kernel, i);
56       MS_EXCEPTION_IF_NULL(address);
57       if (address->GetDevicePtr() == nullptr) {
58         total_mem_size += address->GetSize();
59       }
60     }
61   }
62 
63   return total_mem_size;
64 }
65 
MemAssign(const session::KernelGraph * graph,uint8_t * base_ptr) const66 void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) const {
67   MS_EXCEPTION_IF_NULL(graph);
68   MS_EXCEPTION_IF_NULL(base_ptr);
69   uint8_t *mem_ptr = base_ptr;
70   auto kernels = graph->execution_order();
71   for (const auto &kernel : kernels) {
72     MS_EXCEPTION_IF_NULL(kernel);
73     size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
74     for (size_t i = 0; i < input_num; ++i) {
75       auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, i);
76       MS_EXCEPTION_IF_NULL(kernel_with_index.first);
77       if (kernel_with_index.first->isa<Parameter>()) {
78         continue;
79       }
80       auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true);
81       MS_EXCEPTION_IF_NULL(address);
82       if (address->GetDevicePtr() == nullptr) {
83         address->SetDevicePtr(mem_ptr);
84         mem_ptr = mem_ptr + address->GetSize();
85       }
86     }
87 
88     size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
89     for (size_t i = 0; i < output_num; ++i) {
90       auto address = AnfAlgo::GetMutableOutputAddr(kernel, i);
91       MS_EXCEPTION_IF_NULL(address);
92       if (address->GetDevicePtr() == nullptr) {
93         address->SetDevicePtr(mem_ptr);
94         mem_ptr = mem_ptr + address->GetSize();
95       }
96     }
97 
98     auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
99     MS_EXCEPTION_IF_NULL(kernel_mod);
100     for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
101       auto address = AnfAlgo::GetWorkspaceAddr(kernel, i);
102       MS_EXCEPTION_IF_NULL(address);
103       if (address->GetDevicePtr() == nullptr) {
104         address->SetDevicePtr(mem_ptr);
105         mem_ptr = mem_ptr + address->GetSize();
106       }
107     }
108   }
109 }
110 }  // namespace cpu
111 }  // namespace device
112 }  // namespace mindspore
113