1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:Suppress("NOTHING_TO_INLINE")
18 
19 package androidx.xr.runtime.math
20 
21 import kotlin.math.abs
22 import kotlin.math.acos
23 import kotlin.math.asin
24 import kotlin.math.atan2
25 import kotlin.math.cos
26 import kotlin.math.sin
27 import kotlin.math.sqrt
28 
29 /**
30  * Represents a rotation component in three-dimensional space. Any vector can be provided and the
31  * resulting quaternion will be normalized at construction time.
32  *
33  * @param x the x value of the quaternion.
34  * @param y the y value of the quaternion.
35  * @param z the z value of the quaternion.
36  * @param w the rotation of the unit vector, in radians.
37  */
38 public class Quaternion
39 @JvmOverloads
40 constructor(x: Float = 0F, y: Float = 0F, z: Float = 0F, w: Float = 1F) {
41     /** The normalized x component of the quaternion. */
42     public val x: Float
43     /** The normalized y component of the quaternion. */
44     public val y: Float
45     /** The normalized z component of the quaternion. */
46     public val z: Float
47     /** The normalized w component of the quaternion. */
48     public val w: Float
49 
50     init {
51         val length = sqrt(x * x + y * y + z * z + w * w)
52         this.x = x / length
53         this.y = y / length
54         this.z = z / length
55         this.w = w / length
56     }
57 
58     /** Returns a new quaternion with the inverse rotation. Assumes unit length. */
59     public inline val inverse: Quaternion
60         get() = Quaternion(-x, -y, -z, w)
61 
62     /**
63      * Returns this quaternion as Euler angles (in degrees) applied in YXZ (yaw, pitch, roll) order.
64      */
65     public val eulerAngles: Vector3
66         get() = toYawPitchRoll()
67 
68     /** Returns this quaternion as an axis/angle (in degrees) pair. */
69     public val axisAngle: Pair<Vector3, Float>
70         get() = toAxisAngle()
71 
72     /** Creates a new quaternion with the same values as the [other] quaternion. */
73     public constructor(other: Quaternion) : this(other.x, other.y, other.z, other.w)
74 
75     /** Creates a new quaternion using the components of a [Vector4]. */
76     internal constructor(vector: Vector4) : this(vector.x, vector.y, vector.z, vector.w)
77 
78     /** Flips the sign of the quaternion, but represents the same rotation. */
unaryMinusnull79     public inline operator fun unaryMinus(): Quaternion = Quaternion(-x, -y, -z, -w)
80 
81     /** Returns a new quaternion with the sum of this quaternion and [other] quaternion. */
82     public inline operator fun plus(other: Quaternion): Quaternion =
83         Quaternion(x + other.x, y + other.y, z + other.z, w + other.w)
84 
85     /**
86      * Returns a new quaternion with the difference of this quaternion and the [other] quaternion.
87      */
88     public inline operator fun minus(other: Quaternion): Quaternion =
89         Quaternion(x - other.x, y - other.y, z - other.z, w - other.w)
90 
91     /** Rotates a [Vector3] by this quaternion. */
92     public inline operator fun times(src: Vector3): Vector3 {
93         val qx = x
94         val qy = y
95         val qz = z
96         val qw = w
97         val vx = src.x
98         val vy = src.y
99         val vz = src.z
100 
101         val rx = qy * vz - qz * vy + qw * vx
102         val ry = qz * vx - qx * vz + qw * vy
103         val rz = qx * vy - qy * vx + qw * vz
104         val sx = qy * rz - qz * ry
105         val sy = qz * rx - qx * rz
106         val sz = qx * ry - qy * rx
107         return Vector3(2 * sx + vx, 2 * sy + vy, 2 * sz + vz)
108     }
109 
110     /**
111      * Returns a new quaternion with the product of this quaternion and the [other] quaternion. The
112      * order of the multiplication is `[this] * [other]`.
113      */
timesnull114     public inline operator fun times(other: Quaternion): Quaternion {
115         val lx = this.x
116         val ly = this.y
117         val lz = this.z
118         val lw = this.w
119         val rx = other.x
120         val ry = other.y
121         val rz = other.z
122         val rw = other.w
123 
124         return Quaternion(
125             lw * rx + lx * rw + ly * rz - lz * ry,
126             lw * ry - lx * rz + ly * rw + lz * rx,
127             lw * rz + lx * ry - ly * rx + lz * rw,
128             lw * rw - lx * rx - ly * ry - lz * rz,
129         )
130     }
131 
132     /** Returns a new quaternion with the product of this quaternion and a scalar amount. */
timesnull133     public operator fun times(c: Float): Quaternion = Quaternion(x * c, y * c, z * c, w * c)
134 
135     /** Returns a new quaternion with this quaternion divided by a scalar amount. */
136     public operator fun div(c: Float): Quaternion = Quaternion(x / c, y / c, z / c, w / c)
137 
138     /** Returns a new quaternion with a normalized rotation. */
139     public fun toNormalized(): Quaternion {
140         val norm = rsqrt(x * x + y * y + z * z + w * w)
141         return this * norm
142     }
143 
144     /** Returns the dot product of this quaternion and the [other] quaternion. */
dotnull145     public inline infix fun dot(other: Quaternion): Float =
146         x * other.x + y * other.y + z * other.z + w * other.w
147 
148     /**
149      * Get a [Vector3] containing the pitch, yaw and roll in degrees, extracted in YXZ (yaw, pitch,
150      * roll) order.
151      */
152     private fun toYawPitchRoll(): Vector3 {
153         val test = w * x - y * z
154         if (test > EULER_THRESHOLD) {
155             // There is a singularity when the pitch is directly up, so calculate the
156             // angles another way.
157             return Vector3(90f, toDegrees(-2 * atan2(z, w)), 0f)
158         }
159         if (test < -EULER_THRESHOLD) {
160             // There is a singularity when the pitch is directly down, so calculate the
161             // angles another way.
162             return Vector3(-90f, toDegrees(2 * atan2(z, w)), 0f)
163         }
164 
165         val pitch = asin(2 * test)
166         val yaw = atan2(2 * (w * y + x * z).toDouble(), 1.0 - 2 * (x * x + y * y)).toFloat()
167         val roll = atan2(2 * (w * z + x * y).toDouble(), 1.0 - 2 * (x * x + z * z)).toFloat()
168 
169         return Vector3(toDegrees(pitch), toDegrees(yaw), toDegrees(roll))
170     }
171 
172     /** Returns a Pair containing the axis of rotation and the angle of rotation in degrees. */
toAxisAnglenull173     private fun toAxisAngle(): Pair<Vector3, Float> {
174         val normalized = this.toNormalized()
175         val angleRadians = 2 * acos(normalized.w)
176         val sinHalfAngle = sin(angleRadians / 2)
177         val axis =
178             if (sinHalfAngle < 0.0001f) {
179                 Vector3.Right // Default axis when angle is 0
180             } else {
181                 Vector3(
182                     normalized.x / sinHalfAngle,
183                     normalized.y / sinHalfAngle,
184                     normalized.z / sinHalfAngle,
185                 )
186             }
187 
188         return Pair(axis, toDegrees(angleRadians))
189     }
190 
191     /** Returns a copy of the quaternion. */
192     @JvmOverloads
copynull193     public fun copy(
194         x: Float = this.x,
195         y: Float = this.y,
196         z: Float = this.z,
197         w: Float = this.w,
198     ): Quaternion = Quaternion(x, y, z, w)
199 
200     /** Returns true if this quaternion is equal to the [other]. */
201     override fun equals(other: Any?): Boolean {
202         if (this === other) return true
203         if (other !is Quaternion) return false
204 
205         return this.x == other.x && this.y == other.y && this.z == other.z && this.w == other.w
206     }
207 
hashCodenull208     override fun hashCode(): Int {
209         var result = x.hashCode()
210         result = 31 * result + y.hashCode()
211         result = 31 * result + z.hashCode()
212         result = 31 * result + w.hashCode()
213         return result
214     }
215 
toStringnull216     override fun toString(): String = "[x=$x, y=$y, z=$z, w=$w]"
217 
218     /** Returns a new quaternion with the identity rotation. */
219     public companion object {
220         private const val EULER_THRESHOLD: Float = 0.49999994f
221         private const val COS_THRESHOLD: Float = 0.9995f
222 
223         @JvmField public val Identity: Quaternion = Quaternion()
224 
225         /** Returns a new quaternion representing the rotation from one vector to another. */
226         @JvmStatic
227         public fun fromRotation(start: Vector3, end: Vector3): Quaternion {
228             val startNorm = start.toNormalized()
229             val endNorm = end.toNormalized()
230 
231             val cosTheta = startNorm.dot(endNorm)
232             if (cosTheta < -COS_THRESHOLD) {
233                 // Special case when vectors in opposite directions: no "ideal" rotation axis
234                 // Guess one; any work as long as perpendicular to start
235                 var rotationAxis = Vector3.Backward.cross(startNorm)
236                 if (rotationAxis.lengthSquared < 0.01f) {
237                     rotationAxis =
238                         Vector3.Right.cross(
239                             startNorm
240                         ) // pick new rotation axis as the original was parallel
241                 }
242                 return Quaternion.Companion.fromAxisAngle(rotationAxis, 180f)
243             }
244 
245             val rotationAxis = startNorm.cross(endNorm)
246 
247             return Quaternion(rotationAxis.x, rotationAxis.y, rotationAxis.z, 1 + cosTheta)
248                 .toNormalized()
249         }
250 
251         /** Returns a new quaternion representing the rotation from one quaternion to another. */
252         @JvmStatic
253         public fun fromRotation(start: Quaternion, end: Quaternion): Quaternion =
254             Quaternion(end * start.inverse).toNormalized()
255 
256         /** Returns a new quaternion with the specified forward and upward directions. */
257         @JvmStatic
258         public fun fromLookTowards(forward: Vector3, up: Vector3): Quaternion {
259             val forwardNormalized = forward.toNormalized()
260             val right = (up cross forwardNormalized).toNormalized()
261             val upNormalized = (forwardNormalized cross right).toNormalized()
262 
263             val m00 = right.x
264             val m01 = right.y
265             val m02 = right.z
266             val m10 = upNormalized.x
267             val m11 = upNormalized.y
268             val m12 = upNormalized.z
269             val m20 = forwardNormalized.x
270             val m21 = forwardNormalized.y
271             val m22 = forwardNormalized.z
272 
273             val trace = m00 + m11 + m22
274             return if (trace > 0) {
275                 val s = 0.5f / sqrt(trace + 1.0f)
276                 Quaternion((m12 - m21) * s, (m20 - m02) * s, (m01 - m10) * s, 0.25f / s)
277             } else {
278                 if (m00 > m11 && m00 > m22) {
279                     val s = sqrt(1.0f + m00 - m11 - m22) * 2.0f
280                     Quaternion(0.25f * s, (m01 + m10) / s, (m02 + m20) / s, (m12 - m21) / s)
281                 } else if (m11 > m22) {
282                     val s = sqrt(1.0f + m11 - m00 - m22) * 2.0f
283                     Quaternion((m01 + m10) / s, 0.25f * s, (m12 + m21) / s, (m20 - m02) / s)
284                 } else {
285                     val s = sqrt(1.0f + m22 - m00 - m11) * 2.0f
286                     Quaternion((m02 + m20) / s, (m12 + m21) / s, 0.25f * s, (m01 - m10) / s)
287                 }
288             }
289         }
290 
291         /** Creates a new quaternion using an axis/angle to define the rotation. */
292         @JvmStatic
293         public fun fromAxisAngle(axis: Vector3, degrees: Float): Quaternion =
294             Quaternion(
295                 sin(0.5f * toRadians(degrees)) * axis.toNormalized().x,
296                 sin(0.5f * toRadians(degrees)) * axis.toNormalized().y,
297                 sin(0.5f * toRadians(degrees)) * axis.toNormalized().z,
298                 cos(0.5f * toRadians(degrees)),
299             )
300 
301         /**
302          * Returns a new quaternion using Euler angles (in degrees) to define the rotation in YXZ
303          * (yaw, pitch, roll) order.
304          */
305         @JvmStatic
306         public fun fromEulerAngles(eulerAngles: Vector3): Quaternion =
307             Quaternion(fromAxisAngle(Vector3.Up, eulerAngles.y)) *
308                 Quaternion(fromAxisAngle(Vector3.Right, eulerAngles.x)) *
309                 Quaternion(fromAxisAngle(Vector3.Backward, eulerAngles.z))
310 
311         /**
312          * Returns a new quaternion using Euler angles (in degrees) to define the rotation in YXZ
313          * (yaw, pitch, roll) order.
314          */
315         @JvmStatic
316         public fun fromEulerAngles(pitch: Float, yaw: Float, roll: Float): Quaternion =
317             Quaternion(fromAxisAngle(Vector3.Up, yaw)) *
318                 Quaternion(fromAxisAngle(Vector3.Right, pitch)) *
319                 Quaternion(fromAxisAngle(Vector3.Backward, roll))
320 
321         /**
322          * Returns a new quaternion that is linearly interpolated between [start] and [end] using
323          * the interpolation amount [ratio].
324          *
325          * If [ratio] is outside of the range `[0, 1]`, the returned quaternion will be
326          * extrapolated.
327          */
328         @JvmStatic
329         public fun lerp(start: Quaternion, end: Quaternion, ratio: Float): Quaternion =
330             Quaternion(
331                 lerp(start.x, end.x, ratio),
332                 lerp(start.y, end.y, ratio),
333                 lerp(start.z, end.z, ratio),
334                 lerp(start.w, end.w, ratio),
335             )
336 
337         /**
338          * Returns a new quaternion that is spherically interpolated between [start] and [end] using
339          * the interpolation amount [ratio]. If [ratio] is 0, this returns [start]. As [ratio]
340          * approaches 1, the result of this function may approach either `+end` or `-end` (whichever
341          * is closest to [start]).
342          *
343          * If [ratio] is outside of the range `[0, 1]`, the returned quaternion will be
344          * extrapolated.
345          */
346         @JvmStatic
347         public fun slerp(start: Quaternion, end: Quaternion, ratio: Float): Quaternion {
348             val orientationStart = start
349             var orientationEnd = end
350 
351             // cosTheta0 provides the angle between the rotations at t = 0
352             var cosTheta0 = orientationStart.dot(orientationEnd)
353 
354             // Flip end rotation to get shortest path between the two rotations
355             if (cosTheta0 < 0.0f) {
356                 orientationEnd = -orientationEnd
357                 cosTheta0 = -cosTheta0
358             }
359 
360             // Small rotations can use linear interpolation
361             if (cosTheta0 > COS_THRESHOLD) {
362                 return lerp(orientationStart, orientationEnd, ratio)
363             }
364 
365             val sinTheta0 = sqrt(1.0 - cosTheta0 * cosTheta0)
366             val theta0 = acos(cosTheta0)
367             val thetaT = theta0 * ratio
368             val sinThetaT = sin(thetaT)
369             val costThetaT = cos(thetaT)
370 
371             val s1 = sinThetaT / sinTheta0
372             val s0 = costThetaT - cosTheta0 * s1
373 
374             // Do component-wise multiplication since s0 and s1 could be near-zero which would cause
375             // precision issues when (quat * 0.0f) is normalized due to division by near-zero
376             // length.
377             return Quaternion(
378                 orientationStart.x * s0.toFloat() + orientationEnd.x * s1.toFloat(),
379                 orientationStart.y * s0.toFloat() + orientationEnd.y * s1.toFloat(),
380                 orientationStart.z * s0.toFloat() + orientationEnd.z * s1.toFloat(),
381                 orientationStart.w * s0.toFloat() + orientationEnd.w * s1.toFloat(),
382             )
383         }
384 
385         /** Returns the angle between [start] and [end] quaternion in degrees. */
386         @JvmStatic
387         public fun angle(start: Quaternion, end: Quaternion): Float =
388             toDegrees(2.0f * acos(abs(clamp(dot(start, end), -1.0f, 1.0f))))
389 
390         /** Returns the dot product of two quaternions. */
391         @JvmStatic
392         public fun dot(lhs: Quaternion, rhs: Quaternion): Float =
393             lhs.x * rhs.x + lhs.y * rhs.y + lhs.z * rhs.z + lhs.w * rhs.w
394     }
395 }
396