• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
17 #define TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
18 
19 #include <string>
20 #include <memory>
21 #include "ir/anf.h"
22 #include "ir/primitive.h"
23 #include "ir/manager.h"
24 #include "ir/func_graph.h"
25 #include "pipeline/jit/parse/parse_base.h"
26 #include "pipeline/jit/parse/parse.h"
27 #include "pipeline/jit/parse/resolve.h"
28 
29 namespace UT {
30 
31 void InitPythonPath();
32 
33 class PyFuncGraphFetcher {
34  public:
35   explicit PyFuncGraphFetcher(std::string model_path, bool doResolve = false)
model_path_(model_path)36       : model_path_(model_path), doResolve_(doResolve) {
37     InitPythonPath();
38   }
39   void SetDoResolve(bool doResolve = true) { doResolve_ = doResolve; }
40 
41   // The return of python function of "func_name" should be py::function.
42   // step 1. Call the function user input
43   // step 2. Parse the return "fn"
44   template <class... T>
CallAndParseRet(std::string func_name,T...args)45   mindspore::FuncGraphPtr CallAndParseRet(std::string func_name, T... args) {
46     try {
47       py::function fn = mindspore::parse::python_adapter::CallPyFn(model_path_.c_str(), func_name.c_str(), args...);
48       mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
49       if (doResolve_) {
50         std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
51         mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
52         mindspore::parse::ResolveAll(manager);
53       }
54       return func_graph;
55     } catch (py::error_already_set& e) {
56       MS_LOG(ERROR) << "Call and parse fn failed!!! error:" << e.what();
57       return nullptr;
58     } catch (...) {
59       MS_LOG(ERROR) << "Call fn failed!!!";
60       return nullptr;
61     }
62   }
63 
64   // Fetch python function then parse to graph
operator()65   mindspore::FuncGraphPtr operator()(std::string func_name, std::string model_path = "") {
66     try {
67       std::string path = model_path_;
68       if ("" != model_path) {
69         path = model_path;
70       }
71       py::function fn = mindspore::parse::python_adapter::GetPyFn(path.c_str(), func_name.c_str());
72       mindspore::FuncGraphPtr func_graph = mindspore::parse::ParsePythonCode(fn);
73       if (doResolve_) {
74         std::shared_ptr<mindspore::FuncGraphManager> manager = mindspore::Manage(func_graph, false);
75         mindspore::parse::python_adapter::set_use_signature_in_resolve(false);
76         mindspore::parse::ResolveAll(manager);
77       }
78       return func_graph;
79     } catch (py::error_already_set& e) {
80       MS_LOG(ERROR) << "get fn failed!!! error:" << e.what();
81       return nullptr;
82     } catch (...) {
83       MS_LOG(ERROR) << "get fn failed!!!";
84       return nullptr;
85     }
86   }
87 
88  private:
89   std::string model_path_;
90   bool doResolve_;
91 };
92 
93 }  // namespace UT
94 #endif  // TESTS_UT_COMMON_PY_FUNC_GRAPH_FETCHER_H_
95