1 /** 2 * Copyright 2022-2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H 18 #define TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H 19 20 #include <memory> 21 22 #include "common/common_test.h" 23 #define private public 24 #define protected public 25 #include "abstract/abstract_function.h" 26 #include "runtime/graph_scheduler/control_node_parser.h" 27 #include "include/backend/optimizer/graph_optimizer.h" 28 #include "backend/common/pass/communication_op_fusion.h" 29 #include "backend/graph_compiler/backend.h" 30 #include "runtime/hardware/device_context.h" 31 #include "runtime/hardware/device_context_manager.h" 32 #include "kernel/ops_utils.h" 33 #include "kernel/common_utils.h" 34 #include "kernel/framework_utils.h" 35 #define private public 36 #define protected public 37 38 namespace mindspore { 39 namespace runtime { 40 namespace test { 41 using abstract::AbstractFuncUnion; 42 using abstract::AbstractTensor; 43 using abstract::AbstractTensorPtr; 44 using abstract::AnalysisContext; 45 using abstract::FuncGraphAbstractClosure; 46 using device::DeviceAddress; 47 using device::DeviceAddressPtr; 48 using device::DeviceContextKey; 49 using device::DeviceContextRegister; 50 using device::DeviceType; 51 using kernel::AddressPtr; 52 using session::KernelGraph; 53 54 class TestDeviceAddress : public DeviceAddress { 55 public: TestDeviceAddress(const KernelTensorPtr & kernel_tensor)56 TestDeviceAddress(const KernelTensorPtr &kernel_tensor) : DeviceAddress(kernel_tensor) {} TestDeviceAddress(void * ptr,size_t size)57 TestDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} TestDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const std::string & device_name,uint32_t device_id)58 TestDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, const std::string &device_name, 59 uint32_t device_id) 60 : DeviceAddress(ptr, size, format, type_id, device_name, device_id) {} ~TestDeviceAddress()61 ~TestDeviceAddress() {} SyncDeviceToHost(const ShapeVector & shape,size_t size,TypeId type,void * host_ptr)62 virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const { 63 return true; 64 } SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const void * host_ptr,const std::string & format)65 virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, 66 const std::string &format) const { 67 return true; 68 } GetMutablePtr()69 virtual void *GetMutablePtr() const { return nullptr; } ClearDeviceMemory()70 virtual void ClearDeviceMemory() {} 71 }; 72 73 class TestKernelMod : public kernel::KernelMod { 74 public: 75 TestKernelMod() = default; 76 ~TestKernelMod() override = default; Launch(const std::vector<AddressPtr> & inputs,const std::vector<AddressPtr> & workspace,const std::vector<AddressPtr> & outputs,void * stream_ptr)77 virtual bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &workspace, 78 const std::vector<AddressPtr> &outputs, void *stream_ptr) { 79 return true; 80 } GetOpSupport()81 std::vector<kernel::KernelAttr> GetOpSupport() override { return {}; } 82 }; 83 84 class TestDeviceResManager : public device::DeviceResManager { 85 public: 86 TestDeviceResManager() = default; 87 ~TestDeviceResManager() override = default; 88 89 virtual bool AllocateMemory(DeviceAddress *const &address, uint32_t stream_id = UINT32_MAX) const { return true; } FreeMemory(DeviceAddress * const & address)90 virtual void FreeMemory(DeviceAddress *const &address) const {} 91 virtual void *AllocateMemory(size_t size, const uint32_t stream_id = UINT32_MAX) const { return nullptr; } FreeMemory(void * const ptr)92 virtual void FreeMemory(void *const ptr) const {} FreePartMemorys(const std::vector<void * > & free_addrs,const std::vector<void * > & keep_addrs,const std::vector<size_t> & keep_addr_sizes)93 virtual void FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs, 94 const std::vector<size_t> &keep_addr_sizes) const {} 95 virtual DeviceAddressPtr CreateDeviceAddress(void *const device_ptr, size_t device_size, const string &format, 96 TypeId type_id, const ShapeVector &shape, 97 const UserDataPtr &user_data = nullptr) const { 98 return std::make_shared<TestDeviceAddress>(device_ptr, device_size, format, type_id, "CPU", 0); 99 } 100 CreateDeviceAddress(const KernelTensorPtr & kernel_tensor)101 DeviceAddressPtr CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const { 102 MS_EXCEPTION_IF_NULL(kernel_tensor); 103 if (kernel_tensor->device_name().empty()) { 104 kernel_tensor->set_device_name(device_context_->device_context_key().device_name_); 105 kernel_tensor->set_device_id(device_context_->device_context_key().device_id_); 106 } 107 return std::make_shared<TestDeviceAddress>(kernel_tensor); 108 } 109 }; 110 111 class TestKernelExecutor : public device::KernelExecutor { 112 public: 113 TestKernelExecutor() = default; 114 ~TestKernelExecutor() override = default; 115 OptimizeGraph(const FuncGraphPtr & graph)116 virtual void OptimizeGraph(const FuncGraphPtr &graph) const { 117 MS_EXCEPTION_IF_NULL(graph); 118 auto kernel_graph = graph->cast<KernelGraphPtr>(); 119 MS_EXCEPTION_IF_NULL(kernel_graph); 120 auto &nodes = kernel_graph->execution_order(); 121 for (const auto node : nodes) { 122 MS_EXCEPTION_IF_NULL(node); 123 SetKernelInfo(node); 124 } 125 auto optimizer = std::make_shared<opt::GraphOptimizer>(); 126 auto pm = std::make_shared<opt::PassManager>(); 127 pm->AddPass(std::make_shared<opt::AllReduceFusion>()); 128 optimizer->AddPassManager(pm); 129 (void)optimizer->Optimize(kernel_graph); 130 kernel_graph->SetExecOrderByDefault(); 131 } 132 CreateKernel(const std::vector<CNodePtr> & nodes)133 virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const { 134 for (const auto node : nodes) { 135 MS_EXCEPTION_IF_NULL(node); 136 SetKernelInfo(node); 137 138 std::vector<size_t> input_size_list; 139 std::vector<size_t> output_size_list; 140 size_t input_num = common::AnfAlgo::GetInputTensorNum(node); 141 for (size_t input_index = 0; input_index < input_num; ++input_index) { 142 auto [input_node, index] = common::AnfAlgo::GetPrevNodeOutput(node, input_index, true); 143 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(input_node, index); 144 (void)input_size_list.emplace_back(tensor_size); 145 if (AnfAlgo::OutputAddrExist(input_node, index)) { 146 continue; 147 } 148 AnfAlgo::SetOutputAddr(std::make_shared<TestDeviceAddress>(nullptr, tensor_size), index, input_node.get()); 149 } 150 size_t output_num = AnfAlgo::GetOutputTensorNum(node); 151 for (size_t output_index = 0; output_index < output_num; ++output_index) { 152 size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(node, output_index); 153 (void)output_size_list.emplace_back(tensor_size); 154 AnfAlgo::SetOutputAddr(std::make_shared<TestDeviceAddress>(nullptr, tensor_size), output_index, node.get()); 155 } 156 157 const size_t kDefaultWorkSpaceSize = 4; 158 auto kernel_mod_ptr = std::make_shared<TestKernelMod>(); 159 kernel_mod_ptr->SetInputSizeList(input_size_list); 160 kernel_mod_ptr->SetOutputSizeList(output_size_list); 161 kernel_mod_ptr->SetWorkspaceSizeList({kDefaultWorkSpaceSize}); 162 AnfAlgo::SetKernelMod(kernel_mod_ptr, node.get()); 163 AnfAlgo::SetWorkspaceAddr(std::make_shared<TestDeviceAddress>(nullptr, kDefaultWorkSpaceSize), 0, node.get()); 164 } 165 } 166 167 private: SetKernelInfo(const CNodePtr & node)168 void SetKernelInfo(const CNodePtr &node) const { 169 MS_EXCEPTION_IF_NULL(node); 170 if (node->kernel_info() == nullptr) { 171 auto kernel_info = std::make_shared<device::KernelInfo>(); 172 node->set_kernel_info(kernel_info); 173 } 174 175 const auto &kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info()); 176 MS_EXCEPTION_IF_NULL(kernel_info); 177 if (kernel_info->select_kernel_build_info() != nullptr) { 178 return; 179 } 180 181 std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>(); 182 std::vector<std::string> inputs_format; 183 std::vector<TypeId> inputs_type; 184 size_t input_num = common::AnfAlgo::GetInputTensorNum(node); 185 for (size_t input_index = 0; input_index < input_num; ++input_index) { 186 (void)inputs_format.emplace_back(kOpFormat_DEFAULT); 187 (void)inputs_type.emplace_back(common::AnfAlgo::GetPrevNodeOutputInferDataType(node, input_index)); 188 } 189 190 std::vector<std::string> outputs_format; 191 std::vector<TypeId> outputs_type; 192 size_t output_num = AnfAlgo::GetOutputElementNum(node); 193 for (size_t output_index = 0; output_index < output_num; ++output_index) { 194 (void)outputs_format.emplace_back(kOpFormat_DEFAULT); 195 (void)outputs_type.emplace_back(common::AnfAlgo::GetOutputInferDataType(node, output_index)); 196 } 197 198 builder->SetOriginDataFormat(kOpFormat_DEFAULT); 199 builder->SetInputsFormat(inputs_format); 200 builder->SetInputsDeviceType(inputs_type); 201 builder->SetOutputsFormat(outputs_format); 202 builder->SetOutputsDeviceType(outputs_type); 203 kernel_info->set_select_kernel_build_info(builder->Build()); 204 } 205 }; 206 207 class TestDeviceContext : public device::DeviceInterface<TestKernelExecutor, TestDeviceResManager> { 208 public: TestDeviceContext(const DeviceContextKey & device_context_key)209 explicit TestDeviceContext(const DeviceContextKey &device_context_key) : DeviceInterface(device_context_key) {} 210 ~TestDeviceContext() override = default; 211 Initialize()212 virtual void Initialize() {} GetDeviceType()213 virtual DeviceType GetDeviceType() const { return DeviceType::kCPU; } GetRunMode(const FuncGraphPtr & func_graph)214 device::RunMode GetRunMode(const FuncGraphPtr &func_graph) const override { return device::RunMode::kKernelMode; } 215 }; 216 } // namespace test 217 } // namespace runtime 218 } // namespace mindspore 219 #endif // TESTS_UT_CPP_COMMON_DEVICE_COMMON_TEST_H 220