• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/mslite_model_state.h"
18 #include <unistd.h>
19 #include <algorithm>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "triton/backend/backend_input_collector.h"
26 #include "triton/backend/backend_output_responder.h"
27 
28 namespace triton {
29 namespace backend {
30 namespace mslite {
Create(TRITONBACKEND_Model * triton_model,ModelState ** state)31 TRITONSERVER_Error *ModelState::Create(TRITONBACKEND_Model *triton_model, ModelState **state) {
32   try {
33     *state = new ModelState(triton_model);
34   } catch (const BackendModelException &ex) {
35     RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
36                          std::string("unexpected nullptr in BackendModelException"));
37     RETURN_IF_ERROR(ex.err_);
38   }
39 
40   RETURN_IF_ERROR((*state)->ParseModelConfig());
41   RETURN_IF_ERROR((*state)->ParseModelParameterConfig());
42   RETURN_IF_ERROR((*state)->InitMSContext());
43   return nullptr;  // success
44 }
45 
InitMSContext()46 TRITONSERVER_Error *ModelState::InitMSContext() {
47   // Init ms context.
48   ms_context_ = std::make_shared<mindspore::Context>();
49   RETURN_ERROR_IF_TRUE(ms_context_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
50                        std::string("New mindspore-lite context failed."));
51   auto &device_list = ms_context_->MutableDeviceInfo();
52 
53   if (device_type_ == "ascend") {
54     auto ascend_device_info = std::make_shared<mindspore::AscendDeviceInfo>();
55     RETURN_ERROR_IF_TRUE(ascend_device_info == nullptr, TRITONSERVER_ERROR_INTERNAL,
56                          std::string("New AscendDeviceInfo failed for mindspore-lite context."));
57     ascend_device_info->SetDeviceID(device_id_);
58     device_list.push_back(ascend_device_info);
59   }
60   auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
61   RETURN_ERROR_IF_TRUE(device_info == nullptr, TRITONSERVER_ERROR_INTERNAL,
62                        std::string("New CPUDeviceInfo failed for mindspore-lite context."));
63   device_list.push_back(device_info);
64   return nullptr;
65 }
66 
BuildMSModel()67 std::shared_ptr<mindspore::Model> ModelState::BuildMSModel() {
68   auto ms_model = std::make_shared<mindspore::Model>();
69   if (ms_model == nullptr) {
70     LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "New mslite model failed.");
71     return nullptr;
72   }
73   auto model_path = JoinPath({this->RepositoryPath(), std::to_string(this->Version())});
74   auto model_name = this->Name() + "." + model_type_;
75   auto model_file = JoinPath({model_path, model_name});
76   auto model_type = model_type_ == "mindir" ? mindspore::kMindIR : mindspore::kMindIR_Lite;
77   LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Begin to init mslite model file: ") + model_file).c_str());
78   auto build_ret = ms_model->Build(model_file, model_type, ms_context_);
79   if (build_ret != mindspore::kSuccess) {
80     LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Build mslite model failed.");
81     return nullptr;
82   }
83   return ms_model;
84 }
85 
ParseModelParameterConfig()86 TRITONSERVER_Error *ModelState::ParseModelParameterConfig() {
87   // parse dynamic batching config. the max_batch_size is parsed in the inner method ParseModelConfig.
88   // dynamic batch may means that the max batch is not zero.
89   RETURN_IF_ERROR(this->SupportsFirstDimBatching(&support_dynamic_batch_));
90 
91   // for ascend, use the dynamic_batching to judge the real state of supporting dynamic batch.
92   common::TritonJson::Value dynamic_batching;
93   support_dynamic_batch_ &= model_config_.Find("dynamic_batching", &dynamic_batching);
94   if (support_dynamic_batch_) {
95     RETURN_IF_ERROR(backend::ParseShape(dynamic_batching, "preferred_batch_size", &preferred_batch_size_));
96     std::sort(preferred_batch_size_.begin(), preferred_batch_size_.end());
97     RETURN_ERROR_IF_FALSE(std::all_of(preferred_batch_size_.begin(), preferred_batch_size_.end(),
98                                       [&](int64_t dim) { return dim <= max_batch_size_; }),
99                           TRITONSERVER_ERROR_INVALID_ARG,
100                           std::string("the preferred batch size should not be larger than the max batch size."));
101   }
102 
103   // parse parameters.
104   common::TritonJson::Value parameters;
105   if (model_config_.Find("parameters", &parameters)) {
106     (void)GetParameterValue(parameters, "model_type", &model_type_);
107     (void)GetParameterValue(parameters, "device_type", &device_type_);
108     std::string device_id;
109     (void)GetParameterValue(parameters, "device_id", &device_id);
110     if (!device_id.empty()) {
111       try {
112         device_id_ = std::stoi(device_id);
113       } catch (const BackendModelInstanceException &ex) {
114         RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
115                              std::string("The device_id is not valid.") + device_id);
116         RETURN_IF_ERROR(ex.err_);
117       }
118     }
119   }
120   return nullptr;  // success
121 }
122 
Create(ModelState * model_state,TRITONBACKEND_ModelInstance * triton_model_instance,ModelInstanceState ** state)123 TRITONSERVER_Error *ModelInstanceState::Create(ModelState *model_state,
124                                                TRITONBACKEND_ModelInstance *triton_model_instance,
125                                                ModelInstanceState **state) {
126   try {
127     *state = new ModelInstanceState(model_state, triton_model_instance);
128   } catch (const BackendModelInstanceException &ex) {
129     RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
130                          std::string("unexpected nullptr in BackendModelInstanceException"));
131     RETURN_IF_ERROR(ex.err_);
132   }
133 
134   (*state)->ms_model_ = model_state->BuildMSModel();
135   RETURN_ERROR_IF_TRUE((*state)->ms_model_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
136                        std::string("Init mslite model failed."));
137   return nullptr;  // success
138 }
139 
ProcessInputs(TRITONBACKEND_Request ** requests,const uint32_t request_count,std::vector<TRITONBACKEND_Response * > * responses)140 TRITONSERVER_Error *ModelInstanceState::ProcessInputs(TRITONBACKEND_Request **requests, const uint32_t request_count,
141                                                       std::vector<TRITONBACKEND_Response *> *responses) {
142   // To instruct ProcessTensor to "gather" the entire batch of input
143   // tensors into a single contiguous buffer in CPU memory, set the
144   // "allowed input types" to be the CPU ones (see tritonserver.h in
145   // the triton-inference-server/core repo for allowed memory types).
146   BackendInputCollector collector(requests, request_count, responses, model_state_->TritonMemoryManager(),
147                                   model_state_->EnablePinnedInput(), nullptr);
148 
149   std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> allowed_input_types = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
150                                                                                   {TRITONSERVER_MEMORY_CPU, 0}};
151 
152   uint32_t input_count = 0;
153   RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
154 
155   auto model_inputs = ms_model_->GetInputs();
156   RETURN_ERROR_IF_FALSE(model_inputs.size() == input_count, TRITONSERVER_ERROR_INVALID_ARG,
157                         std::string("The input count is not equal to model inputs: ") +
158                           std::to_string(model_inputs.size()) + std::string(", while the input count of request is: ") +
159                           std::to_string(input_count));
160 
161   inputs_.clear();
162   for (uint32_t idx = 0; idx < input_count; idx++) {
163     TRITONBACKEND_Input *input;
164     RETURN_IF_ERROR(TRITONBACKEND_RequestInputByIndex(requests[0], idx, &input));
165     const char *input_name;
166     TRITONSERVER_DataType input_datatype;
167     const int64_t *input_shape;
168     uint32_t input_dims_count;
169     batched_size_ = 0;
170     for (uint32_t r = 0; r < request_count; r++) {
171       RETURN_IF_ERROR(TRITONBACKEND_InputProperties(input, &input_name, &input_datatype, &input_shape,
172                                                     &input_dims_count, nullptr, nullptr));
173       batched_size_ += *input_shape;
174     }
175     RETURN_ERROR_IF_TRUE(model_state_->MaxBatchSize() > 0 && batched_size_ > model_state_->MaxBatchSize(),
176                          TRITONSERVER_ERROR_INVALID_ARG,
177                          std::string("The input batch size is larger than the max batch size."));
178     auto data_type = GetMSDataTypeFromTritonServerDataType(input_datatype);
179     RETURN_ERROR_IF_FALSE(data_type == model_inputs.at(idx).DataType(), TRITONSERVER_ERROR_INVALID_ARG,
180                           std::string("The input data type is not equal to model input."));
181 
182     std::vector<int64_t> batched_shape(input_dims_count);
183     std::memcpy(batched_shape.data(), input_shape, input_dims_count * sizeof(int64_t));
184     batched_shape.at(0) = batched_size_;
185     if (support_dynamic_batch_) {
186       auto pad_batch_itr = std::lower_bound(preferred_batch_size_.begin(), preferred_batch_size_.end(), batched_size_);
187       batched_shape.at(0) = pad_batch_itr != preferred_batch_size_.end() ? *pad_batch_itr : max_batch_size_;
188       LOG_MESSAGE(TRITONSERVER_LOG_INFO,
189                   std::string("The batched size will be pad to " + std::to_string(batched_shape.at(0))).c_str());
190       need_resize_ = model_inputs.at(idx).Shape() != batched_shape;
191     }
192 
193     const char *input_buffer = nullptr;
194     size_t input_buffer_byte_size;
195     TRITONSERVER_MemoryType input_buffer_memory_type;
196     int64_t input_buffer_memory_type_id;
197     RETURN_IF_ERROR(collector.ProcessTensor(
198       input_name, nullptr /* existing_buffer */, 0 /* existing_buffer_byte_size */, allowed_input_types, &input_buffer,
199       &input_buffer_byte_size, &input_buffer_memory_type, &input_buffer_memory_type_id));
200     RETURN_ERROR_IF_TRUE(input_buffer == nullptr || input_buffer_byte_size == 0, TRITONSERVER_ERROR_INTERNAL,
201                          std::string("Process input tensor data failed."));
202 
203     auto input_tensor = mindspore::MSTensor(input_name, data_type, {}, nullptr, 0);
204     input_tensor.SetShape(batched_shape);
205     if (!support_dynamic_batch_ && input_tensor.DataSize() == input_buffer_byte_size) {
206       // speed up with non-memcpy in static shape,
207       // while the batched data from several requests may cause a host-to-device memcpy error.
208       LOG_MESSAGE(TRITONSERVER_LOG_INFO, std::string("use the original data without copy.").c_str());
209       input_tensor.SetData(reinterpret_cast<void *>(const_cast<char *>(input_buffer)), false);
210     } else {
211       LOG_MESSAGE(TRITONSERVER_LOG_INFO, std::string("malloc data because the data size is not equal").c_str());
212       auto input_data = input_tensor.MutableData();
213       auto data_size = input_tensor.DataSize();
214       RETURN_ERROR_IF_TRUE(input_data == nullptr || input_buffer_byte_size > data_size, TRITONSERVER_ERROR_INTERNAL,
215                            std::string("Process input tensor data failed."));
216 
217       std::memset(input_data, 0, input_tensor.DataSize());
218       std::memcpy(input_data, input_buffer, input_buffer_byte_size);
219     }
220     inputs_.push_back(input_tensor);
221   }
222   collector.Finalize();
223   LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Process inputs tensor finished.").c_str()));
224   return nullptr;  // success
225 }
226 
ProcessOutputs(TRITONBACKEND_Request ** requests,const uint32_t request_count,std::vector<TRITONBACKEND_Response * > * responses)227 TRITONSERVER_Error *ModelInstanceState::ProcessOutputs(TRITONBACKEND_Request **requests, const uint32_t request_count,
228                                                        std::vector<TRITONBACKEND_Response *> *responses) {
229   BackendOutputResponder responder(requests, request_count, responses, model_state_->TritonMemoryManager(),
230                                    model_state_->MaxBatchSize() > 1, model_state_->EnablePinnedOutput(), nullptr);
231 
232   for (auto output : outputs_) {
233     auto shape = output.Shape();
234     auto data_type = GetTritonServerDataTypeFromMSDataType(output.DataType());
235     RETURN_ERROR_IF_TRUE(data_type == TRITONSERVER_TYPE_INVALID, TRITONSERVER_ERROR_INTERNAL,
236                          std::string("The output data type is invalid."));
237     auto data = output.MutableData();
238     RETURN_ERROR_IF_TRUE(data == nullptr, TRITONSERVER_ERROR_INTERNAL, std::string("The output data is nullptr."));
239     responder.ProcessTensor(output.Name(), TRITONSERVER_TYPE_FP32, shape, reinterpret_cast<char *>(data),
240                             TRITONSERVER_MEMORY_CPU, 0);
241   }
242   responder.Finalize();
243   return nullptr;  // success
244 }
245 
ProcessRequests(TRITONBACKEND_Request ** requests,const uint32_t request_count)246 void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request **requests, const uint32_t request_count) {
247   LOG_MESSAGE(TRITONSERVER_LOG_INFO,
248               (std::string("Begin to process ") + std::to_string(request_count) + std::string(" requests.")).c_str());
249   uint64_t exec_start_ns = 0;
250   SET_TIMESTAMP(exec_start_ns);
251 
252   for (size_t i = 0; i < request_count; i++) {
253     // If we get a nullptr request then something is badly wrong. Fail
254     // and release all requests.
255     if (requests[i] == nullptr) {
256       RequestsRespondWithError(
257         requests, request_count,
258         TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
259                               std::string("null request given to MSLite backend for '" + Name() + "'").c_str()));
260       return;
261     }
262   }
263 
264   // At this point we accept ownership of 'requests', which means that
265   // even if something goes wrong we must still return success from
266   // this function. If something does go wrong in processing a
267   // particular request then we send an error response just for the
268   // specific request.
269   std::vector<TRITONBACKEND_Response *> responses;
270   responses.reserve(request_count);
271   for (size_t i = 0; i < request_count; i++) {
272     TRITONBACKEND_Response *response;
273     auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
274     if (err == nullptr) {
275       responses.emplace_back(response);
276     } else {
277       responses.emplace_back(nullptr);
278       LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response");
279       TRITONSERVER_ErrorDelete(err);
280     }
281   }
282 
283   RESPOND_ALL_AND_SET_NULL_IF_ERROR(responses, request_count, ProcessInputs(requests, request_count, &responses));
284 
285   if (need_resize_) {
286     std::vector<std::vector<int64_t>> shapes;
287     std::transform(inputs_.begin(), inputs_.end(), std::back_inserter(shapes),
288                    [](const mindspore::MSTensor &input) { return input.Shape(); });
289     std::stringstream oss;
290     auto print_shape = [&oss](const std::vector<int64_t> &shape) {
291       oss << "[";
292       std::for_each(shape.begin(), shape.end(), [&oss](int64_t dim) { oss << (std::to_string(dim) + ", "); });
293       oss << "], ";
294     };
295     oss << "The inputs needs to resize to [";
296     std::for_each(shapes.begin(), shapes.end(), print_shape);
297     oss << "]";
298     LOG_MESSAGE(TRITONSERVER_LOG_INFO, oss.str().c_str());
299     std::cout << oss.str() << std::endl;
300 
301     auto ret = ms_model_->Resize(ms_model_->GetInputs(), shapes);
302     if (ret != mindspore::kSuccess) {
303       RESPOND_ALL_AND_SET_NULL_IF_ERROR(
304         responses, request_count, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, "Fail to resize mslite model."));
305       return;
306     }
307   }
308 
309   uint64_t compute_start_ns = 0;
310   SET_TIMESTAMP(compute_start_ns);
311   outputs_ = !outputs_.empty() ? outputs_ : ms_model_->GetOutputs();
312   auto ret = ms_model_->Predict(inputs_, &outputs_);
313   if (ret != mindspore::kSuccess) {
314     RESPOND_ALL_AND_SET_NULL_IF_ERROR(
315       responses, request_count, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, "Fail to predict mslite model."));
316     return;
317   }
318 
319   uint64_t compute_end_ns = 0;
320   SET_TIMESTAMP(compute_end_ns);
321   LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Compute model ") + Name() + " cost " +
322                                       std::to_string((compute_end_ns - compute_start_ns) / 1000) + " us")
323                                        .c_str());
324 
325   RESPOND_ALL_AND_SET_NULL_IF_ERROR(responses, request_count, ProcessOutputs(requests, request_count, &responses));
326 
327   uint64_t exec_end_ns = 0;
328   SET_TIMESTAMP(exec_end_ns);
329 
330   // Send all the responses that haven't already been sent because of
331   // an earlier error. Note that the responses are not set to nullptr
332   // here as we need that indication below to determine if the request
333   // we successful or not.
334   for (auto &response : responses) {
335     if (response != nullptr) {
336       LOG_IF_ERROR(TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr),
337                    "failed to send MSLite backend response");
338     }
339   }
340 
341   // Report statistics for each request.
342   for (uint32_t r = 0; r < request_count; ++r) {
343     auto &request = requests[r];
344     LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics(TritonModelInstance(), request,
345                                                              (responses[r] != nullptr) /* success */, exec_start_ns,
346                                                              compute_start_ns, compute_end_ns, exec_end_ns),
347                  "failed reporting request statistics");
348 
349     LOG_IF_ERROR(TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), "failed releasing request");
350   }
351 
352   // Report the entire batch statistics.
353   LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics(TritonModelInstance(), batched_size_, exec_start_ns,
354                                                                 compute_start_ns, compute_end_ns, exec_end_ns),
355                "failed reporting batch request statistics");
356 
357   LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("TRITONBACKEND_ModelExecute: model ") + Name() + " released " +
358                                       std::to_string(request_count) + " requests")
359                                        .c_str());
360 
361   LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Execute model ") + Name() + " cost " +
362                                       std::to_string((exec_end_ns - exec_start_ns) / 1000) + " us")
363                                        .c_str());
364 }
365 }  // namespace mslite
366 }  // namespace backend
367 }  // namespace triton
368