• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_LITE_EXTENDRT_SESSION_LITE_GRAPH_EXECUTOR_H_
17 #define MINDSPORE_LITE_EXTENDRT_SESSION_LITE_GRAPH_EXECUTOR_H_
18 
19 #include <vector>
20 #include <string>
21 #include <memory>
22 #include <map>
23 
24 #include "include/api/types.h"
25 #include "runtime/hardware/device_context.h"
26 
27 namespace mindspore {
28 /// \brief Adaptive Graph Executor for cloud Graph Executor to solve interface conflicts.
29 class LiteGraphExecutor : public device::GraphExecutor {
30  public:
31   LiteGraphExecutor() = default;
32   virtual ~LiteGraphExecutor() = default;
33 
CompileGraph(const FuncGraphPtr & graph,const std::map<string,string> & compile_options,uint32_t * graph_id)34   virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options,
35                             uint32_t *graph_id) {
36     return false;
37   }
38 
CompileGraph(const void * model_data,size_t data_size,const std::map<string,string> & compile_options,uint32_t * graph_id)39   virtual bool CompileGraph(const void *model_data, size_t data_size, const std::map<string, string> &compile_options,
40                             uint32_t *graph_id) {
41     return false;
42   }
43 
UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> & weights)44   virtual bool UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights) {
45     MS_LOG(ERROR) << "UpdateWeights failed.";
46     return false;
47   }
48 
RunGraph(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs,const std::map<string,string> & compile_options)49   virtual bool RunGraph(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs,
50                         std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) {
51     (void)graph_id;
52     (void)inputs;
53     (void)outputs;
54     (void)compile_options;
55     return false;
56   }
57 
Resize(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,const std::vector<std::vector<int64_t>> & new_shapes)58   virtual bool Resize(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs,
59                       const std::vector<std::vector<int64_t>> &new_shapes) {
60     (void)graph_id;
61     (void)inputs;
62     (void)new_shapes;
63     return true;
64   }
GetInputInfos(uint32_t graph_id)65   virtual std::vector<tensor::Tensor> GetInputInfos(uint32_t graph_id) {
66     (void)graph_id;
67     return {};
68   }
GetOutputInfos(uint32_t graph_id)69   virtual std::vector<tensor::Tensor> GetOutputInfos(uint32_t graph_id) {
70     (void)graph_id;
71     return {};
72   }
73 
SetBefore(const MSKernelCallBack & before)74   void SetBefore(const MSKernelCallBack &before) { before_ = before; }
75 
SetAfter(const MSKernelCallBack & after)76   void SetAfter(const MSKernelCallBack &after) { after_ = after; }
77 
78  protected:
79   MSKernelCallBack before_;
80   MSKernelCallBack after_;
81 };
82 }  // namespace mindspore
83 
84 #endif  // MINDSPORE_LITE_EXTENDRT_SESSION_LITE_GRAPH_EXECUTOR_H_
85