1 /**
2 * Copyright 2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/mslite_model_state.h"
18 #include <unistd.h>
19 #include <algorithm>
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 #include "triton/backend/backend_input_collector.h"
26 #include "triton/backend/backend_output_responder.h"
27
28 namespace triton {
29 namespace backend {
30 namespace mslite {
Create(TRITONBACKEND_Model * triton_model,ModelState ** state)31 TRITONSERVER_Error *ModelState::Create(TRITONBACKEND_Model *triton_model, ModelState **state) {
32 try {
33 *state = new ModelState(triton_model);
34 } catch (const BackendModelException &ex) {
35 RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
36 std::string("unexpected nullptr in BackendModelException"));
37 RETURN_IF_ERROR(ex.err_);
38 }
39
40 RETURN_IF_ERROR((*state)->ParseModelConfig());
41 RETURN_IF_ERROR((*state)->ParseModelParameterConfig());
42 RETURN_IF_ERROR((*state)->InitMSContext());
43 return nullptr; // success
44 }
45
InitMSContext()46 TRITONSERVER_Error *ModelState::InitMSContext() {
47 // Init ms context.
48 ms_context_ = std::make_shared<mindspore::Context>();
49 RETURN_ERROR_IF_TRUE(ms_context_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
50 std::string("New mindspore-lite context failed."));
51 auto &device_list = ms_context_->MutableDeviceInfo();
52
53 if (device_type_ == "ascend") {
54 auto ascend_device_info = std::make_shared<mindspore::AscendDeviceInfo>();
55 RETURN_ERROR_IF_TRUE(ascend_device_info == nullptr, TRITONSERVER_ERROR_INTERNAL,
56 std::string("New AscendDeviceInfo failed for mindspore-lite context."));
57 ascend_device_info->SetDeviceID(device_id_);
58 device_list.push_back(ascend_device_info);
59 }
60 auto device_info = std::make_shared<mindspore::CPUDeviceInfo>();
61 RETURN_ERROR_IF_TRUE(device_info == nullptr, TRITONSERVER_ERROR_INTERNAL,
62 std::string("New CPUDeviceInfo failed for mindspore-lite context."));
63 device_list.push_back(device_info);
64 return nullptr;
65 }
66
BuildMSModel()67 std::shared_ptr<mindspore::Model> ModelState::BuildMSModel() {
68 auto ms_model = std::make_shared<mindspore::Model>();
69 if (ms_model == nullptr) {
70 LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "New mslite model failed.");
71 return nullptr;
72 }
73 auto model_path = JoinPath({this->RepositoryPath(), std::to_string(this->Version())});
74 auto model_name = this->Name() + "." + model_type_;
75 auto model_file = JoinPath({model_path, model_name});
76 auto model_type = model_type_ == "mindir" ? mindspore::kMindIR : mindspore::kMindIR_Lite;
77 LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Begin to init mslite model file: ") + model_file).c_str());
78 auto build_ret = ms_model->Build(model_file, model_type, ms_context_);
79 if (build_ret != mindspore::kSuccess) {
80 LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Build mslite model failed.");
81 return nullptr;
82 }
83 return ms_model;
84 }
85
ParseModelParameterConfig()86 TRITONSERVER_Error *ModelState::ParseModelParameterConfig() {
87 // parse dynamic batching config. the max_batch_size is parsed in the inner method ParseModelConfig.
88 // dynamic batch may means that the max batch is not zero.
89 RETURN_IF_ERROR(this->SupportsFirstDimBatching(&support_dynamic_batch_));
90
91 // for ascend, use the dynamic_batching to judge the real state of supporting dynamic batch.
92 common::TritonJson::Value dynamic_batching;
93 support_dynamic_batch_ &= model_config_.Find("dynamic_batching", &dynamic_batching);
94 if (support_dynamic_batch_) {
95 RETURN_IF_ERROR(backend::ParseShape(dynamic_batching, "preferred_batch_size", &preferred_batch_size_));
96 std::sort(preferred_batch_size_.begin(), preferred_batch_size_.end());
97 RETURN_ERROR_IF_FALSE(std::all_of(preferred_batch_size_.begin(), preferred_batch_size_.end(),
98 [&](int64_t dim) { return dim <= max_batch_size_; }),
99 TRITONSERVER_ERROR_INVALID_ARG,
100 std::string("the preferred batch size should not be larger than the max batch size."));
101 }
102
103 // parse parameters.
104 common::TritonJson::Value parameters;
105 if (model_config_.Find("parameters", ¶meters)) {
106 (void)GetParameterValue(parameters, "model_type", &model_type_);
107 (void)GetParameterValue(parameters, "device_type", &device_type_);
108 std::string device_id;
109 (void)GetParameterValue(parameters, "device_id", &device_id);
110 if (!device_id.empty()) {
111 try {
112 device_id_ = std::stoi(device_id);
113 } catch (const BackendModelInstanceException &ex) {
114 RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
115 std::string("The device_id is not valid.") + device_id);
116 RETURN_IF_ERROR(ex.err_);
117 }
118 }
119 }
120 return nullptr; // success
121 }
122
Create(ModelState * model_state,TRITONBACKEND_ModelInstance * triton_model_instance,ModelInstanceState ** state)123 TRITONSERVER_Error *ModelInstanceState::Create(ModelState *model_state,
124 TRITONBACKEND_ModelInstance *triton_model_instance,
125 ModelInstanceState **state) {
126 try {
127 *state = new ModelInstanceState(model_state, triton_model_instance);
128 } catch (const BackendModelInstanceException &ex) {
129 RETURN_ERROR_IF_TRUE(ex.err_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
130 std::string("unexpected nullptr in BackendModelInstanceException"));
131 RETURN_IF_ERROR(ex.err_);
132 }
133
134 (*state)->ms_model_ = model_state->BuildMSModel();
135 RETURN_ERROR_IF_TRUE((*state)->ms_model_ == nullptr, TRITONSERVER_ERROR_INTERNAL,
136 std::string("Init mslite model failed."));
137 return nullptr; // success
138 }
139
ProcessInputs(TRITONBACKEND_Request ** requests,const uint32_t request_count,std::vector<TRITONBACKEND_Response * > * responses)140 TRITONSERVER_Error *ModelInstanceState::ProcessInputs(TRITONBACKEND_Request **requests, const uint32_t request_count,
141 std::vector<TRITONBACKEND_Response *> *responses) {
142 // To instruct ProcessTensor to "gather" the entire batch of input
143 // tensors into a single contiguous buffer in CPU memory, set the
144 // "allowed input types" to be the CPU ones (see tritonserver.h in
145 // the triton-inference-server/core repo for allowed memory types).
146 BackendInputCollector collector(requests, request_count, responses, model_state_->TritonMemoryManager(),
147 model_state_->EnablePinnedInput(), nullptr);
148
149 std::vector<std::pair<TRITONSERVER_MemoryType, int64_t>> allowed_input_types = {{TRITONSERVER_MEMORY_CPU_PINNED, 0},
150 {TRITONSERVER_MEMORY_CPU, 0}};
151
152 uint32_t input_count = 0;
153 RETURN_IF_ERROR(TRITONBACKEND_RequestInputCount(requests[0], &input_count));
154
155 auto model_inputs = ms_model_->GetInputs();
156 RETURN_ERROR_IF_FALSE(model_inputs.size() == input_count, TRITONSERVER_ERROR_INVALID_ARG,
157 std::string("The input count is not equal to model inputs: ") +
158 std::to_string(model_inputs.size()) + std::string(", while the input count of request is: ") +
159 std::to_string(input_count));
160
161 inputs_.clear();
162 for (uint32_t idx = 0; idx < input_count; idx++) {
163 TRITONBACKEND_Input *input;
164 RETURN_IF_ERROR(TRITONBACKEND_RequestInputByIndex(requests[0], idx, &input));
165 const char *input_name;
166 TRITONSERVER_DataType input_datatype;
167 const int64_t *input_shape;
168 uint32_t input_dims_count;
169 batched_size_ = 0;
170 for (uint32_t r = 0; r < request_count; r++) {
171 RETURN_IF_ERROR(TRITONBACKEND_InputProperties(input, &input_name, &input_datatype, &input_shape,
172 &input_dims_count, nullptr, nullptr));
173 batched_size_ += *input_shape;
174 }
175 RETURN_ERROR_IF_TRUE(model_state_->MaxBatchSize() > 0 && batched_size_ > model_state_->MaxBatchSize(),
176 TRITONSERVER_ERROR_INVALID_ARG,
177 std::string("The input batch size is larger than the max batch size."));
178 auto data_type = GetMSDataTypeFromTritonServerDataType(input_datatype);
179 RETURN_ERROR_IF_FALSE(data_type == model_inputs.at(idx).DataType(), TRITONSERVER_ERROR_INVALID_ARG,
180 std::string("The input data type is not equal to model input."));
181
182 std::vector<int64_t> batched_shape(input_dims_count);
183 std::memcpy(batched_shape.data(), input_shape, input_dims_count * sizeof(int64_t));
184 batched_shape.at(0) = batched_size_;
185 if (support_dynamic_batch_) {
186 auto pad_batch_itr = std::lower_bound(preferred_batch_size_.begin(), preferred_batch_size_.end(), batched_size_);
187 batched_shape.at(0) = pad_batch_itr != preferred_batch_size_.end() ? *pad_batch_itr : max_batch_size_;
188 LOG_MESSAGE(TRITONSERVER_LOG_INFO,
189 std::string("The batched size will be pad to " + std::to_string(batched_shape.at(0))).c_str());
190 need_resize_ = model_inputs.at(idx).Shape() != batched_shape;
191 }
192
193 const char *input_buffer = nullptr;
194 size_t input_buffer_byte_size;
195 TRITONSERVER_MemoryType input_buffer_memory_type;
196 int64_t input_buffer_memory_type_id;
197 RETURN_IF_ERROR(collector.ProcessTensor(
198 input_name, nullptr /* existing_buffer */, 0 /* existing_buffer_byte_size */, allowed_input_types, &input_buffer,
199 &input_buffer_byte_size, &input_buffer_memory_type, &input_buffer_memory_type_id));
200 RETURN_ERROR_IF_TRUE(input_buffer == nullptr || input_buffer_byte_size == 0, TRITONSERVER_ERROR_INTERNAL,
201 std::string("Process input tensor data failed."));
202
203 auto input_tensor = mindspore::MSTensor(input_name, data_type, {}, nullptr, 0);
204 input_tensor.SetShape(batched_shape);
205 if (!support_dynamic_batch_ && input_tensor.DataSize() == input_buffer_byte_size) {
206 // speed up with non-memcpy in static shape,
207 // while the batched data from several requests may cause a host-to-device memcpy error.
208 LOG_MESSAGE(TRITONSERVER_LOG_INFO, std::string("use the original data without copy.").c_str());
209 input_tensor.SetData(reinterpret_cast<void *>(const_cast<char *>(input_buffer)), false);
210 } else {
211 LOG_MESSAGE(TRITONSERVER_LOG_INFO, std::string("malloc data because the data size is not equal").c_str());
212 auto input_data = input_tensor.MutableData();
213 auto data_size = input_tensor.DataSize();
214 RETURN_ERROR_IF_TRUE(input_data == nullptr || input_buffer_byte_size > data_size, TRITONSERVER_ERROR_INTERNAL,
215 std::string("Process input tensor data failed."));
216
217 std::memset(input_data, 0, input_tensor.DataSize());
218 std::memcpy(input_data, input_buffer, input_buffer_byte_size);
219 }
220 inputs_.push_back(input_tensor);
221 }
222 collector.Finalize();
223 LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Process inputs tensor finished.").c_str()));
224 return nullptr; // success
225 }
226
ProcessOutputs(TRITONBACKEND_Request ** requests,const uint32_t request_count,std::vector<TRITONBACKEND_Response * > * responses)227 TRITONSERVER_Error *ModelInstanceState::ProcessOutputs(TRITONBACKEND_Request **requests, const uint32_t request_count,
228 std::vector<TRITONBACKEND_Response *> *responses) {
229 BackendOutputResponder responder(requests, request_count, responses, model_state_->TritonMemoryManager(),
230 model_state_->MaxBatchSize() > 1, model_state_->EnablePinnedOutput(), nullptr);
231
232 for (auto output : outputs_) {
233 auto shape = output.Shape();
234 auto data_type = GetTritonServerDataTypeFromMSDataType(output.DataType());
235 RETURN_ERROR_IF_TRUE(data_type == TRITONSERVER_TYPE_INVALID, TRITONSERVER_ERROR_INTERNAL,
236 std::string("The output data type is invalid."));
237 auto data = output.MutableData();
238 RETURN_ERROR_IF_TRUE(data == nullptr, TRITONSERVER_ERROR_INTERNAL, std::string("The output data is nullptr."));
239 responder.ProcessTensor(output.Name(), TRITONSERVER_TYPE_FP32, shape, reinterpret_cast<char *>(data),
240 TRITONSERVER_MEMORY_CPU, 0);
241 }
242 responder.Finalize();
243 return nullptr; // success
244 }
245
ProcessRequests(TRITONBACKEND_Request ** requests,const uint32_t request_count)246 void ModelInstanceState::ProcessRequests(TRITONBACKEND_Request **requests, const uint32_t request_count) {
247 LOG_MESSAGE(TRITONSERVER_LOG_INFO,
248 (std::string("Begin to process ") + std::to_string(request_count) + std::string(" requests.")).c_str());
249 uint64_t exec_start_ns = 0;
250 SET_TIMESTAMP(exec_start_ns);
251
252 for (size_t i = 0; i < request_count; i++) {
253 // If we get a nullptr request then something is badly wrong. Fail
254 // and release all requests.
255 if (requests[i] == nullptr) {
256 RequestsRespondWithError(
257 requests, request_count,
258 TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL,
259 std::string("null request given to MSLite backend for '" + Name() + "'").c_str()));
260 return;
261 }
262 }
263
264 // At this point we accept ownership of 'requests', which means that
265 // even if something goes wrong we must still return success from
266 // this function. If something does go wrong in processing a
267 // particular request then we send an error response just for the
268 // specific request.
269 std::vector<TRITONBACKEND_Response *> responses;
270 responses.reserve(request_count);
271 for (size_t i = 0; i < request_count; i++) {
272 TRITONBACKEND_Response *response;
273 auto err = TRITONBACKEND_ResponseNew(&response, requests[i]);
274 if (err == nullptr) {
275 responses.emplace_back(response);
276 } else {
277 responses.emplace_back(nullptr);
278 LOG_MESSAGE(TRITONSERVER_LOG_ERROR, "Fail to create response");
279 TRITONSERVER_ErrorDelete(err);
280 }
281 }
282
283 RESPOND_ALL_AND_SET_NULL_IF_ERROR(responses, request_count, ProcessInputs(requests, request_count, &responses));
284
285 if (need_resize_) {
286 std::vector<std::vector<int64_t>> shapes;
287 std::transform(inputs_.begin(), inputs_.end(), std::back_inserter(shapes),
288 [](const mindspore::MSTensor &input) { return input.Shape(); });
289 std::stringstream oss;
290 auto print_shape = [&oss](const std::vector<int64_t> &shape) {
291 oss << "[";
292 std::for_each(shape.begin(), shape.end(), [&oss](int64_t dim) { oss << (std::to_string(dim) + ", "); });
293 oss << "], ";
294 };
295 oss << "The inputs needs to resize to [";
296 std::for_each(shapes.begin(), shapes.end(), print_shape);
297 oss << "]";
298 LOG_MESSAGE(TRITONSERVER_LOG_INFO, oss.str().c_str());
299 std::cout << oss.str() << std::endl;
300
301 auto ret = ms_model_->Resize(ms_model_->GetInputs(), shapes);
302 if (ret != mindspore::kSuccess) {
303 RESPOND_ALL_AND_SET_NULL_IF_ERROR(
304 responses, request_count, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, "Fail to resize mslite model."));
305 return;
306 }
307 }
308
309 uint64_t compute_start_ns = 0;
310 SET_TIMESTAMP(compute_start_ns);
311 outputs_ = !outputs_.empty() ? outputs_ : ms_model_->GetOutputs();
312 auto ret = ms_model_->Predict(inputs_, &outputs_);
313 if (ret != mindspore::kSuccess) {
314 RESPOND_ALL_AND_SET_NULL_IF_ERROR(
315 responses, request_count, TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_INTERNAL, "Fail to predict mslite model."));
316 return;
317 }
318
319 uint64_t compute_end_ns = 0;
320 SET_TIMESTAMP(compute_end_ns);
321 LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Compute model ") + Name() + " cost " +
322 std::to_string((compute_end_ns - compute_start_ns) / 1000) + " us")
323 .c_str());
324
325 RESPOND_ALL_AND_SET_NULL_IF_ERROR(responses, request_count, ProcessOutputs(requests, request_count, &responses));
326
327 uint64_t exec_end_ns = 0;
328 SET_TIMESTAMP(exec_end_ns);
329
330 // Send all the responses that haven't already been sent because of
331 // an earlier error. Note that the responses are not set to nullptr
332 // here as we need that indication below to determine if the request
333 // we successful or not.
334 for (auto &response : responses) {
335 if (response != nullptr) {
336 LOG_IF_ERROR(TRITONBACKEND_ResponseSend(response, TRITONSERVER_RESPONSE_COMPLETE_FINAL, nullptr),
337 "failed to send MSLite backend response");
338 }
339 }
340
341 // Report statistics for each request.
342 for (uint32_t r = 0; r < request_count; ++r) {
343 auto &request = requests[r];
344 LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportStatistics(TritonModelInstance(), request,
345 (responses[r] != nullptr) /* success */, exec_start_ns,
346 compute_start_ns, compute_end_ns, exec_end_ns),
347 "failed reporting request statistics");
348
349 LOG_IF_ERROR(TRITONBACKEND_RequestRelease(request, TRITONSERVER_REQUEST_RELEASE_ALL), "failed releasing request");
350 }
351
352 // Report the entire batch statistics.
353 LOG_IF_ERROR(TRITONBACKEND_ModelInstanceReportBatchStatistics(TritonModelInstance(), batched_size_, exec_start_ns,
354 compute_start_ns, compute_end_ns, exec_end_ns),
355 "failed reporting batch request statistics");
356
357 LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("TRITONBACKEND_ModelExecute: model ") + Name() + " released " +
358 std::to_string(request_count) + " requests")
359 .c_str());
360
361 LOG_MESSAGE(TRITONSERVER_LOG_INFO, (std::string("Execute model ") + Name() + " cost " +
362 std::to_string((exec_end_ns - exec_start_ns) / 1000) + " us")
363 .c_str());
364 }
365 } // namespace mslite
366 } // namespace backend
367 } // namespace triton
368