1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_HAL_DEVICE_LAUNCH_TRANSDATA_H_ 18 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_HAL_DEVICE_LAUNCH_TRANSDATA_H_ 19 #include <utility> 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include "kernel/kernel.h" 24 #include "include/backend/kernel_graph.h" 25 26 namespace mindspore::device::ascend { 27 class LaunchTransData { 28 public: LaunchTransData(uint32_t stream_id,TypeId dtype,size_t total_size,std::string src_format,std::string dst_format,ShapeVector host_shape,int64_t groups)29 LaunchTransData(uint32_t stream_id, TypeId dtype, size_t total_size, std::string src_format, std::string dst_format, 30 ShapeVector host_shape, int64_t groups) 31 : stream_id_(stream_id), 32 dtype_(dtype), 33 total_size_(total_size), 34 src_format_(std::move(src_format)), 35 dst_format_(std::move(dst_format)), 36 shape_(std::move(host_shape)), 37 groups_(groups) {} 38 39 ~LaunchTransData() = default; 40 void LaunchOpKernel(); 41 std::vector<uint8_t *> GetKernelOutputAddr(); 42 void SetInputAddr(void *input_addr); 43 void FreeDeviceMem(); 44 45 private: 46 void AclKernelBuild(); 47 void ConstructKernelGraph(); 48 void SetKernelBuildInfo(); 49 uint8_t *AllocDeviceMem(size_t size); 50 void CreateOutputAddr(const std::vector<size_t> &outputs_list, std::vector<kernel::KernelTensorPtr> *kernel_tensors); 51 uint32_t stream_id_; 52 TypeId dtype_; 53 size_t total_size_; 54 std::string src_format_; 55 std::string dst_format_; 56 ShapeVector shape_; 57 int64_t groups_; 58 kernel::KernelModPtr kernel_mod_{nullptr}; 59 std::vector<uint8_t *> outputs_addr_; 60 void *input_addr_{nullptr}; 61 KernelGraphPtr kernel_graph_; 62 }; 63 64 } // namespace mindspore::device::ascend 65 #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_HAL_DEVICE_LAUNCH_TRANSDATA_H_ 66