1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #pragma once 17 18 #include <json/value.h> 19 20 #include <sstream> 21 #include <vector> 22 23 #include "virtualtemp_estimator_data.h" 24 25 namespace thermal { 26 namespace vtestimator { 27 28 enum VtEstimatorStatus { 29 kVtEstimatorOk = 0, 30 kVtEstimatorInvalidArgs = 1, 31 kVtEstimatorInitFailed = 2, 32 kVtEstimatorInvokeFailed = 3, 33 kVtEstimatorUnSupported = 4, 34 kVtEstimatorLowConfidence = 5, 35 kVtEstimatorUnderSampling = 6, 36 }; 37 38 enum VtEstimationType { kUseMLModel = 0, kUseLinearModel = 1, kInvalidEstimationType = 2 }; 39 40 struct MLModelInitData { 41 std::string model_path; 42 bool use_prev_samples; 43 size_t prev_samples_order; 44 size_t output_label_count; 45 size_t num_hot_spots; 46 bool enable_input_validation; 47 std::vector<float> offset_thresholds; 48 std::vector<float> offset_values; 49 bool support_under_sampling; 50 }; 51 52 struct LinearModelInitData { 53 bool use_prev_samples; 54 size_t prev_samples_order; 55 std::vector<float> coefficients; 56 std::vector<float> offset_thresholds; 57 std::vector<float> offset_values; 58 }; 59 60 union VtEstimationInitData { VtEstimationInitData(VtEstimationType type)61 VtEstimationInitData(VtEstimationType type) { 62 if (type == kUseMLModel) { 63 ml_model_init_data.model_path = ""; 64 ml_model_init_data.use_prev_samples = false; 65 ml_model_init_data.prev_samples_order = 1; 66 ml_model_init_data.output_label_count = 1; 67 ml_model_init_data.num_hot_spots = 1; 68 ml_model_init_data.enable_input_validation = false; 69 ml_model_init_data.support_under_sampling = false; 70 } else if (type == kUseLinearModel) { 71 linear_model_init_data.use_prev_samples = false; 72 linear_model_init_data.prev_samples_order = 1; 73 } 74 } ~VtEstimationInitData()75 ~VtEstimationInitData() {} 76 77 MLModelInitData ml_model_init_data; 78 LinearModelInitData linear_model_init_data; 79 }; 80 81 // Class to estimate virtual temperature 82 class VirtualTempEstimator { 83 public: 84 // Implicit copy-move headers. 85 VirtualTempEstimator(const VirtualTempEstimator &) = delete; 86 VirtualTempEstimator(VirtualTempEstimator &&) = default; 87 VirtualTempEstimator &operator=(const VirtualTempEstimator &) = delete; 88 VirtualTempEstimator &operator=(VirtualTempEstimator &&) = default; 89 90 VirtualTempEstimator(std::string_view sensor_name, VtEstimationType type, 91 size_t num_linked_sensors); 92 ~VirtualTempEstimator(); 93 94 // Initializes the estimator based on init_data 95 VtEstimatorStatus Initialize(const VtEstimationInitData &init_data); 96 97 // Performs the prediction and returns estimated value in output 98 VtEstimatorStatus Estimate(const std::vector<float> &thermistors, std::vector<float> *output); 99 100 // Dump estimator status 101 VtEstimatorStatus DumpStatus(std::string_view sensor_name, std::ostringstream *dump_buf); 102 // Get predict window width in milliseconds 103 VtEstimatorStatus GetMaxPredictWindowMs(size_t *predict_window_ms); 104 // Predict temperature after desired milliseconds 105 VtEstimatorStatus PredictAfterTimeMs(const size_t time_ms, float *output); 106 // Get entire output buffer of the estimator 107 VtEstimatorStatus GetAllPredictions(std::vector<float> *output); 108 109 // Adds traces to help debug 110 VtEstimatorStatus DumpTraces(); 111 112 private: 113 void LoadTFLiteWrapper(); 114 VtEstimationType type; 115 std::unique_ptr<VtEstimatorCommonData> common_instance_; 116 std::unique_ptr<VtEstimatorTFLiteData> tflite_instance_; 117 std::unique_ptr<VtEstimatorLinearModelData> linear_model_instance_; 118 119 VtEstimatorStatus LinearModelInitialize(LinearModelInitData data); 120 VtEstimatorStatus TFliteInitialize(MLModelInitData data); 121 122 VtEstimatorStatus LinearModelEstimate(const std::vector<float> &thermistors, 123 std::vector<float> *output); 124 VtEstimatorStatus TFliteEstimate(const std::vector<float> &thermistors, 125 std::vector<float> *output); 126 VtEstimatorStatus TFliteGetMaxPredictWindowMs(size_t *predict_window_ms); 127 VtEstimatorStatus TFlitePredictAfterTimeMs(const size_t time_ms, float *output); 128 VtEstimatorStatus TFliteGetAllPredictions(std::vector<float> *output); 129 130 VtEstimatorStatus TFLiteDumpStatus(std::string_view sensor_name, std::ostringstream *dump_buf); 131 bool GetInputConfig(Json::Value *config); 132 bool ParseInputConfig(const Json::Value &config); 133 }; 134 135 } // namespace vtestimator 136 } // namespace thermal 137