• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
18 #define TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
19 
20 #include <memory>
21 
22 #include "common/common_test.h"
23 #define private public
24 #define protected public
25 #include "abstract/abstract_function.h"
26 #include "runtime/graph_scheduler/control_node_parser.h"
27 #include "include/backend/optimizer/graph_optimizer.h"
28 #include "backend/common/pass/communication_op_fusion.h"
29 #include "backend/graph_compiler/backend.h"
30 #include "runtime/hardware/device_context.h"
31 #include "runtime/hardware/device_context_manager.h"
32 #include "kernel/ops_utils.h"
33 #include "kernel/common_utils.h"
34 #include "kernel/framework_utils.h"
35 #define private public
36 #define protected public
37 
38 namespace mindspore {
39 namespace runtime {
40 namespace test {
41 using abstract::AbstractFuncUnion;
42 using abstract::AbstractTensor;
43 using abstract::AbstractTensorPtr;
44 using abstract::AnalysisContext;
45 using abstract::FuncGraphAbstractClosure;
46 using device::DeviceAddress;
47 using device::DeviceAddressPtr;
48 using device::DeviceContextKey;
49 using device::DeviceContextRegister;
50 using device::DeviceType;
51 using kernel::AddressPtr;
52 using session::KernelGraph;
53 
54 class TestDeviceAddress : public DeviceAddress {
55  public:
TestDeviceAddress(const KernelTensorPtr & kernel_tensor)56   TestDeviceAddress(const KernelTensorPtr &kernel_tensor) : DeviceAddress(kernel_tensor) {}
TestDeviceAddress(void * ptr,size_t size)57   TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
TestDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)58   TestDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const std::string &device_name,
59                     uint32_t device_id)
60       : DeviceAddress(ptr, size, format, type_id, device_name, device_id) {}
~TestDeviceAddress()61   ~TestDeviceAddress() {}
SyncDeviceToHost(const ShapeVector & shape,size_t size,TypeId type,void * host_ptr)62   virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const {
63     return true;
64   }
SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const void * host_ptr,const std::string & format)65   virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
66                                 const std::string &format) const {
67     return true;
68   }
GetMutablePtr()69   virtual void *GetMutablePtr() const { return nullptr; }
ClearDeviceMemory()70   virtual void ClearDeviceMemory() {}
71 };
72 
73 class TestKernelMod : public kernel::KernelMod {
74  public:
75   TestKernelMod() = default;
76   ~TestKernelMod() override = default;
Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs,void * stream_ptr)77   virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace,
78                       const std::vector<AddressPtr> &outputs, void *stream_ptr) {
79     return true;
80   }
GetOpSupport()81   std::vector<kernel::KernelAttr> GetOpSupport() override { return {}; }
82 };
83 
84 class TestDeviceResManager : public device::DeviceResManager {
85  public:
86   TestDeviceResManager() = default;
87   ~TestDeviceResManager() override = default;
88 
89   virtual bool AllocateMemory(DeviceAddress *const &address, uint32_t stream_id = UINT32_MAX) const { return true; }
FreeMemory(DeviceAddress * const & address)90   virtual void FreeMemory(DeviceAddress *const &address) const {}
91   virtual void *AllocateMemory(size_t size, const uint32_t stream_id = UINT32_MAX) const { return nullptr; }
FreeMemory(void * const ptr)92   virtual void FreeMemory(void *const ptr) const {}
FreePartMemorys(const std::vector<void * > & free_addrs,const std::vector<void * > & keep_addrs,const std::vector<size_t> & keep_addr_sizes)93   virtual void FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs,
94                                const std::vector<size_t> &keep_addr_sizes) const {}
95   virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format,
96                                                TypeId type_id, const ShapeVector &shape,
97                                                const UserDataPtr &user_data = nullptr) const {
98     return std::make_shared<TestDeviceAddress>(device_ptr, device_size, format, type_id, "CPU", 0);
99   }
100 
CreateDeviceAddress(const KernelTensorPtr & kernel_tensor)101   DeviceAddressPtr CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const {
102     MS_EXCEPTION_IF_NULL(kernel_tensor);
103     if (kernel_tensor->device_name().empty()) {
104       kernel_tensor->set_device_name(device_context_->device_context_key().device_name_);
105       kernel_tensor->set_device_id(device_context_->device_context_key().device_id_);
106     }
107     return std::make_shared<TestDeviceAddress>(kernel_tensor);
108   }
109 };
110 
111 class TestKernelExecutor : public device::KernelExecutor {
112  public:
113   TestKernelExecutor() = default;
114   ~TestKernelExecutor() override = default;
115 
OptimizeGraph(const FuncGraphPtr & graph)116   virtual void OptimizeGraph(const FuncGraphPtr &graph) const {
117     MS_EXCEPTION_IF_NULL(graph);
118     auto kernel_graph = graph->cast<KernelGraphPtr>();
119     MS_EXCEPTION_IF_NULL(kernel_graph);
120     auto &nodes = kernel_graph->execution_order();
121     for (const auto node : nodes) {
122       MS_EXCEPTION_IF_NULL(node);
123       SetKernelInfo(node);
124     }
125     auto optimizer = std::make_shared<opt::GraphOptimizer>();
126     auto pm = std::make_shared<opt::PassManager>();
127     pm->AddPass(std::make_shared<opt::AllReduceFusion>());
128     optimizer->AddPassManager(pm);
129     (void)optimizer->Optimize(kernel_graph);
130     kernel_graph->SetExecOrderByDefault();
131   }
132 
CreateKernel(const std::vector<CNodePtr> & nodes)133   virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {
134     for (const auto node : nodes) {
135       MS_EXCEPTION_IF_NULL(node);
136       SetKernelInfo(node);
137 
138       std::vector<size_t> input_size_list;
139       std::vector<size_t> output_size_list;
140       size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
141       for (size_t input_index = 0; input_index < input_num; ++input_index) {
142         auto [input_node, index] = common::AnfAlgo::GetPrevNodeOutput(node, input_index, true);
143         size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index);
144         (void)input_size_list.emplace_back(tensor_size);
145         if (AnfAlgo::OutputAddrExist(input_node, index)) {
146           continue;
147         }
148         AnfAlgo::SetOutputAddr(std::make_shared<TestDeviceAddress>(nullptr, tensor_size), index, input_node.get());
149       }
150       size_t output_num = AnfAlgo::GetOutputTensorNum(node);
151       for (size_t output_index = 0; output_index < output_num; ++output_index) {
152         size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(node, output_index);
153         (void)output_size_list.emplace_back(tensor_size);
154         AnfAlgo::SetOutputAddr(std::make_shared<TestDeviceAddress>(nullptr, tensor_size), output_index, node.get());
155       }
156 
157       const size_t kDefaultWorkSpaceSize = 4;
158       auto kernel_mod_ptr = std::make_shared<TestKernelMod>();
159       kernel_mod_ptr->SetInputSizeList(input_size_list);
160       kernel_mod_ptr->SetOutputSizeList(output_size_list);
161       kernel_mod_ptr->SetWorkspaceSizeList({kDefaultWorkSpaceSize});
162       AnfAlgo::SetKernelMod(kernel_mod_ptr, node.get());
163       AnfAlgo::SetWorkspaceAddr(std::make_shared<TestDeviceAddress>(nullptr, kDefaultWorkSpaceSize), 0, node.get());
164     }
165   }
166 
167  private:
SetKernelInfo(const CNodePtr & node)168   void SetKernelInfo(const CNodePtr &node) const {
169     MS_EXCEPTION_IF_NULL(node);
170     if (node->kernel_info() == nullptr) {
171       auto kernel_info = std::make_shared<device::KernelInfo>();
172       node->set_kernel_info(kernel_info);
173     }
174 
175     const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
176     MS_EXCEPTION_IF_NULL(kernel_info);
177     if (kernel_info->select_kernel_build_info() != nullptr) {
178       return;
179     }
180 
181     std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
182     std::vector<std::string> inputs_format;
183     std::vector<TypeId> inputs_type;
184     size_t input_num = common::AnfAlgo::GetInputTensorNum(node);
185     for (size_t input_index = 0; input_index < input_num; ++input_index) {
186       (void)inputs_format.emplace_back(kOpFormat_DEFAULT);
187       (void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index));
188     }
189 
190     std::vector<std::string> outputs_format;
191     std::vector<TypeId> outputs_type;
192     size_t output_num = AnfAlgo::GetOutputElementNum(node);
193     for (size_t output_index = 0; output_index < output_num; ++output_index) {
194       (void)outputs_format.emplace_back(kOpFormat_DEFAULT);
195       (void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(node, output_index));
196     }
197 
198     builder->SetOriginDataFormat(kOpFormat_DEFAULT);
199     builder->SetInputsFormat(inputs_format);
200     builder->SetInputsDeviceType(inputs_type);
201     builder->SetOutputsFormat(outputs_format);
202     builder->SetOutputsDeviceType(outputs_type);
203     kernel_info->set_select_kernel_build_info(builder->Build());
204   }
205 };
206 
207 class TestDeviceContext : public device::DeviceInterface<TestKernelExecutor, TestDeviceResManager> {
208  public:
TestDeviceContext(const DeviceContextKey & device_context_key)209   explicit TestDeviceContext(const DeviceContextKey &device_context_key) : DeviceInterface(device_context_key) {}
210   ~TestDeviceContext() override = default;
211 
Initialize()212   virtual void Initialize() {}
GetDeviceType()213   virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; }
GetRunMode(const FuncGraphPtr & func_graph)214   device::RunMode GetRunMode(const FuncGraphPtr &func_graph) const override { return device::RunMode::kKernelMode; }
215 };
216 }  // namespace test
217 }  // namespace runtime
218 }  // namespace mindspore
219 #endif  // TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H
220