• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright 2015 The ANGLE Project Authors. All rights reserved.
3 // Use of this source code is governed by a BSD-style license that can be
4 // found in the LICENSE file.
5 //
6 // Matrix:
7 //   Utility class implementing various matrix operations.
8 //   Supports matrices with minimum 2 and maximum 4 number of rows/columns.
9 //
10 // TODO: Check if we can merge Matrix.h in sample_util with this and replace it with this
11 // implementation.
12 // TODO: Rename this file to Matrix.h once we remove Matrix.h in sample_util.
13 
14 #ifndef COMMON_MATRIX_UTILS_H_
15 #define COMMON_MATRIX_UTILS_H_
16 
17 #include <vector>
18 
19 #include "common/debug.h"
20 #include "common/mathutil.h"
21 #include "common/vector_utils.h"
22 
23 namespace angle
24 {
25 
26 template <typename T>
27 class Matrix
28 {
29   public:
Matrix(const std::vector<T> & elements,const unsigned int numRows,const unsigned int numCols)30     Matrix(const std::vector<T> &elements, const unsigned int numRows, const unsigned int numCols)
31         : mElements(elements), mRows(numRows), mCols(numCols)
32     {
33         ASSERT(rows() >= 1 && rows() <= 4);
34         ASSERT(columns() >= 1 && columns() <= 4);
35     }
36 
Matrix(const std::vector<T> & elements,const unsigned int size)37     Matrix(const std::vector<T> &elements, const unsigned int size)
38         : mElements(elements), mRows(size), mCols(size)
39     {
40         ASSERT(rows() >= 1 && rows() <= 4);
41         ASSERT(columns() >= 1 && columns() <= 4);
42     }
43 
Matrix(const T * elements,const unsigned int size)44     Matrix(const T *elements, const unsigned int size) : mRows(size), mCols(size)
45     {
46         ASSERT(rows() >= 1 && rows() <= 4);
47         ASSERT(columns() >= 1 && columns() <= 4);
48         for (size_t i = 0; i < size * size; i++)
49             mElements.push_back(elements[i]);
50     }
51 
operator()52     const T &operator()(const unsigned int rowIndex, const unsigned int columnIndex) const
53     {
54         ASSERT(rowIndex < mRows);
55         ASSERT(columnIndex < mCols);
56         return mElements[rowIndex * columns() + columnIndex];
57     }
58 
operator()59     T &operator()(const unsigned int rowIndex, const unsigned int columnIndex)
60     {
61         ASSERT(rowIndex < mRows);
62         ASSERT(columnIndex < mCols);
63         return mElements[rowIndex * columns() + columnIndex];
64     }
65 
at(const unsigned int rowIndex,const unsigned int columnIndex)66     const T &at(const unsigned int rowIndex, const unsigned int columnIndex) const
67     {
68         ASSERT(rowIndex < mRows);
69         ASSERT(columnIndex < mCols);
70         return operator()(rowIndex, columnIndex);
71     }
72 
73     Matrix<T> operator*(const Matrix<T> &m)
74     {
75         ASSERT(columns() == m.rows());
76 
77         unsigned int resultRows = rows();
78         unsigned int resultCols = m.columns();
79         Matrix<T> result(std::vector<T>(resultRows * resultCols), resultRows, resultCols);
80         for (unsigned int i = 0; i < resultRows; i++)
81         {
82             for (unsigned int j = 0; j < resultCols; j++)
83             {
84                 T tmp = 0.0f;
85                 for (unsigned int k = 0; k < columns(); k++)
86                     tmp += at(i, k) * m(k, j);
87                 result(i, j) = tmp;
88             }
89         }
90 
91         return result;
92     }
93 
94     void operator*=(const Matrix<T> &m)
95     {
96         ASSERT(columns() == m.rows());
97         Matrix<T> res  = (*this) * m;
98         size_t numElts = res.elements().size();
99         mElements.resize(numElts);
100         memcpy(mElements.data(), res.data(), numElts * sizeof(float));
101     }
102 
103     bool operator==(const Matrix<T> &m) const
104     {
105         ASSERT(columns() == m.columns());
106         ASSERT(rows() == m.rows());
107         return mElements == m.elements();
108     }
109 
110     bool operator!=(const Matrix<T> &m) const { return !(mElements == m.elements()); }
111 
nearlyEqual(T epsilon,const Matrix<T> & m)112     bool nearlyEqual(T epsilon, const Matrix<T> &m) const
113     {
114         ASSERT(columns() == m.columns());
115         ASSERT(rows() == m.rows());
116         const auto &otherElts = m.elements();
117         for (size_t i = 0; i < otherElts.size(); i++)
118         {
119             if ((mElements[i] - otherElts[i] > epsilon) && (otherElts[i] - mElements[i] > epsilon))
120                 return false;
121         }
122         return true;
123     }
124 
size()125     unsigned int size() const
126     {
127         ASSERT(rows() == columns());
128         return rows();
129     }
130 
rows()131     unsigned int rows() const { return mRows; }
132 
columns()133     unsigned int columns() const { return mCols; }
134 
elements()135     std::vector<T> elements() const { return mElements; }
data()136     T *data() { return mElements.data(); }
constData()137     const T *constData() const { return mElements.data(); }
138 
compMult(const Matrix<T> & mat1)139     Matrix<T> compMult(const Matrix<T> &mat1) const
140     {
141         Matrix result(std::vector<T>(mElements.size()), rows(), columns());
142         for (unsigned int i = 0; i < rows(); i++)
143         {
144             for (unsigned int j = 0; j < columns(); j++)
145             {
146                 T lhs        = at(i, j);
147                 T rhs        = mat1(i, j);
148                 result(i, j) = rhs * lhs;
149             }
150         }
151 
152         return result;
153     }
154 
outerProduct(const Matrix<T> & mat1)155     Matrix<T> outerProduct(const Matrix<T> &mat1) const
156     {
157         unsigned int cols = mat1.columns();
158         Matrix result(std::vector<T>(rows() * cols), rows(), cols);
159         for (unsigned int i = 0; i < rows(); i++)
160             for (unsigned int j = 0; j < cols; j++)
161                 result(i, j) = at(i, 0) * mat1(0, j);
162 
163         return result;
164     }
165 
transpose()166     Matrix<T> transpose() const
167     {
168         Matrix result(std::vector<T>(mElements.size()), columns(), rows());
169         for (unsigned int i = 0; i < columns(); i++)
170             for (unsigned int j = 0; j < rows(); j++)
171                 result(i, j) = at(j, i);
172 
173         return result;
174     }
175 
determinant()176     T determinant() const
177     {
178         ASSERT(rows() == columns());
179 
180         switch (size())
181         {
182             case 2:
183                 return at(0, 0) * at(1, 1) - at(0, 1) * at(1, 0);
184 
185             case 3:
186                 return at(0, 0) * at(1, 1) * at(2, 2) + at(0, 1) * at(1, 2) * at(2, 0) +
187                        at(0, 2) * at(1, 0) * at(2, 1) - at(0, 2) * at(1, 1) * at(2, 0) -
188                        at(0, 1) * at(1, 0) * at(2, 2) - at(0, 0) * at(1, 2) * at(2, 1);
189 
190             case 4:
191             {
192                 const float minorMatrices[4][3 * 3] = {{
193                                                            at(1, 1),
194                                                            at(2, 1),
195                                                            at(3, 1),
196                                                            at(1, 2),
197                                                            at(2, 2),
198                                                            at(3, 2),
199                                                            at(1, 3),
200                                                            at(2, 3),
201                                                            at(3, 3),
202                                                        },
203                                                        {
204                                                            at(1, 0),
205                                                            at(2, 0),
206                                                            at(3, 0),
207                                                            at(1, 2),
208                                                            at(2, 2),
209                                                            at(3, 2),
210                                                            at(1, 3),
211                                                            at(2, 3),
212                                                            at(3, 3),
213                                                        },
214                                                        {
215                                                            at(1, 0),
216                                                            at(2, 0),
217                                                            at(3, 0),
218                                                            at(1, 1),
219                                                            at(2, 1),
220                                                            at(3, 1),
221                                                            at(1, 3),
222                                                            at(2, 3),
223                                                            at(3, 3),
224                                                        },
225                                                        {
226                                                            at(1, 0),
227                                                            at(2, 0),
228                                                            at(3, 0),
229                                                            at(1, 1),
230                                                            at(2, 1),
231                                                            at(3, 1),
232                                                            at(1, 2),
233                                                            at(2, 2),
234                                                            at(3, 2),
235                                                        }};
236                 return at(0, 0) * Matrix<T>(minorMatrices[0], 3).determinant() -
237                        at(0, 1) * Matrix<T>(minorMatrices[1], 3).determinant() +
238                        at(0, 2) * Matrix<T>(minorMatrices[2], 3).determinant() -
239                        at(0, 3) * Matrix<T>(minorMatrices[3], 3).determinant();
240             }
241 
242             default:
243                 UNREACHABLE();
244                 break;
245         }
246 
247         return T();
248     }
249 
inverse()250     Matrix<T> inverse() const
251     {
252         ASSERT(rows() == columns());
253 
254         Matrix<T> cof(std::vector<T>(mElements.size()), rows(), columns());
255         switch (size())
256         {
257             case 2:
258                 cof(0, 0) = at(1, 1);
259                 cof(0, 1) = -at(1, 0);
260                 cof(1, 0) = -at(0, 1);
261                 cof(1, 1) = at(0, 0);
262                 break;
263 
264             case 3:
265                 cof(0, 0) = at(1, 1) * at(2, 2) - at(2, 1) * at(1, 2);
266                 cof(0, 1) = -(at(1, 0) * at(2, 2) - at(2, 0) * at(1, 2));
267                 cof(0, 2) = at(1, 0) * at(2, 1) - at(2, 0) * at(1, 1);
268                 cof(1, 0) = -(at(0, 1) * at(2, 2) - at(2, 1) * at(0, 2));
269                 cof(1, 1) = at(0, 0) * at(2, 2) - at(2, 0) * at(0, 2);
270                 cof(1, 2) = -(at(0, 0) * at(2, 1) - at(2, 0) * at(0, 1));
271                 cof(2, 0) = at(0, 1) * at(1, 2) - at(1, 1) * at(0, 2);
272                 cof(2, 1) = -(at(0, 0) * at(1, 2) - at(1, 0) * at(0, 2));
273                 cof(2, 2) = at(0, 0) * at(1, 1) - at(1, 0) * at(0, 1);
274                 break;
275 
276             case 4:
277                 cof(0, 0) = at(1, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(1, 3) +
278                             at(3, 1) * at(1, 2) * at(2, 3) - at(1, 1) * at(3, 2) * at(2, 3) -
279                             at(2, 1) * at(1, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(1, 3);
280                 cof(0, 1) = -(at(1, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(1, 3) +
281                               at(3, 0) * at(1, 2) * at(2, 3) - at(1, 0) * at(3, 2) * at(2, 3) -
282                               at(2, 0) * at(1, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(1, 3));
283                 cof(0, 2) = at(1, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(1, 3) +
284                             at(3, 0) * at(1, 1) * at(2, 3) - at(1, 0) * at(3, 1) * at(2, 3) -
285                             at(2, 0) * at(1, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(1, 3);
286                 cof(0, 3) = -(at(1, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(1, 2) +
287                               at(3, 0) * at(1, 1) * at(2, 2) - at(1, 0) * at(3, 1) * at(2, 2) -
288                               at(2, 0) * at(1, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(1, 2));
289                 cof(1, 0) = -(at(0, 1) * at(2, 2) * at(3, 3) + at(2, 1) * at(3, 2) * at(0, 3) +
290                               at(3, 1) * at(0, 2) * at(2, 3) - at(0, 1) * at(3, 2) * at(2, 3) -
291                               at(2, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(2, 2) * at(0, 3));
292                 cof(1, 1) = at(0, 0) * at(2, 2) * at(3, 3) + at(2, 0) * at(3, 2) * at(0, 3) +
293                             at(3, 0) * at(0, 2) * at(2, 3) - at(0, 0) * at(3, 2) * at(2, 3) -
294                             at(2, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(2, 2) * at(0, 3);
295                 cof(1, 2) = -(at(0, 0) * at(2, 1) * at(3, 3) + at(2, 0) * at(3, 1) * at(0, 3) +
296                               at(3, 0) * at(0, 1) * at(2, 3) - at(0, 0) * at(3, 1) * at(2, 3) -
297                               at(2, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(2, 1) * at(0, 3));
298                 cof(1, 3) = at(0, 0) * at(2, 1) * at(3, 2) + at(2, 0) * at(3, 1) * at(0, 2) +
299                             at(3, 0) * at(0, 1) * at(2, 2) - at(0, 0) * at(3, 1) * at(2, 2) -
300                             at(2, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(2, 1) * at(0, 2);
301                 cof(2, 0) = at(0, 1) * at(1, 2) * at(3, 3) + at(1, 1) * at(3, 2) * at(0, 3) +
302                             at(3, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(3, 2) * at(1, 3) -
303                             at(1, 1) * at(0, 2) * at(3, 3) - at(3, 1) * at(1, 2) * at(0, 3);
304                 cof(2, 1) = -(at(0, 0) * at(1, 2) * at(3, 3) + at(1, 0) * at(3, 2) * at(0, 3) +
305                               at(3, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(3, 2) * at(1, 3) -
306                               at(1, 0) * at(0, 2) * at(3, 3) - at(3, 0) * at(1, 2) * at(0, 3));
307                 cof(2, 2) = at(0, 0) * at(1, 1) * at(3, 3) + at(1, 0) * at(3, 1) * at(0, 3) +
308                             at(3, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(3, 1) * at(1, 3) -
309                             at(1, 0) * at(0, 1) * at(3, 3) - at(3, 0) * at(1, 1) * at(0, 3);
310                 cof(2, 3) = -(at(0, 0) * at(1, 1) * at(3, 2) + at(1, 0) * at(3, 1) * at(0, 2) +
311                               at(3, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(3, 1) * at(1, 2) -
312                               at(1, 0) * at(0, 1) * at(3, 2) - at(3, 0) * at(1, 1) * at(0, 2));
313                 cof(3, 0) = -(at(0, 1) * at(1, 2) * at(2, 3) + at(1, 1) * at(2, 2) * at(0, 3) +
314                               at(2, 1) * at(0, 2) * at(1, 3) - at(0, 1) * at(2, 2) * at(1, 3) -
315                               at(1, 1) * at(0, 2) * at(2, 3) - at(2, 1) * at(1, 2) * at(0, 3));
316                 cof(3, 1) = at(0, 0) * at(1, 2) * at(2, 3) + at(1, 0) * at(2, 2) * at(0, 3) +
317                             at(2, 0) * at(0, 2) * at(1, 3) - at(0, 0) * at(2, 2) * at(1, 3) -
318                             at(1, 0) * at(0, 2) * at(2, 3) - at(2, 0) * at(1, 2) * at(0, 3);
319                 cof(3, 2) = -(at(0, 0) * at(1, 1) * at(2, 3) + at(1, 0) * at(2, 1) * at(0, 3) +
320                               at(2, 0) * at(0, 1) * at(1, 3) - at(0, 0) * at(2, 1) * at(1, 3) -
321                               at(1, 0) * at(0, 1) * at(2, 3) - at(2, 0) * at(1, 1) * at(0, 3));
322                 cof(3, 3) = at(0, 0) * at(1, 1) * at(2, 2) + at(1, 0) * at(2, 1) * at(0, 2) +
323                             at(2, 0) * at(0, 1) * at(1, 2) - at(0, 0) * at(2, 1) * at(1, 2) -
324                             at(1, 0) * at(0, 1) * at(2, 2) - at(2, 0) * at(1, 1) * at(0, 2);
325                 break;
326 
327             default:
328                 UNREACHABLE();
329                 break;
330         }
331 
332         // The inverse of A is the transpose of the cofactor matrix times the reciprocal of the
333         // determinant of A.
334         Matrix<T> adjugateMatrix(cof.transpose());
335         T det = determinant();
336         Matrix<T> result(std::vector<T>(mElements.size()), rows(), columns());
337         for (unsigned int i = 0; i < rows(); i++)
338             for (unsigned int j = 0; j < columns(); j++)
339                 result(i, j) = (det != static_cast<T>(0)) ? adjugateMatrix(i, j) / det : T();
340 
341         return result;
342     }
343 
setToIdentity()344     void setToIdentity()
345     {
346         ASSERT(rows() == columns());
347 
348         const auto one  = T(1);
349         const auto zero = T(0);
350 
351         for (auto &e : mElements)
352             e = zero;
353 
354         for (unsigned int i = 0; i < rows(); ++i)
355         {
356             const auto pos = i * columns() + (i % columns());
357             mElements[pos] = one;
358         }
359     }
360 
361     template <unsigned int Size>
setToIdentity(T (& matrix)[Size])362     static void setToIdentity(T (&matrix)[Size])
363     {
364         static_assert(gl::iSquareRoot<Size>() != 0, "Matrix is not square.");
365 
366         const auto cols = gl::iSquareRoot<Size>();
367         const auto one  = T(1);
368         const auto zero = T(0);
369 
370         for (auto &e : matrix)
371             e = zero;
372 
373         for (unsigned int i = 0; i < cols; ++i)
374         {
375             const auto pos = i * cols + (i % cols);
376             matrix[pos]    = one;
377         }
378     }
379 
380   protected:
381     std::vector<T> mElements;
382     unsigned int mRows;
383     unsigned int mCols;
384 };
385 
386 class Mat4 : public Matrix<float>
387 {
388   public:
389     Mat4();
390     Mat4(const Matrix<float> generalMatrix);
391     Mat4(const std::vector<float> &elements);
392     Mat4(const float *elements);
393     Mat4(float m00,
394          float m01,
395          float m02,
396          float m03,
397          float m10,
398          float m11,
399          float m12,
400          float m13,
401          float m20,
402          float m21,
403          float m22,
404          float m23,
405          float m30,
406          float m31,
407          float m32,
408          float m33);
409 
410     static Mat4 Rotate(float angle, const Vector3 &axis);
411     static Mat4 Translate(const Vector3 &t);
412     static Mat4 Scale(const Vector3 &s);
413     static Mat4 Frustum(float l, float r, float b, float t, float n, float f);
414     static Mat4 Perspective(float fov, float aspectRatio, float n, float f);
415     static Mat4 Ortho(float l, float r, float b, float t, float n, float f);
416 
417     Mat4 product(const Mat4 &m);
418     Vector4 product(const Vector4 &b);
419     void dump();
420 };
421 
422 }  // namespace angle
423 
424 #endif  // COMMON_MATRIX_UTILS_H_
425