• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Matrix:
7 //   Utility class implementing various matrix operations.
8 //   Supports matrices with minimum 2 and maximum 4 number of rows/columns.
9 //
10 // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this
11 // implementation.
12 // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
13 
14 #ifndef COMMON_MATRIX_UTILS_H_
15 #define COMMON_MATRIX_UTILS_H_
16 
17 #include <vector>
18 
19 #include "common/debug.h"
20 #include "common/mathutil.h"
21 #include "common/vector_utils.h"
22 
23 namespace angle
24 {
25 
26 template <typename T>
27 class Matrix
28 {
29   public:
Matrix(const std::vector<T> & elements,const unsigned int numRows,const unsigned int numCols)30     Matrix(const std::vector<T> &elements, const unsigned int numRows, const unsigned int numCols)
31         : mElements(elements), mRows(numRows), mCols(numCols)
32     {
33         ASSERT(rows() >= 1 && rows() <= 4);
34         ASSERT(columns() >= 1 && columns() <= 4);
35     }
36 
Matrix(const std::vector<T> & elements,const unsigned int size)37     Matrix(const std::vector<T> &elements, const unsigned int size)
38         : mElements(elements), mRows(size), mCols(size)
39     {
40         ASSERT(rows() >= 1 && rows() <= 4);
41         ASSERT(columns() >= 1 && columns() <= 4);
42     }
43 
Matrix(const T * elements,const unsigned int size)44     Matrix(const T *elements, const unsigned int size) : mRows(size), mCols(size)
45     {
46         ASSERT(rows() >= 1 && rows() <= 4);
47         ASSERT(columns() >= 1 && columns() <= 4);
48         for (size_t i = 0; i < size * size; i++)
49             mElements.push_back(elements[i]);
50     }
51 
operator()52     const T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) const
53     {
54         ASSERT(rowIndex < mRows);
55         ASSERT(columnIndex < mCols);
56         return mElements[rowIndex * columns() + columnIndex];
57     }
58 
operator()59     T &operator()(const unsigned int rowIndex, const unsigned int columnIndex)
60     {
61         ASSERT(rowIndex < mRows);
62         ASSERT(columnIndex < mCols);
63         return mElements[rowIndex * columns() + columnIndex];
64     }
65 
at(const unsigned int rowIndex,const unsigned int columnIndex)66     const T &at(const unsigned int rowIndex, const unsigned int columnIndex) const
67     {
68         ASSERT(rowIndex < mRows);
69         ASSERT(columnIndex < mCols);
70         return operator()(rowIndex, columnIndex);
71     }
72 
73     Matrix<T> operator*(const Matrix<T> &m)
74     {
75         ASSERT(columns() == m.rows());
76 
77         unsigned int resultRows = rows();
78         unsigned int resultCols = m.columns();
79         Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
80         for (unsigned int i = 0; i < resultRows; i++)
81         {
82             for (unsigned int j = 0; j < resultCols; j++)
83             {
84                 T tmp = 0.0f;
85                 for (unsigned int k = 0; k < columns(); k++)
86                     tmp += at(i, k) * m(k, j);
87                 result(i, j) = tmp;
88             }
89         }
90 
91         return result;
92     }
93 
94     void operator*=(const Matrix<T> &m)
95     {
96         ASSERT(columns() == m.rows());
97         Matrix<T> res  = (*this) * m;
98         size_t numElts = res.elements().size();
99         mElements.resize(numElts);
100         memcpy(mElements.data(), res.data(), numElts * sizeof(float));
101     }
102 
103     bool operator==(const Matrix<T> &m) const
104     {
105         ASSERT(columns() == m.columns());
106         ASSERT(rows() == m.rows());
107         return mElements == m.elements();
108     }
109 
110     bool operator!=(const Matrix<T> &m) const { return !(mElements == m.elements()); }
111 
nearlyEqual(T epsilon,const Matrix<T> & m)112     bool nearlyEqual(T epsilon, const Matrix<T> &m) const
113     {
114         ASSERT(columns() == m.columns());
115         ASSERT(rows() == m.rows());
116         const auto &otherElts = m.elements();
117         for (size_t i = 0; i < otherElts.size(); i++)
118         {
119             if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon))
120                 return false;
121         }
122         return true;
123     }
124 
size()125     unsigned int size() const
126     {
127         ASSERT(rows() == columns());
128         return rows();
129     }
130 
rows()131     unsigned int rows() const { return mRows; }
132 
columns()133     unsigned int columns() const { return mCols; }
134 
elements()135     std::vector<T> elements() const { return mElements; }
data()136     T *data() { return mElements.data(); }
137 
compMult(const Matrix<T> & mat1)138     Matrix<T> compMult(const Matrix<T> &mat1) const
139     {
140         Matrix result(std::vector<T>(mElements.size()), rows(), columns());
141         for (unsigned int i = 0; i < rows(); i++)
142         {
143             for (unsigned int j = 0; j < columns(); j++)
144             {
145                 T lhs        = at(i, j);
146                 T rhs        = mat1(i, j);
147                 result(i, j) = rhs * lhs;
148             }
149         }
150 
151         return result;
152     }
153 
outerProduct(const Matrix<T> & mat1)154     Matrix<T> outerProduct(const Matrix<T> &mat1) const
155     {
156         unsigned int cols = mat1.columns();
157         Matrix result(std::vector<T>(rows() * cols), rows(), cols);
158         for (unsigned int i = 0; i < rows(); i++)
159             for (unsigned int j = 0; j < cols; j++)
160                 result(i, j) = at(i, 0) * mat1(0, j);
161 
162         return result;
163     }
164 
transpose()165     Matrix<T> transpose() const
166     {
167         Matrix result(std::vector<T>(mElements.size()), columns(), rows());
168         for (unsigned int i = 0; i < columns(); i++)
169             for (unsigned int j = 0; j < rows(); j++)
170                 result(i, j) = at(j, i);
171 
172         return result;
173     }
174 
determinant()175     T determinant() const
176     {
177         ASSERT(rows() == columns());
178 
179         switch (size())
180         {
181             case 2:
182                 return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
183 
184             case 3:
185                 return at(0, 0) * at(1, 1) * at(2, 2) + at(0, 1) * at(1, 2) * at(2, 0) +
186                        at(0, 2) * at(1, 0) * at(2, 1) - at(0, 2) * at(1, 1) * at(2, 0) -
187                        at(0, 1) * at(1, 0) * at(2, 2) - at(0, 0) * at(1, 2) * at(2, 1);
188 
189             case 4:
190             {
191                 const float minorMatrices[4][3 * 3] = {{
192                                                            at(1, 1),
193                                                            at(2, 1),
194                                                            at(3, 1),
195                                                            at(1, 2),
196                                                            at(2, 2),
197                                                            at(3, 2),
198                                                            at(1, 3),
199                                                            at(2, 3),
200                                                            at(3, 3),
201                                                        },
202                                                        {
203                                                            at(1, 0),
204                                                            at(2, 0),
205                                                            at(3, 0),
206                                                            at(1, 2),
207                                                            at(2, 2),
208                                                            at(3, 2),
209                                                            at(1, 3),
210                                                            at(2, 3),
211                                                            at(3, 3),
212                                                        },
213                                                        {
214                                                            at(1, 0),
215                                                            at(2, 0),
216                                                            at(3, 0),
217                                                            at(1, 1),
218                                                            at(2, 1),
219                                                            at(3, 1),
220                                                            at(1, 3),
221                                                            at(2, 3),
222                                                            at(3, 3),
223                                                        },
224                                                        {
225                                                            at(1, 0),
226                                                            at(2, 0),
227                                                            at(3, 0),
228                                                            at(1, 1),
229                                                            at(2, 1),
230                                                            at(3, 1),
231                                                            at(1, 2),
232                                                            at(2, 2),
233                                                            at(3, 2),
234                                                        }};
235                 return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
236                        at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
237                        at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
238                        at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
239             }
240 
241             default:
242                 UNREACHABLE();
243                 break;
244         }
245 
246         return T();
247     }
248 
inverse()249     Matrix<T> inverse() const
250     {
251         ASSERT(rows() == columns());
252 
253         Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns());
254         switch (size())
255         {
256             case 2:
257                 cof(0, 0) = at(1, 1);
258                 cof(0, 1) = -at(1, 0);
259                 cof(1, 0) = -at(0, 1);
260                 cof(1, 1) = at(0, 0);
261                 break;
262 
263             case 3:
264                 cof(0, 0) = at(1, 1) * at(2, 2) - at(2, 1) * at(1, 2);
265                 cof(0, 1) = -(at(1, 0) * at(2, 2) - at(2, 0) * at(1, 2));
266                 cof(0, 2) = at(1, 0) * at(2, 1) - at(2, 0) * at(1, 1);
267                 cof(1, 0) = -(at(0, 1) * at(2, 2) - at(2, 1) * at(0, 2));
268                 cof(1, 1) = at(0, 0) * at(2, 2) - at(2, 0) * at(0, 2);
269                 cof(1, 2) = -(at(0, 0) * at(2, 1) - at(2, 0) * at(0, 1));
270                 cof(2, 0) = at(0, 1) * at(1, 2) - at(1, 1) * at(0, 2);
271                 cof(2, 1) = -(at(0, 0) * at(1, 2) - at(1, 0) * at(0, 2));
272                 cof(2, 2) = at(0, 0) * at(1, 1) - at(1, 0) * at(0, 1);
273                 break;
274 
275             case 4:
276                 cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(1, 3) +
277                             at(3, 1) * at(1, 2) * at(2, 3) - at(1, 1) * at(3, 2) * at(2, 3) -
278                             at(2, 1) * at(1, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(1, 3);
279                 cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(1, 3) +
280                               at(3, 0) * at(1, 2) * at(2, 3) - at(1, 0) * at(3, 2) * at(2, 3) -
281                               at(2, 0) * at(1, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(1, 3));
282                 cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(1, 3) +
283                             at(3, 0) * at(1, 1) * at(2, 3) - at(1, 0) * at(3, 1) * at(2, 3) -
284                             at(2, 0) * at(1, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(1, 3);
285                 cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(1, 2) +
286                               at(3, 0) * at(1, 1) * at(2, 2) - at(1, 0) * at(3, 1) * at(2, 2) -
287                               at(2, 0) * at(1, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(1, 2));
288                 cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(0, 3) +
289                               at(3, 1) * at(0, 2) * at(2, 3) - at(0, 1) * at(3, 2) * at(2, 3) -
290                               at(2, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(0, 3));
291                 cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(0, 3) +
292                             at(3, 0) * at(0, 2) * at(2, 3) - at(0, 0) * at(3, 2) * at(2, 3) -
293                             at(2, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(0, 3);
294                 cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(0, 3) +
295                               at(3, 0) * at(0, 1) * at(2, 3) - at(0, 0) * at(3, 1) * at(2, 3) -
296                               at(2, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(0, 3));
297                 cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(0, 2) +
298                             at(3, 0) * at(0, 1) * at(2, 2) - at(0, 0) * at(3, 1) * at(2, 2) -
299                             at(2, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(0, 2);
300                 cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) + at(1, 1) * at(3, 2) * at(0, 3) +
301                             at(3, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(3, 2) * at(1, 3) -
302                             at(1, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(1, 2) * at(0, 3);
303                 cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) + at(1, 0) * at(3, 2) * at(0, 3) +
304                               at(3, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(3, 2) * at(1, 3) -
305                               at(1, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(1, 2) * at(0, 3));
306                 cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) + at(1, 0) * at(3, 1) * at(0, 3) +
307                             at(3, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(3, 1) * at(1, 3) -
308                             at(1, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(1, 1) * at(0, 3);
309                 cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) + at(1, 0) * at(3, 1) * at(0, 2) +
310                               at(3, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(3, 1) * at(1, 2) -
311                               at(1, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(1, 1) * at(0, 2));
312                 cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) + at(1, 1) * at(2, 2) * at(0, 3) +
313                               at(2, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(2, 2) * at(1, 3) -
314                               at(1, 1) * at(0, 2) * at(2, 3) - at(2, 1) * at(1, 2) * at(0, 3));
315                 cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) + at(1, 0) * at(2, 2) * at(0, 3) +
316                             at(2, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(2, 2) * at(1, 3) -
317                             at(1, 0) * at(0, 2) * at(2, 3) - at(2, 0) * at(1, 2) * at(0, 3);
318                 cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) + at(1, 0) * at(2, 1) * at(0, 3) +
319                               at(2, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(2, 1) * at(1, 3) -
320                               at(1, 0) * at(0, 1) * at(2, 3) - at(2, 0) * at(1, 1) * at(0, 3));
321                 cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) + at(1, 0) * at(2, 1) * at(0, 2) +
322                             at(2, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(2, 1) * at(1, 2) -
323                             at(1, 0) * at(0, 1) * at(2, 2) - at(2, 0) * at(1, 1) * at(0, 2);
324                 break;
325 
326             default:
327                 UNREACHABLE();
328                 break;
329         }
330 
331         // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the
332         // determinant of A.
333         Matrix<T> adjugateMatrix(cof.transpose());
334         T det = determinant();
335         Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
336         for (unsigned int i = 0; i < rows(); i++)
337             for (unsigned int j = 0; j < columns(); j++)
338                 result(i, j) = (det != static_cast<T>(0)) ? adjugateMatrix(i, j) / det : T();
339 
340         return result;
341     }
342 
setToIdentity()343     void setToIdentity()
344     {
345         ASSERT(rows() == columns());
346 
347         const auto one  = T(1);
348         const auto zero = T(0);
349 
350         for (auto &e : mElements)
351             e = zero;
352 
353         for (unsigned int i = 0; i < rows(); ++i)
354         {
355             const auto pos = i * columns() + (i % columns());
356             mElements[pos] = one;
357         }
358     }
359 
360     template <unsigned int Size>
setToIdentity(T (& matrix)[Size])361     static void setToIdentity(T (&matrix)[Size])
362     {
363         static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
364 
365         const auto cols = gl::iSquareRoot<Size>();
366         const auto one  = T(1);
367         const auto zero = T(0);
368 
369         for (auto &e : matrix)
370             e = zero;
371 
372         for (unsigned int i = 0; i < cols; ++i)
373         {
374             const auto pos = i * cols + (i % cols);
375             matrix[pos]    = one;
376         }
377     }
378 
379   protected:
380     std::vector<T> mElements;
381     unsigned int mRows;
382     unsigned int mCols;
383 };
384 
385 class Mat4 : public Matrix<float>
386 {
387   public:
388     Mat4();
389     Mat4(const Matrix<float> generalMatrix);
390     Mat4(const std::vector<float> &elements);
391     Mat4(const float *elements);
392     Mat4(float m00,
393          float m01,
394          float m02,
395          float m03,
396          float m10,
397          float m11,
398          float m12,
399          float m13,
400          float m20,
401          float m21,
402          float m22,
403          float m23,
404          float m30,
405          float m31,
406          float m32,
407          float m33);
408 
409     static Mat4 Rotate(float angle, const Vector3 &axis);
410     static Mat4 Translate(const Vector3 &t);
411     static Mat4 Scale(const Vector3 &s);
412     static Mat4 Frustum(float l, float r, float b, float t, float n, float f);
413     static Mat4 Perspective(float fov, float aspectRatio, float n, float f);
414     static Mat4 Ortho(float l, float r, float b, float t, float n, float f);
415 
416     Mat4 product(const Mat4 &m);
417     Vector4 product(const Vector4 &b);
418     void dump();
419 };
420 
421 }  // namespace angle
422 
423 #endif  // COMMON_MATRIX_UTILS_H_
424