• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_
18 #define MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 
24 #include "transform/graph_ir/transform_util.h"
25 #include "transform/graph_ir/df_graph_manager.h"
26 #include "ir/tensor.h"
27 
28 namespace mindspore {
29 namespace transform {
30 class GraphRunner {
31  public:
32   explicit GraphRunner(const GraphRunnerOptions &options);
~GraphRunner()33   ~GraphRunner() { sess_ = nullptr; }
34   Status AddGraph(const std::string &name);
35   Status RunGraph(const RunOptions &options, const std::vector<MeTensorPtr> &inputs, std::vector<MeTensorPtr> *outputs);
36   Status RunGraph(const RunOptions &options, const std::vector<GeTensorPtr> &inputs, std::vector<GeTensorPtr> *outputs);
37   Status RunGraphAsync(const RunOptions &options, const std::vector<GeTensorPtr> &inputs,
38                        std::vector<GeTensorPtr> *outputs);
39   Status RunGraphWithStreamAsync(const RunOptions &options, void *stream, const std::vector<GeTensor> &inputs,
40                                  std::vector<GeTensor> *outputs);
41   Status CompileGraph(const RunOptions &options);
42   Status CompileGraph(const RunOptions &options, ::ge::CompiledGraphSummaryPtr *graph_summary);
43   Status SetConstMemory(const RunOptions &options, const void *const memory, size_t size);
44   Status UpdateFeatureMemory(const RunOptions &options, const void *const memory, size_t size);
45   static std::shared_ptr<::ge::Session> NewSession(const SessionOptions &sess_options);
46   Status RegisterExternalAllocator(const void *const stream, GeAllocatorPtr allocator);
47   Status UnregisterExternalAllocator(const void *const stream);
IsAllocatorRegistered()48   const bool IsAllocatorRegistered() const { return is_allocator_registered; }
49 
50  private:
51   Status GetWrapper(const std::string &name, DfGraphWrapperPtr *wrapper) const;
52 
53   std::shared_ptr<::ge::Session> sess_;
54   transform::GraphRunnerOptions options_;
55   DfGraphManager &graph_manager_;
56   bool is_allocator_registered = false;
57 };
58 }  // namespace transform
59 }  // namespace mindspore
60 
61 #endif  // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_GRAPH_RUNNER_H_
62