• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_
19 
20 #include <set>
21 #include <string>
22 #include <vector>
23 #include <map>
24 #include <memory>
25 
26 #include "transform/graph_ir/types.h"
27 #include "transform/graph_ir/util.h"
28 #include "ir/tensor.h"
29 #include "transform/graph_ir/df_graph_manager.h"
30 
31 namespace mindspore {
32 namespace transform {
33 using SessionOptions = std::map<std::string, std::string>;
34 
35 struct GraphRunnerOptions {
36   std::string target{"default_graph_runner"};
37   SessionOptions options;
38   // if sess_ptr is nullptr, GraphRunner will create a new ge session
39   std::shared_ptr<ge::Session> sess_ptr{nullptr};
40 };
41 
42 struct RunOptions {
43   // graph's name
44   std::string name;
45 };
46 
47 class GraphRunner {
48  public:
49   explicit GraphRunner(const GraphRunnerOptions &options);
~GraphRunner()50   ~GraphRunner() { sess_ = nullptr; }
51   Status RunGraph(const RunOptions &options, const std::vector<MeTensorPtr> &inputs, std::vector<MeTensorPtr> *outputs);
52   Status RunGraph(const RunOptions &options, const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs);
53   static std::shared_ptr<ge::Session> NewSession(const SessionOptions &sess_options);
54 
55  private:
56   std::shared_ptr<ge::Session> sess_;
57   transform::GraphRunnerOptions options_;
58   DfGraphManager &graph_manager_;
59 };
60 }  // namespace transform
61 }  // namespace mindspore
62 
63 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_
64