1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "coder/generator/component/const_blocks/msession.h" 18 19 namespace mindspore::lite::micro { 20 const char session_header[] = R"RAW( 21 /** 22 * Copyright 2021 Huawei Technologies Co., Ltd 23 * 24 * Licensed under the Apache License, Version 2.0 (the "License"); 25 * you may not use this file except in compliance with the License. 26 * You may obtain a copy of the License at 27 * 28 * http://www.apache.org/licenses/LICENSE-2.0 29 * 30 * Unless required by applicable law or agreed to in writing, software 31 * distributed under the License is distributed on an "AS IS" BASIS, 32 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 33 * See the License for the specific language governing permissions and 34 * limitations under the License. 35 */ 36 37 #ifndef MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_SESSION_H_ 38 #define MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_SESSION_H_ 39 40 #include "include/errorcode.h" 41 #include "include/lite_session.h" 42 43 #include "tensor.h" 44 45 namespace mindspore { 46 namespace lite { 47 48 #define MS_ERROR_IF_NULL(ptr) \ 49 do { \ 50 if ((ptr) == nullptr) { \ 51 return mindspore::lite::RET_ERROR; \ 52 } \ 53 } while (0) 54 55 #define MS_NULLPTR_IF_NULL(ptr) \ 56 do { \ 57 if ((ptr) == nullptr) { \ 58 return nullptr; \ 59 } \ 60 } while (0) 61 62 #define MS_NULLPTR_IF_ERROR(ptr) \ 63 do { \ 64 if ((ptr) != mindspore::lite::RET_OK) { \ 65 return nullptr; \ 66 } \ 67 } while (0) 68 69 class LiteSession : public session::LiteSession { 70 public: 71 LiteSession() = default; 72 73 ~LiteSession() override; 74 75 void BindThread(bool if_bind) override {} 76 77 int CompileGraph(lite::Model *model) override; 78 79 Vector<tensor::MSTensor *> GetInputs() const override; 80 81 mindspore::tensor::MSTensor *GetInputsByTensorName(const String &tensor_name) const override { return nullptr; } 82 83 int RunGraph(const KernelCallBack &before = nullptr, const KernelCallBack &after = nullptr) override; 84 85 Vector<tensor::MSTensor *> GetOutputsByNodeName(const String &node_name) const override; 86 87 Vector<String> GetOutputTensorNames() const override; 88 89 mindspore::tensor::MSTensor *GetOutputByTensorName(const String &tensor_name) const override; 90 91 int Resize(const Vector<tensor::MSTensor *> &inputs, const Vector<Vector<int>> &dims) override { return RET_ERROR; } 92 93 int InitRuntimeBuffer(); 94 95 private: 96 Vector<MTensor *> inputs_; 97 Vector<MTensor *> outputs_; 98 void *runtime_buffer_; 99 }; 100 101 } // namespace lite 102 } // namespace mindspore 103 104 #endif // MINDSPORE_LITE_MICRO_LIBRARY_SOURCE_SESSION_H_ 105 106 )RAW"; 107 108 const char session_source[] = R"RAW( 109 int LiteSession::RunGraph(const KernelCallBack &before, const KernelCallBack &after) { 110 const void *inputs_data[inputs_.size()]; 111 for (size_t i = 0; i < inputs_.size(); ++i) { 112 inputs_data[i] = inputs_[i]->MutableData(); 113 } 114 SetInputs(inputs_data, inputs_.size()); 115 116 Inference(); 117 118 void *outputs_data[outputs_.size()]; 119 for (size_t i = 0; i < outputs_.size(); ++i) { 120 outputs_data[i] = outputs_[i]->MutableData(); 121 } 122 CopyOutputsData(outputs_data, outputs_.size()); 123 124 return RET_OK; 125 } 126 127 int LiteSession::InitRuntimeBuffer() { 128 int buffer_size = GetBufferSize(); 129 runtime_buffer_ = malloc(buffer_size); 130 if (runtime_buffer_ == nullptr) { 131 return RET_ERROR; 132 } 133 int ret = SetBuffer(runtime_buffer_); 134 if (ret != RET_OK) { 135 return RET_ERROR; 136 } 137 return RET_OK; 138 } 139 140 Vector<tensor::MSTensor *> LiteSession::GetInputs() const { 141 Vector<tensor::MSTensor *> inputs; 142 for (const auto &input : inputs_) { 143 inputs.push_back(input); 144 } 145 return inputs; 146 } 147 148 Vector<tensor::MSTensor *> LiteSession::GetOutputsByNodeName(const String &node_name) const { 149 Vector<tensor::MSTensor *> outputs; 150 return outputs; 151 } 152 153 Vector<String> LiteSession::GetOutputTensorNames() const { 154 Vector<String> output_names; 155 for (const auto &output : outputs_) { 156 output_names.push_back(output->tensor_name()); 157 } 158 return output_names; 159 } 160 161 mindspore::tensor::MSTensor *LiteSession::GetOutputByTensorName(const String &tensor_name) const { 162 for (const auto &output : outputs_) { 163 if (output->tensor_name() == tensor_name) { 164 return output; 165 } 166 } 167 return nullptr; 168 } 169 } // namespace lite 170 171 )RAW"; 172 } // namespace mindspore::lite::micro 173