1 /** 2 * Copyright 2019-2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ 19 #include <map> 20 #include <memory> 21 #include <string> 22 #include <vector> 23 #include <tuple> 24 #include "runtime/base.h" 25 #include "runtime/rt_model.h" 26 #include "runtime/device/ascend/ge_runtime/davinci_model.h" 27 28 namespace mindspore::ge::model_runner { 29 using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>; 30 class Task; 31 class RuntimeModel { 32 public: 33 RuntimeModel() = default; 34 ~RuntimeModel(); 35 36 void Load(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model); 37 void DistributeTask(); 38 void LoadComplete(); 39 const std::vector<uint32_t> &GetTaskIdList() const; 40 const std::vector<uint32_t> &GetStreamIdList() const; GetRuntimeInfoMap()41 const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; } GetModelHandle()42 rtModel_t GetModelHandle() const { return rt_model_handle_; } GetModelStream()43 rtStream_t GetModelStream() const { return rt_model_stream_; } 44 void Run(); 45 46 private: 47 void InitResource(const std::shared_ptr<DavinciModel> &davinci_model); 48 void GenerateTask(uint32_t device_id, uint64_t session_id, const std::shared_ptr<DavinciModel> &davinci_model); 49 void InitStream(const std::shared_ptr<DavinciModel> &davinci_model); 50 void InitEvent(uint32_t event_num); 51 void InitLabel(const std::shared_ptr<DavinciModel> &davinci_model); 52 void RtModelUnbindStream() noexcept; 53 void RtStreamDestory() noexcept; 54 void RtModelDestory() noexcept; 55 void RtLabelDestory() noexcept; 56 void RtEventDestory() noexcept; 57 58 rtModel_t rt_model_handle_{}; 59 rtStream_t rt_model_stream_{}; 60 61 std::vector<rtStream_t> stream_list_{}; 62 std::vector<rtLabel_t> label_list_{}; 63 std::vector<rtEvent_t> event_list_{}; 64 65 std::vector<std::shared_ptr<Task>> task_list_{}; 66 67 std::vector<uint32_t> task_id_list_{}; 68 std::vector<uint32_t> stream_id_list_{}; 69 std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_; 70 }; 71 } // namespace mindspore::ge::model_runner 72 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_GE_RUNTIME_RUNTIME_MODEL_H_ 73