1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_LITE_EXTENDRT_SESSION_LITE_GRAPH_EXECUTOR_H_ 17 #define MINDSPORE_LITE_EXTENDRT_SESSION_LITE_GRAPH_EXECUTOR_H_ 18 19 #include <vector> 20 #include <string> 21 #include <memory> 22 #include <map> 23 24 #include "include/api/types.h" 25 #include "runtime/hardware/device_context.h" 26 27 namespace mindspore { 28 /// \brief Adaptive Graph Executor for cloud Graph Executor to solve interface conflicts. 29 class LiteGraphExecutor : public device::GraphExecutor { 30 public: 31 LiteGraphExecutor() = default; 32 virtual ~LiteGraphExecutor() = default; 33 CompileGraph(const FuncGraphPtr & graph,const std::map<string,string> & compile_options,uint32_t * graph_id)34 virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options, 35 uint32_t *graph_id) { 36 return false; 37 } 38 CompileGraph(const void * model_data,size_t data_size,const std::map<string,string> & compile_options,uint32_t * graph_id)39 virtual bool CompileGraph(const void *model_data, size_t data_size, const std::map<string, string> &compile_options, 40 uint32_t *graph_id) { 41 return false; 42 } 43 UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> & weights)44 virtual bool UpdateWeights(const std::vector<std::vector<std::shared_ptr<tensor::Tensor>>> &weights) { 45 MS_LOG(ERROR) << "UpdateWeights failed."; 46 return false; 47 } 48 RunGraph(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs,const std::map<string,string> & compile_options)49 virtual bool RunGraph(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs, 50 std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) { 51 (void)graph_id; 52 (void)inputs; 53 (void)outputs; 54 (void)compile_options; 55 return false; 56 } 57 Resize(uint32_t graph_id,const std::vector<tensor::Tensor> & inputs,const std::vector<std::vector<int64_t>> & new_shapes)58 virtual bool Resize(uint32_t graph_id, const std::vector<tensor::Tensor> &inputs, 59 const std::vector<std::vector<int64_t>> &new_shapes) { 60 (void)graph_id; 61 (void)inputs; 62 (void)new_shapes; 63 return true; 64 } GetInputInfos(uint32_t graph_id)65 virtual std::vector<tensor::Tensor> GetInputInfos(uint32_t graph_id) { 66 (void)graph_id; 67 return {}; 68 } GetOutputInfos(uint32_t graph_id)69 virtual std::vector<tensor::Tensor> GetOutputInfos(uint32_t graph_id) { 70 (void)graph_id; 71 return {}; 72 } 73 SetBefore(const MSKernelCallBack & before)74 void SetBefore(const MSKernelCallBack &before) { before_ = before; } 75 SetAfter(const MSKernelCallBack & after)76 void SetAfter(const MSKernelCallBack &after) { after_ = after; } 77 78 protected: 79 MSKernelCallBack before_; 80 MSKernelCallBack after_; 81 }; 82 } // namespace mindspore 83 84 #endif // MINDSPORE_LITE_EXTENDRT_SESSION_LITE_GRAPH_EXECUTOR_H_ 85