1 /**
2 * Copyright 2019 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "plugin/device/cpu/hal/device/cpu_simple_mem_plan.h"
17 #include "include/backend/anf_runtime_algorithm.h"
18 #include "include/common/utils/anfalgo.h"
19
20 namespace mindspore {
21 namespace device {
22 namespace cpu {
MemPlan(const session::KernelGraph * graph) const23 size_t CPUSimpleMemPlan::MemPlan(const session::KernelGraph *graph) const {
24 MS_EXCEPTION_IF_NULL(graph);
25 size_t total_mem_size = 32;
26 auto kernels = graph->execution_order();
27 for (const auto &kernel : kernels) {
28 MS_EXCEPTION_IF_NULL(kernel);
29 size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
30 for (size_t i = 0; i < input_num; ++i) {
31 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, i);
32 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
33 if (kernel_with_index.first->isa<Parameter>()) {
34 continue;
35 }
36 auto address = AnfAlgo::GetOutputAddr(kernel_with_index.first, kernel_with_index.second, true);
37 MS_EXCEPTION_IF_NULL(address);
38 if (address->GetDevicePtr() == nullptr) {
39 total_mem_size += address->GetSize();
40 }
41 }
42
43 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
44 for (size_t i = 0; i < output_num; ++i) {
45 auto address = AnfAlgo::GetOutputAddr(kernel, i);
46 MS_EXCEPTION_IF_NULL(address);
47 if (address->GetDevicePtr() == nullptr) {
48 total_mem_size += address->GetSize();
49 }
50 }
51
52 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
53 MS_EXCEPTION_IF_NULL(kernel_mod);
54 for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
55 auto address = AnfAlgo::GetWorkspaceAddr(kernel, i);
56 MS_EXCEPTION_IF_NULL(address);
57 if (address->GetDevicePtr() == nullptr) {
58 total_mem_size += address->GetSize();
59 }
60 }
61 }
62
63 return total_mem_size;
64 }
65
MemAssign(const session::KernelGraph * graph,uint8_t * base_ptr) const66 void CPUSimpleMemPlan::MemAssign(const session::KernelGraph *graph, uint8_t *base_ptr) const {
67 MS_EXCEPTION_IF_NULL(graph);
68 MS_EXCEPTION_IF_NULL(base_ptr);
69 uint8_t *mem_ptr = base_ptr;
70 auto kernels = graph->execution_order();
71 for (const auto &kernel : kernels) {
72 MS_EXCEPTION_IF_NULL(kernel);
73 size_t input_num = common::AnfAlgo::GetInputTensorNum(kernel);
74 for (size_t i = 0; i < input_num; ++i) {
75 auto kernel_with_index = common::AnfAlgo::GetPrevNodeOutput(kernel, i);
76 MS_EXCEPTION_IF_NULL(kernel_with_index.first);
77 if (kernel_with_index.first->isa<Parameter>()) {
78 continue;
79 }
80 auto address = AnfAlgo::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second, true);
81 MS_EXCEPTION_IF_NULL(address);
82 if (address->GetDevicePtr() == nullptr) {
83 address->SetDevicePtr(mem_ptr);
84 mem_ptr = mem_ptr + address->GetSize();
85 }
86 }
87
88 size_t output_num = AnfAlgo::GetOutputTensorNum(kernel);
89 for (size_t i = 0; i < output_num; ++i) {
90 auto address = AnfAlgo::GetMutableOutputAddr(kernel, i);
91 MS_EXCEPTION_IF_NULL(address);
92 if (address->GetDevicePtr() == nullptr) {
93 address->SetDevicePtr(mem_ptr);
94 mem_ptr = mem_ptr + address->GetSize();
95 }
96 }
97
98 auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
99 MS_EXCEPTION_IF_NULL(kernel_mod);
100 for (size_t i = 0; i < kernel_mod->GetWorkspaceSizeList().size(); ++i) {
101 auto address = AnfAlgo::GetWorkspaceAddr(kernel, i);
102 MS_EXCEPTION_IF_NULL(address);
103 if (address->GetDevicePtr() == nullptr) {
104 address->SetDevicePtr(mem_ptr);
105 mem_ptr = mem_ptr + address->GetSize();
106 }
107 }
108 }
109 }
110 } // namespace cpu
111 } // namespace device
112 } // namespace mindspore
113