1 /* <lambda>null2 * Copyright 2020 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package androidx.compose.animation.core 18 19 import android.os.Build 20 import androidx.test.filters.SdkSuppress 21 import junit.framework.TestCase.assertTrue 22 import kotlin.math.abs 23 import kotlin.math.max 24 import kotlin.math.sqrt 25 import org.junit.Test 26 import org.junit.runner.RunWith 27 import org.junit.runners.Parameterized 28 29 @RunWith(Parameterized::class) 30 class SpringEstimationTest(private val m: Double, private val k: Double) { 31 companion object { 32 private const val TwoFrames60fpsMillis = 33L 33 34 @JvmStatic 35 @Parameterized.Parameters 36 fun data(): List<Array<out Any>> { 37 return mutableListOf<Array<out Any>>().apply { 38 (1..100 step 2) 39 .map { it * it } 40 .forEach { m -> 41 (1..1000 step 20) 42 .map { it * it } 43 .forEach { k -> add(arrayOf(m.toDouble(), k.toDouble())) } 44 } 45 // Additional edge cases to test for 46 add(arrayOf(10_000.0, 1.0)) 47 } 48 } 49 } 50 51 @Test 52 @SdkSuppress(minSdkVersion = Build.VERSION_CODES.N) // parallelStream() requires API level 24 53 fun runTestCases() { 54 val failedTestCaseResults = mutableListOf<TestCaseResult>() 55 56 val testCases = generateTestCases() 57 println("Generated ${testCases.size} test cases") 58 59 testCases.parallelStream().forEach { 60 val res = runTestCase(it) 61 if (!res.pass) { 62 synchronized(failedTestCaseResults) { failedTestCaseResults.add(res) } 63 } 64 } 65 66 if (failedTestCaseResults.isNotEmpty()) { 67 println("Failed ${failedTestCaseResults.size} test cases") 68 failedTestCaseResults.forEach { 69 println(it.testCase) 70 if (it.reason.isNotBlank()) { 71 println("\treason:${it.reason}") 72 } 73 } 74 } 75 76 assertTrue(failedTestCaseResults.isEmpty()) 77 } 78 79 private fun runTestCase(testCase: TestCase): TestCaseResult { 80 val springSimulation = SpringSimulation(0.0f) 81 springSimulation.dampingRatio = testCase.dampingRatio.toFloat() 82 springSimulation.stiffness = testCase.stiffness.toFloat() 83 84 val endTime = 85 estimateAnimationDurationMillis( 86 mass = testCase.mass, 87 springConstant = testCase.springConstant, 88 dampingCoefficient = testCase.dampingCoefficient, 89 initialDisplacement = testCase.initialDisplacement, 90 initialVelocity = testCase.initialVelocity, 91 delta = 1.0 92 ) 93 94 val alternateEndTime = 95 estimateAnimationDurationMillis( 96 stiffness = testCase.stiffness, 97 dampingRatio = testCase.dampingRatio, 98 initialDisplacement = testCase.initialDisplacement, 99 initialVelocity = testCase.initialVelocity, 100 delta = 1.0 101 ) 102 103 // Test that the alternate implementation gives the same answer within 1ms. 104 if (abs(endTime - alternateEndTime) > 1) { 105 return TestCaseResult( 106 pass = false, 107 testCase = testCase, 108 reason = "stiffness/dampingRatio implementation discrepancy" 109 ) 110 } 111 112 if (endTime == Long.MAX_VALUE) return TestCaseResult(false, testCase, "End time +infinity") 113 114 val simTwoFramesAfter = 115 springSimulation.updateValues( 116 lastDisplacement = testCase.initialDisplacement.toFloat(), 117 lastVelocity = testCase.initialVelocity.toFloat(), 118 timeElapsed = endTime + TwoFrames60fpsMillis 119 ) 120 val simTwoFramesBefore = 121 springSimulation.updateValues( 122 lastDisplacement = testCase.initialDisplacement.toFloat(), 123 lastVelocity = testCase.initialVelocity.toFloat(), 124 timeElapsed = max(endTime - TwoFrames60fpsMillis, 0L) 125 ) 126 val simAtTime = 127 springSimulation.updateValues( 128 lastDisplacement = testCase.initialDisplacement.toFloat(), 129 lastVelocity = testCase.initialVelocity.toFloat(), 130 timeElapsed = endTime 131 ) 132 133 val pass = 134 if (testCase.dampingRatio >= 1.0) { 135 // The primary success criterion is that two frames before the settling time, the 136 // function x(t) is greater than the threshold and two frames after. 137 138 // A secondary criterion is added to account for scenarios where the settling time 139 // is 140 // close to the inflection point in over/critically-damped cases, and therefore the 141 // before and after times are both below the threshold. 142 ((abs(simTwoFramesBefore.value) >= 0.999 && 143 abs(simTwoFramesAfter.value) <= 1.001) || 144 (abs(simAtTime.value) >= 0.999 && 145 abs(simTwoFramesBefore.value) < abs(simAtTime.value) && 146 abs(simTwoFramesAfter.value) < abs(simAtTime.value))) 147 } else { 148 // In the under-damped scenario, x(t) varies heavily due to oscillations, therefore 149 // the over/critically damped conditions may fail erroneously. 150 abs(simTwoFramesAfter.value) < 1.00 151 } 152 153 return TestCaseResult(pass, testCase) 154 } 155 156 private fun generateTestCases(): List<TestCase> { 157 val testCases = mutableListOf<TestCase>() 158 159 // Generate general test cases that broadly cover the over and under damped test cases 160 for (c in 1..10_000 step 500) { 161 for (v0 in -200_000..200_000 step 100_000) { 162 for (p0 in -10_000..10_000 step 100) { 163 if (!(v0 == 0 && p0 == 0)) { 164 val testCase = 165 TestCase( 166 mass = m, 167 springConstant = k, 168 dampingCoefficient = c.toDouble(), 169 initialVelocity = v0.toDouble(), 170 initialDisplacement = p0.toDouble() 171 ) 172 synchronized(testCases) { testCases.add(testCase) } 173 } 174 } 175 } 176 } 177 178 // Generate specific test cases that cover the critically damped test cases 179 180 // Guarantee a damping ratio of 1.0 by fixing c such that 181 // c^2 = 4mk 182 val c = 2.0 * sqrt(k * m) 183 for (v0 in -200_000..200_000 step 10_000) { 184 for (p0 in -10_000..10_000 step 100) { 185 if (!(v0 == 0 && p0 == 0)) { 186 val testCase = 187 TestCase( 188 mass = m, 189 springConstant = k, 190 dampingCoefficient = c, 191 initialVelocity = v0.toDouble(), 192 initialDisplacement = p0.toDouble() 193 ) 194 195 synchronized(testCases) { testCases.add(testCase) } 196 } 197 } 198 } 199 return testCases 200 } 201 202 private data class TestCase( 203 val mass: Double, 204 val springConstant: Double, 205 val dampingCoefficient: Double, 206 val initialVelocity: Double, 207 val initialDisplacement: Double 208 ) { 209 val dampingRatio: Double 210 get() { 211 val criticalDamping = 2.0 * sqrt(springConstant * mass) 212 return dampingCoefficient / criticalDamping 213 } 214 215 val stiffness: Double 216 get() { 217 return springConstant / mass 218 } 219 } 220 221 private data class TestCaseResult( 222 val pass: Boolean, 223 val testCase: TestCase, 224 val reason: String = "" 225 ) 226 } 227