1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.xr.runtime.math
18 
19 import kotlin.math.sign
20 import kotlin.math.sqrt
21 
22 /**
23  * An immutable 4x4 matrix that represents translation, scale, and rotation. The matrix is column
24  * major and right handed. The indexes of [dataToCopy] represent the following matrix layout:
25  * ```
26  *  [0, 4,  8, 12]
27  *  [1, 5,  9, 13]
28  *  [2, 6, 10, 14]
29  *  [3, 7, 11, 15]
30  * ```
31  *
32  * @param dataToCopy the array with 16 elements that will be copied over.
33  */
34 public class Matrix4(dataToCopy: FloatArray) {
35     init {
36         // TODO: Consider using contracts to avoid the exception being inlined.
<lambda>null37         require(dataToCopy.size == 16) {
38             "Input array must have exactly 16 elements for a 4x4 matrix"
39         }
40     }
41 
42     /** Returns an array of the components of this matrix. */
43     public val data: FloatArray = dataToCopy.copyOf()
44 
45     /** Returns a matrix that performs the opposite transformation. */
<lambda>null46     public val inverse: Matrix4 by lazy(LazyThreadSafetyMode.NONE) { inverse() }
47 
48     /** Returns a matrix that is the transpose of this matrix. */
<lambda>null49     public val transpose: Matrix4 by lazy(LazyThreadSafetyMode.NONE) { transpose() }
50 
51     /** Returns the translation component of this matrix. */
52     public val translation: Vector3 by
<lambda>null53         lazy(LazyThreadSafetyMode.NONE) { Vector3(data[12], data[13], data[14]) }
54 
55     /** Returns the scale component of this matrix. */
<lambda>null56     public val scale: Vector3 by lazy(LazyThreadSafetyMode.NONE) { scale() }
57 
58     /** Returns the rotation component of this matrix. */
<lambda>null59     public val rotation: Quaternion by lazy(LazyThreadSafetyMode.NONE) { rotation() }
60 
61     /** Returns the pose (i.e. rotation and translation) of this matrix. */
<lambda>null62     public val pose: Pose by lazy(LazyThreadSafetyMode.NONE) { Pose(translation, rotation) }
63 
64     /**
65      * Returns true if this matrix is a valid transformation matrix that can be decomposed into
66      * translation, rotation and scale using determinant properties.
67      */
<lambda>null68     public val isTrs: Boolean by lazy(LazyThreadSafetyMode.NONE) { determinant() != 0.0f }
69 
70     /** Creates a new matrix with a deep copy of the data from the [other] [Matrix4]. */
71     public constructor(other: Matrix4) : this(other.data.copyOf())
72 
73     /**
74      * Returns a new matrix with the matrix multiplication product of this matrix and the [other]
75      * matrix.
76      */
timesnull77     public operator fun times(other: Matrix4): Matrix4 {
78         val result = Matrix4.Zero
79         android.opengl.Matrix.multiplyMM(
80             /* result= */ result.data,
81             /* resultOffset= */ 0,
82             /* lhs= */ this.data,
83             /* lhsOffset= */ 0,
84             /* rhs= */ other.data,
85             /* rhsOffset= */ 0,
86         )
87 
88         return Matrix4(result.data)
89     }
90 
inversenull91     private fun inverse(): Matrix4 {
92         val result = Matrix4.Zero
93         android.opengl.Matrix.invertM(
94             /* mInv= */ result.data,
95             /* mInvOffset= */ 0,
96             /* m= */ this.data,
97             /* mOffset= */ 0,
98         )
99 
100         return Matrix4(result.data)
101     }
102 
transposenull103     private fun transpose(): Matrix4 {
104         val result = Matrix4.Zero
105         android.opengl.Matrix.transposeM(
106             /* mTrans= */ result.data,
107             /* mTransOffset= */ 0,
108             /* m= */ this.data,
109             /* mOffset= */ 0,
110         )
111 
112         return Matrix4(result.data)
113     }
114 
rotationnull115     private fun rotation(): Quaternion {
116         val m00 = data[0]
117         val m01 = data[4]
118         val m02 = data[8]
119         val m10 = data[1]
120         val m11 = data[5]
121         val m12 = data[9]
122         val m20 = data[2]
123         val m21 = data[6]
124         val m22 = data[10]
125 
126         val trace = m00 + m11 + m22 + 1.0f
127 
128         return if (trace > 0) {
129             val s = 0.5f / sqrt(trace)
130             Quaternion((m21 - m12) * s, (m02 - m20) * s, (m10 - m01) * s, 0.25f / s)
131         } else if ((m00 > m11) && (m00 > m22)) {
132             val s = 2.0f * sqrt(1.0f + m00 - m11 - m22)
133             Quaternion(0.25f * s, (m01 + m10) / s, (m02 + m20) / s, (m21 - m12) / s)
134         } else if (m11 > m22) {
135             val s = 2.0f * sqrt(1.0f + m11 - m00 - m22)
136             Quaternion((m01 + m10) / s, 0.25f * s, (m12 + m21) / s, (m02 - m20) / s)
137         } else {
138             val s = 2.0f * sqrt(1.0f + m22 - m00 - m11)
139             Quaternion((m02 + m20) / s, (m12 + m21) / s, 0.25f * s, (m10 - m01) / s)
140         }
141     }
142 
scalenull143     private fun scale(): Vector3 {
144         // TODO: b/367780918 - Investigate why scale can have negative values when inputs were
145         // positive.
146         // We shouldn't use sign() directly because we don't want it to ever return 0
147         val signX = if (data[0] == 0.0f) 1.0f else sign(data[0])
148         val signY = if (data[5] == 0.0f) 1.0f else sign(data[5])
149         val signZ = if (data[10] == 0.0f) 1.0f else sign(data[10])
150         return Vector3(
151             signX * sqrt(data[0] * data[0] + data[1] * data[1] + data[2] * data[2]),
152             signY * sqrt(data[4] * data[4] + data[5] * data[5] + data[6] * data[6]),
153             signZ * sqrt(data[8] * data[8] + data[9] * data[9] + data[10] * data[10]),
154         )
155     }
156 
157     /** Computes the determinant of a 4x4 matrix. */
determinantnull158     private fun determinant(): Float =
159         data[0] *
160             (data[5] * (data[10] * data[15] - data[14] * data[11]) -
161                 data[9] * (data[6] * data[15] - data[14] * data[7]) +
162                 data[13] * (data[6] * data[11] - data[10] * data[7])) -
163             data[4] *
164                 (data[1] * (data[10] * data[15] - data[14] * data[11]) -
165                     data[9] * (data[2] * data[15] - data[14] * data[3]) +
166                     data[13] * (data[2] * data[11] - data[10] * data[3])) +
167             data[8] *
168                 (data[1] * (data[6] * data[15] - data[14] * data[7]) -
169                     data[5] * (data[2] * data[15] - data[14] * data[3]) +
170                     data[13] * (data[2] * data[7] - data[6] * data[3])) -
171             data[12] *
172                 (data[1] * (data[6] * data[11] - data[10] * data[7]) -
173                     data[5] * (data[2] * data[11] - data[10] * data[3]) +
174                     data[9] * (data[2] * data[7] - data[6] * data[3]))
175 
176     /** Returns true if this pose is equal to [other]. */
177     public override fun equals(other: Any?): Boolean {
178         if (this === other) return true
179         if (other !is Matrix4) return false
180 
181         return this.data.contentEquals(other.data)
182     }
183 
184     /** Standard hash code calculation using constructor values */
hashCodenull185     public override fun hashCode(): Int = data.contentHashCode()
186 
187     /** Standard toString() implementation */
188     public override fun toString(): String =
189         "\n[ " +
190             data[0] +
191             "\t" +
192             data[4] +
193             "\t" +
194             data[8] +
195             "\t" +
196             data[12] +
197             "\n  " +
198             data[1] +
199             "\t" +
200             data[5] +
201             "\t" +
202             data[9] +
203             "\t" +
204             data[13] +
205             "\n  " +
206             data[2] +
207             "\t" +
208             data[6] +
209             "\t" +
210             data[10] +
211             "\t" +
212             data[14] +
213             "\n  " +
214             data[3] +
215             "\t" +
216             data[7] +
217             "\t" +
218             data[11] +
219             "\t" +
220             data[15] +
221             " ]"
222 
223     /** Returns a copy of the matrix. */
224     public fun copy(data: FloatArray = this.data): Matrix4 = Matrix4(data)
225 
226     public companion object {
227         /** Returns an identity matrix. */
228         @JvmField
229         public val Identity: Matrix4 =
230             Matrix4(floatArrayOf(1f, 0f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 0f, 1f, 0f, 0f, 0f, 0f, 1f))
231 
232         /** Returns a zero matrix. */
233         @JvmField
234         public val Zero: Matrix4 =
235             Matrix4(floatArrayOf(0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f, 0f))
236 
237         /**
238          * Returns a new transformation matrix. The returned matrix is such that it first scales
239          * objects, then rotates them, and finally translates them.
240          */
241         @JvmStatic
242         public fun fromTrs(translation: Vector3, rotation: Quaternion, scale: Vector3): Matrix4 {
243             // implementationd details: https://www.songho.ca/opengl/gl_quaternion.html
244             val q = rotation.toNormalized()
245 
246             // double var1 var2
247             val dqyx = 2 * q.y * q.x
248             val dqxz = 2 * q.x * q.z
249             val dqxw = 2 * q.x * q.w
250             val dqyw = 2 * q.y * q.w
251             val dqzw = 2 * q.z * q.w
252             val dqzy = 2 * q.z * q.y
253 
254             // double var squared
255             val dsqz = 2 * q.z * q.z
256             val dsqy = 2 * q.y * q.y
257 
258             val oneMinusDSQX = 1 - 2 * q.x * q.x
259 
260             return Matrix4(
261                 floatArrayOf(
262                     (1 - dsqy - dsqz) * scale.x,
263                     (dqyx + dqzw) * scale.x,
264                     (dqxz - dqyw) * scale.x,
265                     0.0f,
266                     (dqyx - dqzw) * scale.y,
267                     (oneMinusDSQX - dsqz) * scale.y,
268                     (dqzy + dqxw) * scale.y,
269                     0.0f,
270                     (dqxz + dqyw) * scale.z,
271                     (dqzy - dqxw) * scale.z,
272                     (oneMinusDSQX - dsqy) * scale.z,
273                     0.0f,
274                     translation.x,
275                     translation.y,
276                     translation.z,
277                     1.0f,
278                 )
279             )
280         }
281 
282         /** Returns a new translation matrix. */
283         @JvmStatic
284         public fun fromTranslation(translation: Vector3): Matrix4 =
285             Matrix4(
286                 floatArrayOf(
287                     1.0f,
288                     0.0f,
289                     0.0f,
290                     0.0f,
291                     0.0f,
292                     1.0f,
293                     0.0f,
294                     0.0f,
295                     0.0f,
296                     0.0f,
297                     1.0f,
298                     0.0f,
299                     translation.x,
300                     translation.y,
301                     translation.z,
302                     1.0f,
303                 )
304             )
305 
306         /** Returns a new uniform scale matrix. */
307         @JvmStatic
308         public fun fromScale(scale: Vector3): Matrix4 =
309             Matrix4(
310                 floatArrayOf(
311                     scale.x,
312                     0.0f,
313                     0.0f,
314                     0.0f,
315                     0.0f,
316                     scale.y,
317                     0.0f,
318                     0.0f,
319                     0.0f,
320                     0.0f,
321                     scale.z,
322                     0.0f,
323                     0.0f,
324                     0.0f,
325                     0.0f,
326                     1.0f,
327                 )
328             )
329 
330         /** Returns a new scale matrix. */
331         @JvmStatic
332         public fun fromScale(scale: Float): Matrix4 =
333             Matrix4(
334                 floatArrayOf(
335                     scale,
336                     0.0f,
337                     0.0f,
338                     0.0f,
339                     0.0f,
340                     scale,
341                     0.0f,
342                     0.0f,
343                     0.0f,
344                     0.0f,
345                     scale,
346                     0.0f,
347                     0.0f,
348                     0.0f,
349                     0.0f,
350                     1.0f,
351                 )
352             )
353 
354         /** Returns a new rotation matrix. */
355         @JvmStatic
356         public fun fromQuaternion(quaternion: Quaternion): Matrix4 =
357             fromTrs(Vector3.Zero, quaternion, Vector3.One)
358 
359         /**
360          * Returns a new rigid transformation matrix. The returned matrix is such that it first
361          * rotates objects, and then translates them.
362          */
363         @JvmStatic
364         public fun fromPose(pose: Pose): Matrix4 {
365             return fromTrs(pose.translation, pose.rotation, Vector3.One)
366         }
367     }
368 }
369