1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SET_INFERSHAPE_FUNCTOR_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SET_INFERSHAPE_FUNCTOR_H_ 18 #include <string> 19 #include <vector> 20 21 #include "ir/func_graph.h" 22 #include "include/backend/visible.h" 23 #include "include/backend/optimizer/pass.h" 24 #include "backend/common/graph_kernel/symbol_engine/jit/cpp_visitor.h" 25 #include "backend/common/optimizer/dynamic_shape/dynamic_shape_helper.h" 26 27 namespace mindspore::graphkernel { 28 using opt::dynamic_shape::InferShapeFunctor; 29 using opt::dynamic_shape::kAttrInferShapeFunctor; 30 31 class BACKEND_EXPORT SymbolEngineInfer : public InferShapeFunctor { 32 public: SymbolEngineInfer(const std::string & name,const SymbolEnginePtr & engine,const AbstractBasePtr & out_abstract)33 SymbolEngineInfer(const std::string &name, const SymbolEnginePtr &engine, const AbstractBasePtr &out_abstract) 34 : InferShapeFunctor(name), engine_(engine), out_abstract_(out_abstract) {} 35 ~SymbolEngineInfer() override = default; 36 MS_DECLARE_PARENT(SymbolEngineInfer, InferShapeFunctor) 37 BaseShapePtr InferShape(const AbstractBasePtrList &args) override; 38 39 protected: 40 SymbolEnginePtr engine_; 41 AbstractBasePtr out_abstract_; 42 }; 43 44 class SymbolEngineJitInfer : public InferShapeFunctor { 45 public: SymbolEngineJitInfer(const std::string & name,const std::string & func_name,const symshape::CppVisitorPtr & cpp_visitor,const ListSymbolPtr & output_symbol)46 SymbolEngineJitInfer(const std::string &name, const std::string &func_name, 47 const symshape::CppVisitorPtr &cpp_visitor, const ListSymbolPtr &output_symbol) 48 : InferShapeFunctor(name), func_name_(func_name), cpp_visitor_(cpp_visitor), output_symbol_(output_symbol) { 49 Init(); 50 } 51 MS_DECLARE_PARENT(SymbolEngineJitInfer, InferShapeFunctor) 52 BaseShapePtr InferShape(const AbstractBasePtrList &args) override; 53 54 protected: 55 void Init(); 56 57 private: 58 std::string func_name_; 59 symshape::CppVisitorPtr cpp_visitor_; 60 ListSymbolPtr output_symbol_; 61 symshape::CppVisitor::DynFuncType infer_func_ = nullptr; 62 std::vector<int64_t *> output_parm_; 63 ShapeArray out_shapes_; 64 }; 65 66 class SetInferShapeFunctor : public opt::Pass { 67 public: Pass(pass_name)68 explicit SetInferShapeFunctor(const std::string &pass_name = "set_infershape_funtor") : Pass(pass_name) {} 69 ~SetInferShapeFunctor() override = default; 70 bool Run(const FuncGraphPtr &func_graph) override; 71 }; 72 } // namespace mindspore::graphkernel 73 #endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_SET_INFER_FUNCTOR_H_ 74