• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_EXTENDRT_FORMAT_PASS_H_
18 #define MINDSPORE_LITE_SRC_EXTENDRT_FORMAT_PASS_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <utility>
23 #include <string>
24 #include "src/executor/kernel_exec.h"
25 #include "src/executor/sub_graph_kernel.h"
26 #include "src/litert/pass/format_pass/pass_utils.h"
27 
28 namespace mindspore::lite::pass {
29 class FormatPass {
30  public:
FormatPass(mindspore::Format format,std::string name,CreateFormatTransposeFunc create_format_transpose_func)31   explicit FormatPass(mindspore::Format format, std::string name,
32                       CreateFormatTransposeFunc create_format_transpose_func)
33       : format_(format),
34         name_(std::move(name)),
35         create_format_transpose_func_(std::move(create_format_transpose_func)) {}
36   virtual ~FormatPass() = default;
37   virtual int RunPass(kernel::SubGraphKernel *graph, std::vector<lite::Tensor *> *tensors) = 0;
38 
name()39   std::string name() const { return name_; }
40 
41  protected:
42   Format format_ = DEFAULT_FORMAT;
43   std::string name_{};
44   CreateFormatTransposeFunc create_format_transpose_func_ = nullptr;
45 };
46 using FormatPassPtr = std::shared_ptr<FormatPass>;
47 
48 class FormatOptimize {
49  public:
50   int AddPass(const FormatPassPtr &pass);
51   int RunPass(kernel::SubGraphKernel *graph, std::vector<Tensor *> *tensors);
52 
53  private:
54   std::vector<FormatPassPtr> pass_list_;
55 };
56 using FormatOptimizePtr = std::shared_ptr<FormatOptimize>;
57 
58 int DoFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list,
59                  std::vector<mindspore::lite::Tensor *> *tensors, mindspore::Format graph_format,
60                  const CreateFormatTransposeFunc &create_format_transpose_func);
61 
62 int RuntimeFormatPass(std::vector<mindspore::kernel::KernelExec *> *subgraph_list,
63                       std::vector<mindspore::lite::Tensor *> *tensors,
64                       mindspore::Format format = mindspore::Format::NHWC,
65                       const CreateFormatTransposeFunc &create_format_transpose_func = nullptr);
66 }  // namespace mindspore::lite::pass
67 #endif  // MINDSPORE_LITE_SRC_EXTENDRT_FORMAT_PASS_H_
68