1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_FORMAT_PASS_H_ 18 #define MINDSPORE_LITE_SRC_EXTENDRT_FORMAT_PASS_H_ 19 20 #include <vector> 21 #include <memory> 22 #include <utility> 23 #include <string> 24 #include "src/executor/kernel_exec.h" 25 #include "src/executor/sub_graph_kernel.h" 26 #include "src/litert/pass/format_pass/pass_utils.h" 27 28 namespace mindspore::lite::pass { 29 class FormatPass { 30 public: FormatPass(mindspore::Format format,std::string name,CreateFormatTransposeFunc create_format_transpose_func)31 explicit FormatPass(mindspore::Format format, std::string name, 32 CreateFormatTransposeFunc create_format_transpose_func) 33 : format_(format), 34 name_(std::move(name)), 35 create_format_transpose_func_(std::move(create_format_transpose_func)) {} 36 virtual ~FormatPass() = default; 37 virtual int RunPass(kernel::SubGraphKernel *graph, std::vector<lite::Tensor *> *tensors) = 0; 38 name()39 std::string name() const { return name_; } 40 41 protected: 42 Format format_ = DEFAULT_FORMAT; 43 std::string name_{}; 44 CreateFormatTransposeFunc create_format_transpose_func_ = nullptr; 45 }; 46 using FormatPassPtr = std::shared_ptr<FormatPass>; 47 48 class FormatOptimize { 49 public: 50 int AddPass(const FormatPassPtr &pass); 51 int RunPass(kernel::SubGraphKernel *graph, std::vector<Tensor *> *tensors); 52 53 private: 54 std::vector<FormatPassPtr> pass_list_; 55 }; 56 using FormatOptimizePtr = std::shared_ptr<FormatOptimize>; 57 58 int DoFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list, 59 std::vector<mindspore::lite::Tensor *> *tensors, mindspore::Format graph_format, 60 const CreateFormatTransposeFunc &create_format_transpose_func); 61 62 int RuntimeFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list, 63 std::vector<mindspore::lite::Tensor *> *tensors, 64 mindspore::Format format = mindspore::Format::NHWC, 65 const CreateFormatTransposeFunc &create_format_transpose_func = nullptr); 66 } // namespace mindspore::lite::pass 67 #endif // MINDSPORE_LITE_SRC_EXTENDRT_FORMAT_PASS_H_ 68