• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_MODEL_STATE_H_
18 #define MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_MODEL_STATE_H_
19 
20 #include <memory>
21 #include <string>
22 #include <vector>
23 #include <future>
24 #include "src/mslite_utils.h"
25 #include "triton/backend/backend_model.h"
26 #include "triton/backend/backend_model_instance.h"
27 #include "triton/core/tritonserver.h"
28 #include "include/api/context.h"
29 #include "include/api/model.h"
30 #include "triton/backend/backend_input_collector.h"
31 #include "triton/backend/backend_output_responder.h"
32 
33 namespace triton {
34 namespace backend {
35 namespace mslite {
36 // ModelState
37 //
38 // State associated with a model that is using this backend. An object
39 // of this class is createdBLSExecutor and associated with each
40 // TRITONBACKEND_Model.
41 //
42 class ModelState : public BackendModel {
43  public:
44   static TRITONSERVER_Error *Create(TRITONBACKEND_Model *triton_model, ModelState **state);
45   virtual ~ModelState() = default;
46 
47   TRITONSERVER_Error *ParseModelParameterConfig();
48   TRITONSERVER_Error *InitMSContext();
49   // Init mslite model
50   std::shared_ptr<mindspore::Model> BuildMSModel();
51 
PreferredBatchSize()52   std::vector<int64_t> PreferredBatchSize() { return preferred_batch_size_; }
SupportDynamicBatch()53   bool SupportDynamicBatch() { return support_dynamic_batch_; }
54 
55  private:
ModelState(TRITONBACKEND_Model * triton_model)56   explicit ModelState(TRITONBACKEND_Model *triton_model) : BackendModel(triton_model) {}
57 
58   std::shared_ptr<mindspore::Context> ms_context_ = nullptr;
59   bool support_dynamic_batch_ = false;
60   std::vector<int64_t> preferred_batch_size_;
61   std::string model_type_ = "mindir";
62   std::string device_type_ = "";
63   int device_id_ = 0;
64 };
65 
66 //
67 // ModelInstanceState
68 //
69 // State associated with a model instance. An object of this class is
70 // created and associated with each TRITONBACKEND_ModelInstance.
71 //
72 class ModelInstanceState : public BackendModelInstance {
73  public:
74   static TRITONSERVER_Error *Create(ModelState *model_state, TRITONBACKEND_ModelInstance *triton_model_instance,
75                                     ModelInstanceState **state);
76   virtual ~ModelInstanceState() = default;
77 
78   void ProcessRequests(TRITONBACKEND_Request **requests, const uint32_t request_count);
79 
80  private:
ModelInstanceState(ModelState * model_state,TRITONBACKEND_ModelInstance * triton_model_instance)81   ModelInstanceState(ModelState *model_state, TRITONBACKEND_ModelInstance *triton_model_instance)
82       : BackendModelInstance(reinterpret_cast<BackendModel *>(model_state), triton_model_instance),
83         model_state_(model_state) {
84     max_batch_size_ = model_state->MaxBatchSize();
85     preferred_batch_size_ = model_state->PreferredBatchSize();
86     support_dynamic_batch_ = model_state->SupportDynamicBatch();
87   }
88 
89   TRITONSERVER_Error *ProcessInputs(TRITONBACKEND_Request **requests, const uint32_t request_count,
90                                     std::vector<TRITONBACKEND_Response *> *responses);
91 
92   TRITONSERVER_Error *ProcessOutputs(TRITONBACKEND_Request **requests, const uint32_t request_count,
93                                      std::vector<TRITONBACKEND_Response *> *responses);
94 
95   ModelState *model_state_;
96   std::shared_ptr<mindspore::Model> ms_model_;
97 
98   std::vector<mindspore::MSTensor> inputs_;
99   std::vector<mindspore::MSTensor> outputs_;
100 
101   int64_t batched_size_ = 0;
102   int64_t max_batch_size_ = 0;
103   std::vector<int64_t> preferred_batch_size_;
104   bool support_dynamic_batch_ = false;
105   bool need_resize_ = false;
106 };
107 }  // namespace mslite
108 }  // namespace backend
109 }  // namespace triton
110 #endif  // MINDSPORE_LITE_TOOLS_PROVIDERS_TRITON_BACKEND_SRC_MSLITE_MODEL_STATE_H_
111