1 /*
2 * Copyright (C) 2023 The Android Open Source Project
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #define ATRACE_TAG (ATRACE_TAG_THERMAL | ATRACE_TAG_HAL)
17
18 #include "virtualtemp_estimator.h"
19
20 #include <android-base/logging.h>
21 #include <android-base/stringprintf.h>
22 #include <dlfcn.h>
23 #include <json/reader.h>
24 #include <utils/Trace.h>
25
26 #include <cmath>
27 #include <sstream>
28 #include <vector>
29
30 namespace thermal {
31 namespace vtestimator {
32 namespace {
getFloatFromValue(const Json::Value & value)33 float getFloatFromValue(const Json::Value &value) {
34 if (value.isString()) {
35 return std::atof(value.asString().c_str());
36 } else {
37 return value.asFloat();
38 }
39 }
40
getInputRangeInfoFromJsonValues(const Json::Value & values,InputRangeInfo * input_range_info)41 bool getInputRangeInfoFromJsonValues(const Json::Value &values, InputRangeInfo *input_range_info) {
42 if (values.size() != 2) {
43 LOG(ERROR) << "Data Range Values size: " << values.size() << "is invalid.";
44 return false;
45 }
46
47 float min_val = getFloatFromValue(values[0]);
48 float max_val = getFloatFromValue(values[1]);
49
50 if (std::isnan(min_val) || std::isnan(max_val)) {
51 LOG(ERROR) << "Illegal data range: thresholds not defined properly " << min_val << " : "
52 << max_val;
53 return false;
54 }
55
56 if (min_val > max_val) {
57 LOG(ERROR) << "Illegal data range: data_min_threshold(" << min_val
58 << ") > data_max_threshold(" << max_val << ")";
59 return false;
60 }
61 input_range_info->min_threshold = min_val;
62 input_range_info->max_threshold = max_val;
63 LOG(INFO) << "Data Range Info: " << input_range_info->min_threshold
64 << " <= val <= " << input_range_info->max_threshold;
65 return true;
66 }
67
CalculateOffset(const std::vector<float> & offset_thresholds,const std::vector<float> & offset_values,const float value)68 float CalculateOffset(const std::vector<float> &offset_thresholds,
69 const std::vector<float> &offset_values, const float value) {
70 for (int i = offset_thresholds.size(); i > 0; --i) {
71 if (offset_thresholds[i - 1] < value) {
72 return offset_values[i - 1];
73 }
74 }
75
76 return 0;
77 }
78 } // namespace
79
DumpTraces()80 VtEstimatorStatus VirtualTempEstimator::DumpTraces() {
81 if (type != kUseMLModel) {
82 return kVtEstimatorUnSupported;
83 }
84
85 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
86 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during DumpTraces\n";
87 return kVtEstimatorInitFailed;
88 }
89
90 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
91
92 if (!common_instance_->is_initialized) {
93 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
94 return kVtEstimatorInitFailed;
95 }
96
97 // get model input/output buffers
98 float *model_input = tflite_instance_->input_buffer;
99 float *model_output = tflite_instance_->output_buffer;
100 auto input_buffer_size = tflite_instance_->input_buffer_size;
101 auto output_buffer_size = tflite_instance_->output_buffer_size;
102
103 // In Case of use_prev_samples, inputs are available in order in scratch buffer
104 if (common_instance_->use_prev_samples) {
105 model_input = tflite_instance_->scratch_buffer;
106 }
107
108 // Add traces for model input/output buffers
109 std::string sensor_name = common_instance_->sensor_name;
110 for (size_t i = 0; i < input_buffer_size; ++i) {
111 ATRACE_INT((sensor_name + "_input_" + std::to_string(i)).c_str(),
112 static_cast<int>(model_input[i]));
113 }
114
115 for (size_t i = 0; i < output_buffer_size; ++i) {
116 ATRACE_INT((sensor_name + "_output_" + std::to_string(i)).c_str(),
117 static_cast<int>(model_output[i]));
118 }
119
120 // log input data and output data buffers
121 std::string input_data_str = "model_input_buffer: [";
122 for (size_t i = 0; i < input_buffer_size; ++i) {
123 input_data_str += ::android::base::StringPrintf("%0.2f ", model_input[i]);
124 }
125 input_data_str += "]";
126 LOG(INFO) << input_data_str;
127
128 std::string output_data_str = "model_output_buffer: [";
129 for (size_t i = 0; i < output_buffer_size; ++i) {
130 output_data_str += ::android::base::StringPrintf("%0.2f ", model_output[i]);
131 }
132 output_data_str += "]";
133 LOG(INFO) << output_data_str;
134
135 return kVtEstimatorOk;
136 }
137
LoadTFLiteWrapper()138 void VirtualTempEstimator::LoadTFLiteWrapper() {
139 if (!tflite_instance_) {
140 LOG(ERROR) << "tflite_instance_ is nullptr during LoadTFLiteWrapper";
141 return;
142 }
143
144 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
145
146 void *mLibHandle = dlopen("/vendor/lib64/libthermal_tflite_wrapper.so", 0);
147 if (mLibHandle == nullptr) {
148 LOG(ERROR) << "Could not load libthermal_tflite_wrapper library with error: " << dlerror();
149 return;
150 }
151
152 tflite_instance_->tflite_methods.create =
153 reinterpret_cast<tflitewrapper_create>(dlsym(mLibHandle, "ThermalTfliteCreate"));
154 if (!tflite_instance_->tflite_methods.create) {
155 LOG(ERROR) << "Could not link and cast tflitewrapper_create with error: " << dlerror();
156 }
157
158 tflite_instance_->tflite_methods.init =
159 reinterpret_cast<tflitewrapper_init>(dlsym(mLibHandle, "ThermalTfliteInit"));
160 if (!tflite_instance_->tflite_methods.init) {
161 LOG(ERROR) << "Could not link and cast tflitewrapper_init with error: " << dlerror();
162 }
163
164 tflite_instance_->tflite_methods.invoke =
165 reinterpret_cast<tflitewrapper_invoke>(dlsym(mLibHandle, "ThermalTfliteInvoke"));
166 if (!tflite_instance_->tflite_methods.invoke) {
167 LOG(ERROR) << "Could not link and cast tflitewrapper_invoke with error: " << dlerror();
168 }
169
170 tflite_instance_->tflite_methods.destroy =
171 reinterpret_cast<tflitewrapper_destroy>(dlsym(mLibHandle, "ThermalTfliteDestroy"));
172 if (!tflite_instance_->tflite_methods.destroy) {
173 LOG(ERROR) << "Could not link and cast tflitewrapper_destroy with error: " << dlerror();
174 }
175
176 tflite_instance_->tflite_methods.get_input_config_size =
177 reinterpret_cast<tflitewrapper_get_input_config_size>(
178 dlsym(mLibHandle, "ThermalTfliteGetInputConfigSize"));
179 if (!tflite_instance_->tflite_methods.get_input_config_size) {
180 LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config_size with error: "
181 << dlerror();
182 }
183
184 tflite_instance_->tflite_methods.get_input_config =
185 reinterpret_cast<tflitewrapper_get_input_config>(
186 dlsym(mLibHandle, "ThermalTfliteGetInputConfig"));
187 if (!tflite_instance_->tflite_methods.get_input_config) {
188 LOG(ERROR) << "Could not link and cast tflitewrapper_get_input_config with error: "
189 << dlerror();
190 }
191 }
192
VirtualTempEstimator(std::string_view sensor_name,VtEstimationType estimationType,size_t num_linked_sensors)193 VirtualTempEstimator::VirtualTempEstimator(std::string_view sensor_name,
194 VtEstimationType estimationType,
195 size_t num_linked_sensors) {
196 type = estimationType;
197
198 common_instance_ = std::make_unique<VtEstimatorCommonData>(sensor_name, num_linked_sensors);
199 if (estimationType == kUseMLModel) {
200 tflite_instance_ = std::make_unique<VtEstimatorTFLiteData>();
201 LoadTFLiteWrapper();
202 } else if (estimationType == kUseLinearModel) {
203 linear_model_instance_ = std::make_unique<VtEstimatorLinearModelData>();
204 } else {
205 LOG(ERROR) << "Unsupported estimationType [" << estimationType << "]";
206 }
207 }
208
~VirtualTempEstimator()209 VirtualTempEstimator::~VirtualTempEstimator() {
210 LOG(INFO) << "VirtualTempEstimator destructor";
211 }
212
LinearModelInitialize(LinearModelInitData data)213 VtEstimatorStatus VirtualTempEstimator::LinearModelInitialize(LinearModelInitData data) {
214 if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
215 LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
216 return kVtEstimatorInitFailed;
217 }
218
219 size_t num_linked_sensors = common_instance_->num_linked_sensors;
220 std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
221
222 if ((num_linked_sensors == 0) || (data.coefficients.size() == 0) ||
223 (data.prev_samples_order == 0)) {
224 LOG(ERROR) << "Invalid num_linked_sensors [" << num_linked_sensors
225 << "] or coefficients.size() [" << data.coefficients.size()
226 << "] or prev_samples_order [" << data.prev_samples_order << "]";
227 return kVtEstimatorInitFailed;
228 }
229
230 if (data.coefficients.size() != (num_linked_sensors * data.prev_samples_order)) {
231 LOG(ERROR) << "In valid args coefficients.size()[" << data.coefficients.size()
232 << "] num_linked_sensors [" << num_linked_sensors << "] prev_samples_order["
233 << data.prev_samples_order << "]";
234 return kVtEstimatorInvalidArgs;
235 }
236
237 common_instance_->use_prev_samples = data.use_prev_samples;
238 common_instance_->prev_samples_order = data.prev_samples_order;
239
240 linear_model_instance_->input_samples.reserve(common_instance_->prev_samples_order);
241 linear_model_instance_->coefficients.reserve(common_instance_->prev_samples_order);
242
243 // Store coefficients
244 for (size_t i = 0; i < data.prev_samples_order; ++i) {
245 std::vector<float> single_order_coefficients;
246 for (size_t j = 0; j < num_linked_sensors; ++j) {
247 single_order_coefficients.emplace_back(data.coefficients[i * num_linked_sensors + j]);
248 }
249 linear_model_instance_->coefficients.emplace_back(single_order_coefficients);
250 }
251
252 common_instance_->offset_thresholds = data.offset_thresholds;
253 common_instance_->offset_values = data.offset_values;
254 common_instance_->is_initialized = true;
255
256 return kVtEstimatorOk;
257 }
258
TFliteInitialize(MLModelInitData data)259 VtEstimatorStatus VirtualTempEstimator::TFliteInitialize(MLModelInitData data) {
260 if (!tflite_instance_ || !common_instance_) {
261 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Initialize\n";
262 return kVtEstimatorInitFailed;
263 }
264
265 std::string model_path = data.model_path;
266 size_t num_linked_sensors = common_instance_->num_linked_sensors;
267 bool use_prev_samples = data.use_prev_samples;
268 size_t prev_samples_order = data.prev_samples_order;
269 size_t num_hot_spots = data.num_hot_spots;
270 size_t output_label_count = data.output_label_count;
271
272 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
273
274 if (model_path.empty()) {
275 LOG(ERROR) << "Invalid model_path:" << model_path;
276 return kVtEstimatorInvalidArgs;
277 }
278
279 if (num_linked_sensors == 0 || prev_samples_order < 1 ||
280 (!use_prev_samples && prev_samples_order > 1)) {
281 LOG(ERROR) << "Invalid tflite_instance_ config: "
282 << "number of linked sensor: " << num_linked_sensors
283 << " use previous: " << use_prev_samples
284 << " previous sample order: " << prev_samples_order;
285 return kVtEstimatorInitFailed;
286 }
287
288 common_instance_->use_prev_samples = data.use_prev_samples;
289 common_instance_->prev_samples_order = prev_samples_order;
290 tflite_instance_->support_under_sampling = data.support_under_sampling;
291 tflite_instance_->enable_input_validation = data.enable_input_validation;
292 tflite_instance_->input_buffer_size = num_linked_sensors * prev_samples_order;
293 tflite_instance_->input_buffer = new float[tflite_instance_->input_buffer_size];
294 if (common_instance_->use_prev_samples) {
295 tflite_instance_->scratch_buffer = new float[tflite_instance_->input_buffer_size];
296 }
297
298 if (output_label_count < 1 || num_hot_spots < 1) {
299 LOG(ERROR) << "Invalid tflite_instance_ config:"
300 << "number of hot spots: " << num_hot_spots
301 << " predicted sample order: " << output_label_count;
302 return kVtEstimatorInitFailed;
303 }
304
305 tflite_instance_->output_label_count = output_label_count;
306 tflite_instance_->num_hot_spots = num_hot_spots;
307 tflite_instance_->output_buffer_size = output_label_count * num_hot_spots;
308 tflite_instance_->output_buffer = new float[tflite_instance_->output_buffer_size];
309
310 if (!tflite_instance_->tflite_methods.create || !tflite_instance_->tflite_methods.init ||
311 !tflite_instance_->tflite_methods.invoke || !tflite_instance_->tflite_methods.destroy ||
312 !tflite_instance_->tflite_methods.get_input_config_size ||
313 !tflite_instance_->tflite_methods.get_input_config) {
314 LOG(ERROR) << "Invalid tflite methods";
315 return kVtEstimatorInitFailed;
316 }
317
318 tflite_instance_->tflite_wrapper =
319 tflite_instance_->tflite_methods.create(kNumInputTensors, kNumOutputTensors);
320 if (!tflite_instance_->tflite_wrapper) {
321 LOG(ERROR) << "Failed to create tflite wrapper";
322 return kVtEstimatorInitFailed;
323 }
324
325 int ret = tflite_instance_->tflite_methods.init(tflite_instance_->tflite_wrapper,
326 model_path.c_str());
327 if (ret) {
328 LOG(ERROR) << "Failed to Init tflite_wrapper for " << model_path << " (ret: )" << ret
329 << ")";
330 return kVtEstimatorInitFailed;
331 }
332
333 Json::Value input_config;
334 if (!GetInputConfig(&input_config)) {
335 LOG(ERROR) << "Get Input Config failed for " << model_path;
336 return kVtEstimatorInitFailed;
337 }
338
339 if (!ParseInputConfig(input_config)) {
340 LOG(ERROR) << "Parse Input Config failed for " << model_path;
341 return kVtEstimatorInitFailed;
342 }
343
344 if (tflite_instance_->enable_input_validation && !tflite_instance_->input_range.size()) {
345 LOG(ERROR) << "Input ranges missing when input data validation is enabled for "
346 << common_instance_->sensor_name;
347 return kVtEstimatorInitFailed;
348 }
349
350 common_instance_->offset_thresholds = data.offset_thresholds;
351 common_instance_->offset_values = data.offset_values;
352 tflite_instance_->model_path = model_path;
353
354 common_instance_->is_initialized = true;
355 LOG(INFO) << "Successfully initialized VirtualTempEstimator for " << model_path;
356 return kVtEstimatorOk;
357 }
358
LinearModelEstimate(const std::vector<float> & thermistors,std::vector<float> * output)359 VtEstimatorStatus VirtualTempEstimator::LinearModelEstimate(const std::vector<float> &thermistors,
360 std::vector<float> *output) {
361 if (linear_model_instance_ == nullptr || common_instance_ == nullptr) {
362 LOG(ERROR) << "linear_model_instance_ or common_instance_ is nullptr during Initialize";
363 return kVtEstimatorInitFailed;
364 }
365
366 size_t prev_samples_order = common_instance_->prev_samples_order;
367 size_t num_linked_sensors = common_instance_->num_linked_sensors;
368
369 std::unique_lock<std::mutex> lock(linear_model_instance_->mutex);
370
371 if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
372 LOG(ERROR) << "Invalid args Thermistors size[" << thermistors.size()
373 << "] num_linked_sensors[" << num_linked_sensors << "] output[" << output << "]";
374 return kVtEstimatorInvalidArgs;
375 }
376
377 if (common_instance_->is_initialized == false) {
378 LOG(ERROR) << "VirtualTempEstimator not initialized to estimate";
379 return kVtEstimatorInitFailed;
380 }
381
382 // For the first iteration copy current inputs to all previous inputs
383 // This would allow the estimator to have previous samples from the first iteration itself
384 // and provide a valid predicted value
385 if (common_instance_->cur_sample_count == 0) {
386 for (size_t i = 0; i < prev_samples_order; ++i) {
387 linear_model_instance_->input_samples[i] = thermistors;
388 }
389 }
390
391 size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
392 linear_model_instance_->input_samples[cur_sample_index] = thermistors;
393
394 // Calculate Weighted Average Value
395 int input_level = cur_sample_index;
396 float estimated_value = 0;
397 for (size_t i = 0; i < prev_samples_order; ++i) {
398 for (size_t j = 0; j < num_linked_sensors; ++j) {
399 estimated_value += linear_model_instance_->coefficients[i][j] *
400 linear_model_instance_->input_samples[input_level][j];
401 }
402 input_level--; // go to previous samples
403 input_level = (input_level >= 0) ? input_level : (prev_samples_order - 1);
404 }
405
406 // Update sample count
407 common_instance_->cur_sample_count++;
408
409 // add offset to estimated value if applicable
410 estimated_value += CalculateOffset(common_instance_->offset_thresholds,
411 common_instance_->offset_values, estimated_value);
412
413 std::vector<float> data = {estimated_value};
414 *output = data;
415 return kVtEstimatorOk;
416 }
417
TFliteEstimate(const std::vector<float> & thermistors,std::vector<float> * output)418 VtEstimatorStatus VirtualTempEstimator::TFliteEstimate(const std::vector<float> &thermistors,
419 std::vector<float> *output) {
420 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
421 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr during Estimate\n";
422 return kVtEstimatorInitFailed;
423 }
424
425 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
426
427 if (!common_instance_->is_initialized) {
428 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
429 return kVtEstimatorInitFailed;
430 }
431
432 size_t num_linked_sensors = common_instance_->num_linked_sensors;
433 if ((thermistors.size() != num_linked_sensors) || (output == nullptr)) {
434 LOG(ERROR) << "Invalid args for " << tflite_instance_->model_path
435 << " thermistors.size(): " << thermistors.size()
436 << " num_linked_sensors: " << num_linked_sensors << " output: " << output;
437 return kVtEstimatorInvalidArgs;
438 }
439
440 // log input data
441 std::string input_data_str = "model_input: [";
442 for (size_t i = 0; i < num_linked_sensors; ++i) {
443 input_data_str += ::android::base::StringPrintf("%0.2f ", thermistors[i]);
444 }
445 input_data_str += "]";
446 LOG(INFO) << input_data_str;
447
448 // check time gap between samples and ignore stale previous samples
449 if (std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
450 tflite_instance_->prev_sample_time) >=
451 tflite_instance_->max_sample_interval) {
452 LOG(INFO) << "Ignoring stale previous samples for " << common_instance_->sensor_name;
453 common_instance_->cur_sample_count = 0;
454 }
455
456 // copy input data into input tensors
457 size_t prev_samples_order = common_instance_->prev_samples_order;
458 size_t cur_sample_index = common_instance_->cur_sample_count % prev_samples_order;
459 size_t sample_start_index = cur_sample_index * num_linked_sensors;
460 for (size_t i = 0; i < num_linked_sensors; ++i) {
461 if (tflite_instance_->enable_input_validation) {
462 if (thermistors[i] < tflite_instance_->input_range[i].min_threshold ||
463 thermistors[i] > tflite_instance_->input_range[i].max_threshold) {
464 LOG(INFO) << "thermistors[" << i << "] value: " << thermistors[i]
465 << " not in range: " << tflite_instance_->input_range[i].min_threshold
466 << " <= val <= " << tflite_instance_->input_range[i].max_threshold;
467 common_instance_->cur_sample_count = 0;
468 return kVtEstimatorLowConfidence;
469 }
470 }
471 tflite_instance_->input_buffer[sample_start_index + i] = thermistors[i];
472 if (cur_sample_index == 0 && tflite_instance_->support_under_sampling) {
473 // fill previous samples if support under sampling
474 for (size_t j = 1; j < prev_samples_order; ++j) {
475 size_t copy_start_index = j * num_linked_sensors;
476 tflite_instance_->input_buffer[copy_start_index + i] = thermistors[i];
477 }
478 }
479 }
480
481 // Update sample count
482 common_instance_->cur_sample_count++;
483 tflite_instance_->prev_sample_time = boot_clock::now();
484 if ((common_instance_->cur_sample_count < prev_samples_order) &&
485 !(tflite_instance_->support_under_sampling)) {
486 return kVtEstimatorUnderSampling;
487 }
488
489 // prepare model input
490 float *model_input;
491 size_t input_buffer_size = tflite_instance_->input_buffer_size;
492 size_t output_buffer_size = tflite_instance_->output_buffer_size;
493 if (!common_instance_->use_prev_samples) {
494 model_input = tflite_instance_->input_buffer;
495 } else {
496 sample_start_index = ((cur_sample_index + 1) * num_linked_sensors) % input_buffer_size;
497 for (size_t i = 0; i < input_buffer_size; ++i) {
498 size_t input_index = (sample_start_index + i) % input_buffer_size;
499 tflite_instance_->scratch_buffer[i] = tflite_instance_->input_buffer[input_index];
500 }
501 model_input = tflite_instance_->scratch_buffer;
502 }
503
504 int ret = tflite_instance_->tflite_methods.invoke(
505 tflite_instance_->tflite_wrapper, model_input, input_buffer_size,
506 tflite_instance_->output_buffer, output_buffer_size);
507 if (ret) {
508 LOG(ERROR) << "Failed to Invoke for " << tflite_instance_->model_path << " (ret: " << ret
509 << ")";
510 return kVtEstimatorInvokeFailed;
511 }
512 tflite_instance_->last_update_time = boot_clock::now();
513
514 // prepare output
515 std::vector<float> data;
516 std::ostringstream model_out_log, predict_log;
517 data.reserve(output_buffer_size);
518 for (size_t i = 0; i < output_buffer_size; ++i) {
519 // add offset to predicted value
520 float predicted_value = tflite_instance_->output_buffer[i];
521 model_out_log << predicted_value << " ";
522 predicted_value += CalculateOffset(common_instance_->offset_thresholds,
523 common_instance_->offset_values, predicted_value);
524 predict_log << predicted_value << " ";
525 data.emplace_back(predicted_value);
526 }
527 LOG(INFO) << "model_output: [" << model_out_log.str() << "]";
528 LOG(INFO) << "predicted_value: [" << predict_log.str() << "]";
529 *output = data;
530
531 return kVtEstimatorOk;
532 }
533
Estimate(const std::vector<float> & thermistors,std::vector<float> * output)534 VtEstimatorStatus VirtualTempEstimator::Estimate(const std::vector<float> &thermistors,
535 std::vector<float> *output) {
536 if (type == kUseMLModel) {
537 return TFliteEstimate(thermistors, output);
538 } else if (type == kUseLinearModel) {
539 return LinearModelEstimate(thermistors, output);
540 }
541
542 LOG(ERROR) << "Unsupported estimationType [" << type << "]";
543 return kVtEstimatorUnSupported;
544 }
545
TFliteGetMaxPredictWindowMs(size_t * predict_window_ms)546 VtEstimatorStatus VirtualTempEstimator::TFliteGetMaxPredictWindowMs(size_t *predict_window_ms) {
547 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
548 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
549 return kVtEstimatorInitFailed;
550 }
551
552 if (!common_instance_->is_initialized) {
553 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
554 return kVtEstimatorInitFailed;
555 }
556
557 size_t window = tflite_instance_->predict_window_ms;
558 if (window == 0) {
559 return kVtEstimatorUnSupported;
560 }
561 *predict_window_ms = window;
562 return kVtEstimatorOk;
563 }
564
GetMaxPredictWindowMs(size_t * predict_window_ms)565 VtEstimatorStatus VirtualTempEstimator::GetMaxPredictWindowMs(size_t *predict_window_ms) {
566 if (type == kUseMLModel) {
567 return TFliteGetMaxPredictWindowMs(predict_window_ms);
568 }
569
570 LOG(ERROR) << "Unsupported estimationType [" << type << "]";
571 return kVtEstimatorUnSupported;
572 }
573
TFlitePredictAfterTimeMs(const size_t time_ms,float * output)574 VtEstimatorStatus VirtualTempEstimator::TFlitePredictAfterTimeMs(const size_t time_ms,
575 float *output) {
576 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
577 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
578 return kVtEstimatorInitFailed;
579 }
580
581 if (!common_instance_->is_initialized) {
582 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
583 return kVtEstimatorInitFailed;
584 }
585
586 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
587
588 size_t window = tflite_instance_->predict_window_ms;
589 auto sample_interval = tflite_instance_->sample_interval;
590 auto last_update_time = tflite_instance_->last_update_time;
591 auto request_time_ms = std::chrono::duration_cast<std::chrono::milliseconds>(boot_clock::now() -
592 last_update_time);
593 // check for under sampling
594 if ((common_instance_->cur_sample_count < common_instance_->prev_samples_order) &&
595 !(tflite_instance_->support_under_sampling)) {
596 LOG(INFO) << tflite_instance_->model_path
597 << " cannot provide prediction while under sampling";
598 return kVtEstimatorUnderSampling;
599 }
600
601 // calculate requested time since last update
602 request_time_ms = request_time_ms + std::chrono::milliseconds{time_ms};
603 if (sample_interval.count() == 0 || window == 0 ||
604 window < static_cast<size_t>(request_time_ms.count())) {
605 LOG(INFO) << tflite_instance_->model_path << " cannot predict temperature after ("
606 << time_ms << " + " << request_time_ms.count() - time_ms
607 << ") ms since last update with sample interval [" << sample_interval.count()
608 << "] ms and predict window [" << window << "] ms";
609 return kVtEstimatorUnSupported;
610 }
611
612 size_t request_step = request_time_ms / sample_interval;
613 size_t output_label_count = tflite_instance_->output_label_count;
614 float *output_buffer = tflite_instance_->output_buffer;
615 float prediction;
616 if (request_step == output_label_count - 1) {
617 // request prediction is on the right boundary of the window
618 prediction = output_buffer[output_label_count - 1];
619 } else {
620 float left = output_buffer[request_step], right = output_buffer[request_step + 1];
621 prediction = left;
622 if (left != right) {
623 prediction += (request_time_ms - sample_interval * request_step) * (right - left) /
624 sample_interval;
625 }
626 }
627
628 *output = prediction;
629
630 return kVtEstimatorOk;
631 }
632
PredictAfterTimeMs(const size_t time_ms,float * output)633 VtEstimatorStatus VirtualTempEstimator::PredictAfterTimeMs(const size_t time_ms, float *output) {
634 if (type == kUseMLModel) {
635 return TFlitePredictAfterTimeMs(time_ms, output);
636 }
637
638 LOG(ERROR) << "PredictAfterTimeMs not supported for type [" << type << "]";
639 return kVtEstimatorUnSupported;
640 }
641
TFliteGetAllPredictions(std::vector<float> * output)642 VtEstimatorStatus VirtualTempEstimator::TFliteGetAllPredictions(std::vector<float> *output) {
643 if (tflite_instance_ == nullptr || common_instance_ == nullptr) {
644 LOG(ERROR) << "tflite_instance_ or common_instance_ is nullptr for predict window\n";
645 return kVtEstimatorInitFailed;
646 }
647
648 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
649
650 if (!common_instance_->is_initialized) {
651 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
652 return kVtEstimatorInitFailed;
653 }
654
655 if (output == nullptr) {
656 LOG(ERROR) << "output is nullptr";
657 return kVtEstimatorInvalidArgs;
658 }
659
660 std::vector<float> tflite_output;
661 size_t output_buffer_size = tflite_instance_->output_buffer_size;
662 tflite_output.reserve(output_buffer_size);
663 for (size_t i = 0; i < output_buffer_size; ++i) {
664 tflite_output.emplace_back(tflite_instance_->output_buffer[i]);
665 }
666 *output = tflite_output;
667
668 return kVtEstimatorOk;
669 }
670
GetAllPredictions(std::vector<float> * output)671 VtEstimatorStatus VirtualTempEstimator::GetAllPredictions(std::vector<float> *output) {
672 if (type == kUseMLModel) {
673 return TFliteGetAllPredictions(output);
674 }
675
676 LOG(INFO) << "GetAllPredicts not supported by estimationType [" << type << "]";
677 return kVtEstimatorUnSupported;
678 }
679
TFLiteDumpStatus(std::string_view sensor_name,std::ostringstream * dump_buf)680 VtEstimatorStatus VirtualTempEstimator::TFLiteDumpStatus(std::string_view sensor_name,
681 std::ostringstream *dump_buf) {
682 if (dump_buf == nullptr) {
683 LOG(ERROR) << "dump_buf is nullptr for " << sensor_name;
684 return kVtEstimatorInvalidArgs;
685 }
686
687 if (!common_instance_->is_initialized) {
688 LOG(ERROR) << "tflite_instance_ not initialized for " << tflite_instance_->model_path;
689 return kVtEstimatorInitFailed;
690 }
691
692 std::unique_lock<std::mutex> lock(tflite_instance_->tflite_methods.mutex);
693
694 *dump_buf << " Sensor Name: " << sensor_name << std::endl;
695 *dump_buf << " Current Values: ";
696 size_t output_buffer_size = tflite_instance_->output_buffer_size;
697 for (size_t i = 0; i < output_buffer_size; ++i) {
698 // add offset to predicted value
699 float predicted_value = tflite_instance_->output_buffer[i];
700 predicted_value += CalculateOffset(common_instance_->offset_thresholds,
701 common_instance_->offset_values, predicted_value);
702 *dump_buf << predicted_value << ", ";
703 }
704 *dump_buf << std::endl;
705
706 *dump_buf << " Model Path: \"" << tflite_instance_->model_path << "\"" << std::endl;
707
708 return kVtEstimatorOk;
709 }
710
DumpStatus(std::string_view sensor_name,std::ostringstream * dump_buff)711 VtEstimatorStatus VirtualTempEstimator::DumpStatus(std::string_view sensor_name,
712 std::ostringstream *dump_buff) {
713 if (type == kUseMLModel) {
714 return TFLiteDumpStatus(sensor_name, dump_buff);
715 }
716
717 LOG(INFO) << "DumpStatus not supported by estimationType [" << type << "]";
718 return kVtEstimatorUnSupported;
719 }
720
Initialize(const VtEstimationInitData & data)721 VtEstimatorStatus VirtualTempEstimator::Initialize(const VtEstimationInitData &data) {
722 LOG(INFO) << "Initialize VirtualTempEstimator for " << type;
723
724 if (type == kUseMLModel) {
725 return TFliteInitialize(data.ml_model_init_data);
726 } else if (type == kUseLinearModel) {
727 return LinearModelInitialize(data.linear_model_init_data);
728 }
729
730 LOG(ERROR) << "Unsupported estimationType [" << type << "]";
731 return kVtEstimatorUnSupported;
732 }
733
ParseInputConfig(const Json::Value & input_config)734 bool VirtualTempEstimator::ParseInputConfig(const Json::Value &input_config) {
735 if (!input_config["ModelConfig"].empty()) {
736 if (!input_config["ModelConfig"]["sample_interval_ms"].empty()) {
737 // read input sample interval
738 int sample_interval_ms = input_config["ModelConfig"]["sample_interval_ms"].asInt();
739 if (sample_interval_ms <= 0) {
740 LOG(ERROR) << "Invalid sample_interval_ms: " << sample_interval_ms;
741 return false;
742 }
743
744 tflite_instance_->sample_interval = std::chrono::milliseconds{sample_interval_ms};
745 LOG(INFO) << "Parsed tflite model input sample_interval: " << sample_interval_ms
746 << " for " << common_instance_->sensor_name;
747
748 // determine predict window
749 tflite_instance_->predict_window_ms =
750 sample_interval_ms * (tflite_instance_->output_label_count - 1);
751 LOG(INFO) << "Max prediction window size: " << tflite_instance_->predict_window_ms
752 << " ms for " << common_instance_->sensor_name;
753 }
754
755 if (!input_config["ModelConfig"]["max_sample_interval_ms"].empty()) {
756 // read input max sample interval
757 int max_sample_interval_ms =
758 input_config["ModelConfig"]["max_sample_interval_ms"].asInt();
759 if (max_sample_interval_ms <= 0) {
760 LOG(ERROR) << "Invalid max_sample_interval_ms " << max_sample_interval_ms;
761 return false;
762 }
763
764 tflite_instance_->max_sample_interval =
765 std::chrono::milliseconds{max_sample_interval_ms};
766 LOG(INFO) << "Parsed tflite model max_sample_interval: " << max_sample_interval_ms
767 << " for " << common_instance_->sensor_name;
768 }
769 }
770
771 if (!input_config["InputData"].empty()) {
772 Json::Value input_data = input_config["InputData"];
773 if (input_data.size() != common_instance_->num_linked_sensors) {
774 LOG(ERROR) << "Input ranges size: " << input_data.size()
775 << " does not match num_linked_sensors: "
776 << common_instance_->num_linked_sensors;
777 return false;
778 }
779
780 LOG(INFO) << "Start to parse tflite model input config for "
781 << common_instance_->num_linked_sensors;
782 tflite_instance_->input_range.assign(input_data.size(), InputRangeInfo());
783 for (Json::Value::ArrayIndex i = 0; i < input_data.size(); ++i) {
784 const std::string &name = input_data[i]["Name"].asString();
785 LOG(INFO) << "Sensor[" << i << "] Name: " << name;
786 if (!getInputRangeInfoFromJsonValues(input_data[i]["Range"],
787 &tflite_instance_->input_range[i])) {
788 LOG(ERROR) << "Failed to parse tflite model temp range for sensor: [" << name
789 << "]";
790 return false;
791 }
792 }
793 }
794
795 return true;
796 }
797
GetInputConfig(Json::Value * config)798 bool VirtualTempEstimator::GetInputConfig(Json::Value *config) {
799 int config_size = 0;
800 int ret = tflite_instance_->tflite_methods.get_input_config_size(
801 tflite_instance_->tflite_wrapper, &config_size);
802 if (ret || config_size <= 0) {
803 LOG(ERROR) << "Failed to get tflite input config size (ret: " << ret
804 << ") with size: " << config_size;
805 return false;
806 }
807
808 LOG(INFO) << "Model input config_size: " << config_size << " for "
809 << common_instance_->sensor_name;
810
811 char *config_str = new char[config_size];
812 ret = tflite_instance_->tflite_methods.get_input_config(tflite_instance_->tflite_wrapper,
813 config_str, config_size);
814 if (ret) {
815 LOG(ERROR) << "Failed to get tflite input config (ret: " << ret << ")";
816 delete[] config_str;
817 return false;
818 }
819
820 Json::CharReaderBuilder builder;
821 std::unique_ptr<Json::CharReader> reader(builder.newCharReader());
822 std::string errorMessage;
823
824 bool success = true;
825 if (!reader->parse(config_str, config_str + config_size, config, &errorMessage)) {
826 LOG(ERROR) << "Failed to parse tflite JSON input config: " << errorMessage;
827 success = false;
828 }
829 delete[] config_str;
830 return success;
831 }
832
833 } // namespace vtestimator
834 } // namespace thermal
835