1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SUBGRAPH_OPENCL_KERNEL_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SUBGRAPH_OPENCL_KERNEL_H_ 19 20 #include <memory> 21 #include <set> 22 #include <vector> 23 #include "src/runtime/kernel/opencl/opencl_kernel.h" 24 #include "src/runtime/gpu/opencl/opencl_allocator.h" 25 #include "src/runtime/gpu/opencl/opencl_executor.h" 26 #include "src/sub_graph_kernel.h" 27 28 namespace mindspore::kernel { 29 class OpenCLSubGraph : public SubGraphKernel { 30 public: OpenCLSubGraph(const std::vector<kernel::LiteKernel * > & inKernels,const std::vector<kernel::LiteKernel * > & outKernels,const std::vector<kernel::LiteKernel * > & nodes,Kernel * kernel)31 OpenCLSubGraph(const std::vector<kernel::LiteKernel *> &inKernels, 32 const std::vector<kernel::LiteKernel *> &outKernels, const std::vector<kernel::LiteKernel *> &nodes, 33 Kernel *kernel) 34 : SubGraphKernel(inKernels, outKernels, nodes, kernel) { 35 ocl_runtime_ = ocl_runtime_wrap_.GetInstance(); 36 if (nodes.front()->desc().data_type == kNumberTypeFloat16) { 37 subgraph_type_ = kGpuFp16SubGraph; 38 } else { 39 subgraph_type_ = kGpuFp32SubGraph; 40 } 41 desc_.arch = kernel::KERNEL_ARCH::kGPU; 42 static std::atomic_int index = 0; 43 this->set_name("GpuSubGraph" + std::to_string(index++)); 44 nodes_set_.insert(nodes.begin(), nodes.end()); 45 all_kernels_infer_done_ = std::all_of(nodes_.begin(), nodes_.end(), [](const kernel::LiteKernel *kernel) { 46 return kernel && kernel->InferShapeDone(); 47 }); 48 } 49 ~OpenCLSubGraph() override; 50 51 int Prepare() override; 52 int Init() override; 53 int ReSize() override; 54 int ReSize(bool interrupt); 55 int Execute() override; 56 int Execute(const KernelCallBack &before, const KernelCallBack &after) override; 57 58 private: 59 void UnInit(); 60 int UpdateTensorDataTypePass(); 61 void ReplaceOutTensorAndKernelToConvert(const lite::Tensor *in_tensor, 62 const std::vector<kernel::LiteKernel *> &in_kernels, lite::Tensor *new_tensor, 63 kernel::LiteKernel *in_convert_op, lite::opencl::MemType mem_type); 64 void GetInOutNodes(); 65 int GenToFormatOp(const std::vector<lite::Tensor *> &in_tensors, 66 const std::vector<std::vector<kernel::LiteKernel *>> &in_kernels, 67 std::vector<lite::Tensor *> *out_tensors, std::vector<OpenCLToFormatParameter *> *out_parameters, 68 std::vector<LiteKernel *> *out_convert_ops, lite::opencl::MemType mem_type); 69 void GetKernelFromToTensor(const std::vector<lite::Tensor *> &in_tensors, 70 const std::vector<kernel::LiteKernel *> &in_kernels, 71 std::vector<std::vector<kernel::LiteKernel *>> *out_kernels, bool is_from); 72 int FusionPass(); 73 74 int InsertOpsPass(); 75 76 public: 77 using PassFunc = int (OpenCLSubGraph::*)(void); 78 79 private: 80 std::shared_ptr<lite::opencl::OpenCLAllocator> allocator_{nullptr}; 81 std::vector<lite::Tensor *> in_convert_tensors_; 82 std::vector<lite::Tensor *> out_convert_tensors_; 83 std::vector<OpenCLToFormatParameter *> in_parameters_; 84 std::vector<OpenCLToFormatParameter *> out_parameters_; 85 std::vector<LiteKernel *> in_convert_ops_; 86 std::vector<LiteKernel *> out_convert_ops_; 87 std::set<LiteKernel *> nodes_set_; 88 lite::opencl::OpenCLRuntimeInnerWrapper ocl_runtime_wrap_; 89 lite::opencl::OpenCLRuntime *ocl_runtime_{nullptr}; 90 bool all_kernels_infer_done_ = false; 91 }; 92 } // namespace mindspore::kernel 93 94 #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_OPENCL_KERNEL_SUBGRAPH_OPENCL_KERNEL_H_ 95