1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <android-base/chrono_utils.h> 17 18 #include <cstddef> 19 #include <mutex> 20 #include <string> 21 22 #pragma once 23 24 namespace thermal { 25 namespace vtestimator { 26 27 using android::base::boot_clock; 28 29 // Current version only supports single input/output tensors 30 constexpr int kNumInputTensors = 1; 31 constexpr int kNumOutputTensors = 1; 32 33 typedef void *(*tflitewrapper_create)(int num_input_tensors, int num_output_tensors); 34 typedef bool (*tflitewrapper_init)(void *handle, const char *model_path); 35 typedef bool (*tflitewrapper_invoke)(void *handle, float *input_samples, int num_input_samples, 36 float *output_samples, int num_output_samples); 37 typedef void (*tflitewrapper_destroy)(void *handle); 38 typedef bool (*tflitewrapper_get_input_config_size)(void *handle, int *config_size); 39 typedef bool (*tflitewrapper_get_input_config)(void *handle, char *config_buffer, 40 int config_buffer_size); 41 42 struct TFLiteWrapperMethods { 43 tflitewrapper_create create; 44 tflitewrapper_init init; 45 tflitewrapper_invoke invoke; 46 tflitewrapper_destroy destroy; 47 tflitewrapper_get_input_config_size get_input_config_size; 48 tflitewrapper_get_input_config get_input_config; 49 mutable std::mutex mutex; 50 }; 51 52 struct InputRangeInfo { 53 float max_threshold = std::numeric_limits<float>::max(); 54 float min_threshold = std::numeric_limits<float>::min(); 55 }; 56 57 struct VtEstimatorCommonData { VtEstimatorCommonDataVtEstimatorCommonData58 VtEstimatorCommonData(std::string_view name, size_t num_input_sensors) { 59 sensor_name = name; 60 num_linked_sensors = num_input_sensors; 61 prev_samples_order = 1; 62 is_initialized = false; 63 use_prev_samples = false; 64 cur_sample_count = 0; 65 } 66 std::string sensor_name; 67 68 std::vector<float> offset_thresholds; 69 std::vector<float> offset_values; 70 71 size_t num_linked_sensors; 72 size_t prev_samples_order; 73 size_t cur_sample_count; 74 bool use_prev_samples; 75 bool is_initialized; 76 }; 77 78 struct VtEstimatorTFLiteData { VtEstimatorTFLiteDataVtEstimatorTFLiteData79 VtEstimatorTFLiteData() { 80 scratch_buffer = nullptr; 81 input_buffer = nullptr; 82 input_buffer_size = 0; 83 output_label_count = 1; 84 num_hot_spots = 1; 85 output_buffer = nullptr; 86 output_buffer_size = 1; 87 support_under_sampling = false; 88 sample_interval = std::chrono::milliseconds{0}; 89 max_sample_interval = std::chrono::milliseconds{std::numeric_limits<int>::max()}; 90 predict_window_ms = 0; 91 last_update_time = boot_clock::time_point::min(); 92 prev_sample_time = boot_clock::time_point::min(); 93 enable_input_validation = false; 94 95 tflite_wrapper = nullptr; 96 tflite_methods.create = nullptr; 97 tflite_methods.init = nullptr; 98 tflite_methods.get_input_config_size = nullptr; 99 tflite_methods.get_input_config = nullptr; 100 tflite_methods.invoke = nullptr; 101 tflite_methods.destroy = nullptr; 102 } 103 104 void *tflite_wrapper; 105 float *scratch_buffer; 106 float *input_buffer; 107 size_t input_buffer_size; 108 size_t num_hot_spots; 109 size_t output_label_count; 110 float *output_buffer; 111 size_t output_buffer_size; 112 std::string model_path; 113 TFLiteWrapperMethods tflite_methods; 114 std::vector<InputRangeInfo> input_range; 115 bool support_under_sampling; 116 std::chrono::milliseconds sample_interval{}; 117 std::chrono::milliseconds max_sample_interval{}; 118 size_t predict_window_ms; 119 boot_clock::time_point last_update_time; 120 boot_clock::time_point prev_sample_time; 121 bool enable_input_validation; 122 ~VtEstimatorTFLiteDataVtEstimatorTFLiteData123 ~VtEstimatorTFLiteData() { 124 if (tflite_wrapper && tflite_methods.destroy) { 125 tflite_methods.destroy(tflite_wrapper); 126 } 127 128 if (scratch_buffer) { 129 delete scratch_buffer; 130 } 131 132 if (input_buffer) { 133 delete input_buffer; 134 } 135 136 if (output_buffer) { 137 delete output_buffer; 138 } 139 } 140 }; 141 142 struct VtEstimatorLinearModelData { VtEstimatorLinearModelDataVtEstimatorLinearModelData143 VtEstimatorLinearModelData() {} 144 ~VtEstimatorLinearModelDataVtEstimatorLinearModelData145 ~VtEstimatorLinearModelData() {} 146 147 std::vector<std::vector<float>> input_samples; 148 std::vector<std::vector<float>> coefficients; 149 mutable std::mutex mutex; 150 }; 151 152 } // namespace vtestimator 153 } // namespace thermal 154