• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SET_INFERSHAPE_FUNCTOR_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SET_INFERSHAPE_FUNCTOR_H_
18 #include <string>
19 #include <vector>
20 
21 #include "ir/func_graph.h"
22 #include "include/backend/visible.h"
23 #include "include/backend/optimizer/pass.h"
24 #include "backend/common/graph_kernel/symbol_engine/jit/cpp_visitor.h"
25 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h"
26 
27 namespace mindspore::graphkernel {
28 using opt::dynamic_shape::InferShapeFunctor;
29 using opt::dynamic_shape::kAttrInferShapeFunctor;
30 
31 class BACKEND_EXPORT SymbolEngineInfer : public InferShapeFunctor {
32  public:
SymbolEngineInfer(const std::string & name,const SymbolEnginePtr & engine,const AbstractBasePtr & out_abstract)33   SymbolEngineInfer(const std::string &name, const SymbolEnginePtr &engine, const AbstractBasePtr &out_abstract)
34       : InferShapeFunctor(name), engine_(engine), out_abstract_(out_abstract) {}
35   ~SymbolEngineInfer() override = default;
36   MS_DECLARE_PARENT(SymbolEngineInfer, InferShapeFunctor)
37   BaseShapePtr InferShape(const AbstractBasePtrList &args) override;
38 
39  protected:
40   SymbolEnginePtr engine_;
41   AbstractBasePtr out_abstract_;
42 };
43 
44 class SymbolEngineJitInfer : public InferShapeFunctor {
45  public:
SymbolEngineJitInfer(const std::string & name,const std::string & func_name,const symshape::CppVisitorPtr & cpp_visitor,const ListSymbolPtr & output_symbol)46   SymbolEngineJitInfer(const std::string &name, const std::string &func_name,
47                        const symshape::CppVisitorPtr &cpp_visitor, const ListSymbolPtr &output_symbol)
48       : InferShapeFunctor(name), func_name_(func_name), cpp_visitor_(cpp_visitor), output_symbol_(output_symbol) {
49     Init();
50   }
51   MS_DECLARE_PARENT(SymbolEngineJitInfer, InferShapeFunctor)
52   BaseShapePtr InferShape(const AbstractBasePtrList &args) override;
53 
54  protected:
55   void Init();
56 
57  private:
58   std::string func_name_;
59   symshape::CppVisitorPtr cpp_visitor_;
60   ListSymbolPtr output_symbol_;
61   symshape::CppVisitor::DynFuncType infer_func_ = nullptr;
62   std::vector<int64_t *> output_parm_;
63   ShapeArray out_shapes_;
64 };
65 
66 class SetInferShapeFunctor : public opt::Pass {
67  public:
Pass(pass_name)68   explicit SetInferShapeFunctor(const std::string &pass_name = "set_infershape_funtor") : Pass(pass_name) {}
69   ~SetInferShapeFunctor() override = default;
70   bool Run(const FuncGraphPtr &func_graph) override;
71 };
72 }  // namespace mindspore::graphkernel
73 #endif  // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SET_INFER_FUNCTOR_H_
74