1 /* 2 * Copyright (C) 2023 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #pragma once 18 19 #include "PosePredictorVerifier.h" 20 #include <memory> 21 #include <audio_utils/Statistics.h> 22 #include <media/PosePredictorType.h> 23 #include <media/Twist.h> 24 #include <media/VectorRecorder.h> 25 26 namespace android::media { 27 28 // Interface for generic pose predictors 29 class PredictorBase { 30 public: 31 virtual ~PredictorBase() = default; 32 virtual void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) = 0; 33 virtual Pose3f predict(int64_t atNs) const = 0; 34 virtual void reset() = 0; 35 virtual std::string name() const = 0; 36 virtual std::string toString(size_t index) const = 0; 37 }; 38 39 /** 40 * LastPredictor uses the last sample Pose for prediction 41 * 42 * This class is not thread-safe. 43 */ 44 class LastPredictor : public PredictorBase { 45 public: add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)46 void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override { 47 (void)atNs; 48 (void)twist; 49 mLastPose = pose; 50 } 51 predict(int64_t atNs)52 Pose3f predict(int64_t atNs) const override { 53 (void)atNs; 54 return mLastPose; 55 } 56 reset()57 void reset() override { 58 mLastPose = {}; 59 } 60 name()61 std::string name() const override { 62 return "LAST"; 63 } 64 toString(size_t index)65 std::string toString(size_t index) const override { 66 std::string s(index, ' '); 67 s.append("LastPredictor using last pose: ") 68 .append(mLastPose.toString()) 69 .append("\n"); 70 return s; 71 } 72 73 private: 74 Pose3f mLastPose; 75 }; 76 77 /** 78 * TwistPredictor uses the last sample Twist and Pose for prediction 79 * 80 * This class is not thread-safe. 81 */ 82 class TwistPredictor : public PredictorBase { 83 public: add(int64_t atNs,const Pose3f & pose,const Twist3f & twist)84 void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override { 85 mLastAtNs = atNs; 86 mLastPose = pose; 87 mLastTwist = twist; 88 } 89 predict(int64_t atNs)90 Pose3f predict(int64_t atNs) const override { 91 return mLastPose * integrate(mLastTwist, atNs - mLastAtNs); 92 } 93 reset()94 void reset() override { 95 mLastAtNs = {}; 96 mLastPose = {}; 97 mLastTwist = {}; 98 } 99 name()100 std::string name() const override { 101 return "TWIST"; 102 } 103 toString(size_t index)104 std::string toString(size_t index) const override { 105 std::string s(index, ' '); 106 s.append("TwistPredictor using last pose: ") 107 .append(mLastPose.toString()) 108 .append(" last twist: ") 109 .append(mLastTwist.toString()) 110 .append("\n"); 111 return s; 112 } 113 114 private: 115 int64_t mLastAtNs{}; 116 Pose3f mLastPose; 117 Twist3f mLastTwist; 118 }; 119 120 121 /** 122 * LeastSquaresPredictor uses the Pose history for prediction. 123 * 124 * A exponential weighted least squares is used. 125 * 126 * This class is not thread-safe. 127 */ 128 class LeastSquaresPredictor : public PredictorBase { 129 public: 130 // alpha is the exponential decay. 131 LeastSquaresPredictor(double alpha = kDefaultAlphaEstimator) mAlpha(alpha)132 : mAlpha(alpha) 133 , mRw(alpha) 134 , mRx(alpha) 135 , mRy(alpha) 136 , mRz(alpha) 137 {} 138 139 void add(int64_t atNs, const Pose3f& pose, const Twist3f& twist) override; 140 Pose3f predict(int64_t atNs) const override; 141 void reset() override; name()142 std::string name() const override { 143 return "LEAST_SQUARES(" + std::to_string(mAlpha) + ")"; 144 } 145 std::string toString(size_t index) const override; 146 147 private: 148 const double mAlpha; 149 int64_t mLastAtNs{}; 150 Pose3f mLastPose; 151 static constexpr double kDefaultAlphaEstimator = 0.2; 152 static constexpr size_t kMinimumSamplesForPrediction = 4; 153 audio_utils::LinearLeastSquaresFit<double> mRw; 154 audio_utils::LinearLeastSquaresFit<double> mRx; 155 audio_utils::LinearLeastSquaresFit<double> mRy; 156 audio_utils::LinearLeastSquaresFit<double> mRz; 157 }; 158 159 /* 160 * PosePredictor predicts the pose given sensor input at a time in the future. 161 * 162 * This class is not thread safe. 163 */ 164 class PosePredictor { 165 public: 166 PosePredictor(); 167 168 Pose3f predict(int64_t timestampNs, const Pose3f& pose, const Twist3f& twist, 169 float predictionDurationNs); 170 171 void setPosePredictorType(PosePredictorType type); 172 173 // convert predictions to a printable string 174 std::string toString(size_t index) const; 175 176 private: 177 static constexpr int64_t kMaximumSampleIntervalBeforeResetNs = 178 300'000'000; 179 180 // Predictors 181 const std::vector<std::shared_ptr<PredictorBase>> mPredictors; 182 183 // Verifiers, create one for an array of future lookaheads for comparison. 184 const std::vector<int> mLookaheadMs; 185 186 std::vector<PosePredictorVerifier> mVerifiers; 187 188 const std::vector<size_t> mDelimiterIdx; 189 190 // Recorders 191 media::VectorRecorder mPredictionRecorder{ 192 std::size(mVerifiers) /* vectorSize */, std::chrono::seconds(1), 10 /* maxLogLine */, 193 mDelimiterIdx}; 194 media::VectorRecorder mPredictionDurableRecorder{ 195 std::size(mVerifiers) /* vectorSize */, std::chrono::minutes(1), 10 /* maxLogLine */, 196 mDelimiterIdx}; 197 198 // Status 199 200 // SetType is the externally set predictor type. It may include AUTO. 201 PosePredictorType mSetType = PosePredictorType::LEAST_SQUARES; 202 203 // CurrentType is the actual predictor type used by this class. 204 // It does not include AUTO because that metatype means the class 205 // chooses the best predictor type based on sensor statistics. 206 PosePredictorType mCurrentType = PosePredictorType::LEAST_SQUARES; 207 208 int64_t mResets{}; 209 int64_t mLastTimestampNs{}; 210 211 // Returns current predictor 212 std::shared_ptr<PredictorBase> getCurrentPredictor() const; 213 }; 214 215 } // namespace android::media 216