1 /*
<lambda>null2  * Copyright 2020 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.compose.animation.core
18 
19 import android.os.Build
20 import androidx.test.filters.SdkSuppress
21 import junit.framework.TestCase.assertTrue
22 import kotlin.math.abs
23 import kotlin.math.max
24 import kotlin.math.sqrt
25 import org.junit.Test
26 import org.junit.runner.RunWith
27 import org.junit.runners.Parameterized
28 
29 @RunWith(Parameterized::class)
30 class SpringEstimationTest(private val m: Double, private val k: Double) {
31     companion object {
32         private const val TwoFrames60fpsMillis = 33L
33 
34         @JvmStatic
35         @Parameterized.Parameters
36         fun data(): List<Array<out Any>> {
37             return mutableListOf<Array<out Any>>().apply {
38                 (1..100 step 2)
39                     .map { it * it }
40                     .forEach { m ->
41                         (1..1000 step 20)
42                             .map { it * it }
43                             .forEach { k -> add(arrayOf(m.toDouble(), k.toDouble())) }
44                     }
45                 // Additional edge cases to test for
46                 add(arrayOf(10_000.0, 1.0))
47             }
48         }
49     }
50 
51     @Test
52     @SdkSuppress(minSdkVersion = Build.VERSION_CODES.N) // parallelStream() requires API level 24
53     fun runTestCases() {
54         val failedTestCaseResults = mutableListOf<TestCaseResult>()
55 
56         val testCases = generateTestCases()
57         println("Generated ${testCases.size} test cases")
58 
59         testCases.parallelStream().forEach {
60             val res = runTestCase(it)
61             if (!res.pass) {
62                 synchronized(failedTestCaseResults) { failedTestCaseResults.add(res) }
63             }
64         }
65 
66         if (failedTestCaseResults.isNotEmpty()) {
67             println("Failed ${failedTestCaseResults.size} test cases")
68             failedTestCaseResults.forEach {
69                 println(it.testCase)
70                 if (it.reason.isNotBlank()) {
71                     println("\treason:${it.reason}")
72                 }
73             }
74         }
75 
76         assertTrue(failedTestCaseResults.isEmpty())
77     }
78 
79     private fun runTestCase(testCase: TestCase): TestCaseResult {
80         val springSimulation = SpringSimulation(0.0f)
81         springSimulation.dampingRatio = testCase.dampingRatio.toFloat()
82         springSimulation.stiffness = testCase.stiffness.toFloat()
83 
84         val endTime =
85             estimateAnimationDurationMillis(
86                 mass = testCase.mass,
87                 springConstant = testCase.springConstant,
88                 dampingCoefficient = testCase.dampingCoefficient,
89                 initialDisplacement = testCase.initialDisplacement,
90                 initialVelocity = testCase.initialVelocity,
91                 delta = 1.0
92             )
93 
94         val alternateEndTime =
95             estimateAnimationDurationMillis(
96                 stiffness = testCase.stiffness,
97                 dampingRatio = testCase.dampingRatio,
98                 initialDisplacement = testCase.initialDisplacement,
99                 initialVelocity = testCase.initialVelocity,
100                 delta = 1.0
101             )
102 
103         // Test that the alternate implementation gives the same answer within 1ms.
104         if (abs(endTime - alternateEndTime) > 1) {
105             return TestCaseResult(
106                 pass = false,
107                 testCase = testCase,
108                 reason = "stiffness/dampingRatio implementation discrepancy"
109             )
110         }
111 
112         if (endTime == Long.MAX_VALUE) return TestCaseResult(false, testCase, "End time +infinity")
113 
114         val simTwoFramesAfter =
115             springSimulation.updateValues(
116                 lastDisplacement = testCase.initialDisplacement.toFloat(),
117                 lastVelocity = testCase.initialVelocity.toFloat(),
118                 timeElapsed = endTime + TwoFrames60fpsMillis
119             )
120         val simTwoFramesBefore =
121             springSimulation.updateValues(
122                 lastDisplacement = testCase.initialDisplacement.toFloat(),
123                 lastVelocity = testCase.initialVelocity.toFloat(),
124                 timeElapsed = max(endTime - TwoFrames60fpsMillis, 0L)
125             )
126         val simAtTime =
127             springSimulation.updateValues(
128                 lastDisplacement = testCase.initialDisplacement.toFloat(),
129                 lastVelocity = testCase.initialVelocity.toFloat(),
130                 timeElapsed = endTime
131             )
132 
133         val pass =
134             if (testCase.dampingRatio >= 1.0) {
135                 // The primary success criterion is that two frames before the settling time, the
136                 // function x(t) is greater than the threshold and two frames after.
137 
138                 // A secondary criterion is added to account for scenarios where the settling time
139                 // is
140                 // close to the inflection point in over/critically-damped cases, and therefore the
141                 // before and after times are both below the threshold.
142                 ((abs(simTwoFramesBefore.value) >= 0.999 &&
143                     abs(simTwoFramesAfter.value) <= 1.001) ||
144                     (abs(simAtTime.value) >= 0.999 &&
145                         abs(simTwoFramesBefore.value) < abs(simAtTime.value) &&
146                         abs(simTwoFramesAfter.value) < abs(simAtTime.value)))
147             } else {
148                 // In the under-damped scenario, x(t) varies heavily due to oscillations, therefore
149                 // the over/critically damped conditions may fail erroneously.
150                 abs(simTwoFramesAfter.value) < 1.00
151             }
152 
153         return TestCaseResult(pass, testCase)
154     }
155 
156     private fun generateTestCases(): List<TestCase> {
157         val testCases = mutableListOf<TestCase>()
158 
159         // Generate general test cases that broadly cover the over and under damped test cases
160         for (c in 1..10_000 step 500) {
161             for (v0 in -200_000..200_000 step 100_000) {
162                 for (p0 in -10_000..10_000 step 100) {
163                     if (!(v0 == 0 && p0 == 0)) {
164                         val testCase =
165                             TestCase(
166                                 mass = m,
167                                 springConstant = k,
168                                 dampingCoefficient = c.toDouble(),
169                                 initialVelocity = v0.toDouble(),
170                                 initialDisplacement = p0.toDouble()
171                             )
172                         synchronized(testCases) { testCases.add(testCase) }
173                     }
174                 }
175             }
176         }
177 
178         // Generate specific test cases that cover the critically damped test cases
179 
180         // Guarantee a damping ratio of 1.0 by fixing c such that
181         // c^2 = 4mk
182         val c = 2.0 * sqrt(k * m)
183         for (v0 in -200_000..200_000 step 10_000) {
184             for (p0 in -10_000..10_000 step 100) {
185                 if (!(v0 == 0 && p0 == 0)) {
186                     val testCase =
187                         TestCase(
188                             mass = m,
189                             springConstant = k,
190                             dampingCoefficient = c,
191                             initialVelocity = v0.toDouble(),
192                             initialDisplacement = p0.toDouble()
193                         )
194 
195                     synchronized(testCases) { testCases.add(testCase) }
196                 }
197             }
198         }
199         return testCases
200     }
201 
202     private data class TestCase(
203         val mass: Double,
204         val springConstant: Double,
205         val dampingCoefficient: Double,
206         val initialVelocity: Double,
207         val initialDisplacement: Double
208     ) {
209         val dampingRatio: Double
210             get() {
211                 val criticalDamping = 2.0 * sqrt(springConstant * mass)
212                 return dampingCoefficient / criticalDamping
213             }
214 
215         val stiffness: Double
216             get() {
217                 return springConstant / mass
218             }
219     }
220 
221     private data class TestCaseResult(
222         val pass: Boolean,
223         val testCase: TestCase,
224         val reason: String = ""
225     )
226 }
227