• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.android.server.wifi.util;
18 
19 /**
20  * Utility for doing basic matrix calculations
21  */
22 public class Matrix {
23     public final int n;
24     public final int m;
25     public final double[] mem;
26 
27     /**
28      * Creates a new matrix, initialized to zeros
29      *
30      * @param rows - number of rows (n)
31      * @param cols - number of columns (m)
32      */
Matrix(int rows, int cols)33     public Matrix(int rows, int cols) {
34         n = rows;
35         m = cols;
36         mem = new double[rows * cols];
37     }
38 
39     /**
40      * Creates a new matrix using the provided array of values
41      * <p>
42      * Values are in row-major order.
43      *
44      * @param stride is the number of columns.
45      * @param values is the array of values.
46      * @throws IllegalArgumentException if length of values array not a multiple of stride
47      */
Matrix(int stride, double[] values)48     public Matrix(int stride, double[] values) {
49         n = (values.length + stride - 1) / stride;
50         m = stride;
51         mem = values;
52         if (mem.length != n * m) throw new IllegalArgumentException();
53     }
54 
55     /**
56      * Creates a new matrix duplicating the given one
57      *
58      * @param that is the source Matrix.
59      */
Matrix(Matrix that)60     public Matrix(Matrix that) {
61         n = that.n;
62         m = that.m;
63         mem = new double[that.mem.length];
64         for (int i = 0; i < mem.length; i++) {
65             mem[i] = that.mem[i];
66         }
67     }
68 
69     /**
70      * Gets the matrix coefficient from row i, column j
71      *
72      * @param i row number
73      * @param j column number
74      * @return Coefficient at i,j
75      * @throws IndexOutOfBoundsException if an index is out of bounds
76      */
get(int i, int j)77     public double get(int i, int j) {
78         if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException();
79         return mem[i * m + j];
80     }
81 
82     /**
83      * Store a matrix coefficient in row i, column j
84      *
85      * @param i row number
86      * @param j column number
87      * @param v Coefficient to store at i,j
88      * @throws IndexOutOfBoundsException if an index is out of bounds
89      */
put(int i, int j, double v)90     public void put(int i, int j, double v) {
91         if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException();
92         mem[i * m + j] = v;
93     }
94 
95     /**
96      * Forms the sum of two matrices, this and that
97      *
98      * @param that is the other matrix
99      * @return newly allocated matrix representing the sum of this and that
100      * @throws IllegalArgumentException if shapes differ
101      */
plus(Matrix that)102     public Matrix plus(Matrix that) {
103         return plus(that, new Matrix(n, m));
104 
105     }
106 
107     /**
108      * Forms the sum of two matrices, this and that
109      *
110      * @param that   is the other matrix
111      * @param result is space to hold the result
112      * @return result, filled with the matrix sum
113      * @throws IllegalArgumentException if shapes differ
114      */
plus(Matrix that, Matrix result)115     public Matrix plus(Matrix that, Matrix result) {
116         if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) {
117             throw new IllegalArgumentException();
118         }
119         for (int i = 0; i < mem.length; i++) {
120             result.mem[i] = this.mem[i] + that.mem[i];
121         }
122         return result;
123     }
124 
125     /**
126      * Forms the difference of two matrices, this and that
127      *
128      * @param that is the other matrix
129      * @return newly allocated matrix representing the difference of this and that
130      * @throws IllegalArgumentException if shapes differ
131      */
minus(Matrix that)132     public Matrix minus(Matrix that) {
133         return minus(that, new Matrix(n, m));
134     }
135 
136     /**
137      * Forms the difference of two matrices, this and that
138      *
139      * @param that   is the other matrix
140      * @param result is space to hold the result
141      * @return result, filled with the matrix difference
142      * @throws IllegalArgumentException if shapes differ
143      */
minus(Matrix that, Matrix result)144     public Matrix minus(Matrix that, Matrix result) {
145         if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) {
146             throw new IllegalArgumentException();
147         }
148         for (int i = 0; i < mem.length; i++) {
149             result.mem[i] = this.mem[i] - that.mem[i];
150         }
151         return result;
152     }
153 
154     /**
155      * Forms a scalar product
156      *
157      * @param scalar is the value to multiply by
158      * @return newly allocated matrix representing the product this and scalar
159      */
times(double scalar)160     public Matrix times(double scalar) {
161         return times(scalar, new Matrix(n, m));
162     }
163 
164     /**
165      * Forms a scalar product
166      *
167      * @param scalar is the value to multiply by
168      * @param result is space to hold the result
169      * @return result, filled with the matrix difference
170      * @throws IllegalArgumentException if shapes differ
171      */
times(double scalar, Matrix result)172     public Matrix times(double scalar, Matrix result) {
173         if (!(this.n == result.n && this.m == result.m)) {
174             throw new IllegalArgumentException();
175         }
176         for (int i = 0; i < mem.length; i++) {
177             result.mem[i] = this.mem[i] * scalar;
178         }
179         return result;
180     }
181 
182     /**
183      * Forms the matrix product of two matrices, this and that
184      *
185      * @param that is the other matrix
186      * @return newly allocated matrix representing the matrix product of this and that
187      * @throws IllegalArgumentException if shapes are not conformant
188      */
dot(Matrix that)189     public Matrix dot(Matrix that) {
190         return dot(that, new Matrix(this.n, that.m));
191     }
192 
193     /**
194      * Forms the matrix product of two matrices, this and that
195      * <p>
196      * Caller supplies an object to contain the result, as well as scratch space
197      *
198      * @param that   is the other matrix
199      * @param result is space to hold the result
200      * @return result, filled with the matrix product
201      * @throws IllegalArgumentException if shapes are not conformant
202      */
dot(Matrix that, Matrix result)203     public Matrix dot(Matrix that, Matrix result) {
204         if (!(this.n == result.n && this.m == that.n && that.m == result.m)) {
205             throw new IllegalArgumentException("shape error" + this + that + result);
206         }
207         for (int i = 0; i < n; i++) {
208             for (int j = 0; j < that.m; j++) {
209                 double s = 0.0;
210                 for (int k = 0; k < m; k++) {
211                     s += this.get(i, k) * that.get(k, j);
212                 }
213                 result.put(i, j, s);
214             }
215         }
216         return result;
217     }
218 
219     /**
220      * Forms the matrix transpose
221      *
222      * @return newly allocated transpose matrix
223      */
transpose()224     public Matrix transpose() {
225         return transpose(new Matrix(m, n));
226     }
227 
228     /**
229      * Forms the matrix transpose
230      * <p>
231      * Caller supplies an object to contain the result
232      *
233      * @param result is space to hold the result
234      * @return result, filled with the matrix transpose
235      * @throws IllegalArgumentException if result shape is wrong
236      */
transpose(Matrix result)237     public Matrix transpose(Matrix result) {
238         if (!(this.n == result.m && this.m == result.n)) throw new IllegalArgumentException();
239         for (int i = 0; i < n; i++) {
240             for (int j = 0; j < m; j++) {
241                 result.put(j, i, get(i, j));
242             }
243         }
244         return result;
245     }
246 
247     /**
248      * Forms the inverse of a square matrix
249      *
250      * @return newly allocated matrix representing the matrix inverse
251      * @throws ArithmeticException if the matrix is not invertible
252      */
inverse()253     public Matrix inverse() {
254         return inverse(new Matrix(n, m), new Matrix(n, 2 * m));
255     }
256 
257     /**
258      * Forms the inverse of a square matrix
259      *
260      * @param result  is space to hold the result
261      * @param scratch is workspace of dimension n by 2*n
262      * @return result, filled with the matrix inverse
263      * @throws ArithmeticException if the matrix is not invertible
264      * @throws IllegalArgumentException if shape of scratch or result is wrong
265      */
inverse(Matrix result, Matrix scratch)266     public Matrix inverse(Matrix result, Matrix scratch) {
267         if (!(n == m && n == result.n && m == result.m && n == scratch.n && 2 * m == scratch.m)) {
268             throw new IllegalArgumentException();
269         }
270 
271         for (int i = 0; i < n; i++) {
272             for (int j = 0; j < m; j++) {
273                 scratch.put(i, j, get(i, j));
274                 scratch.put(i, m + j, i == j ? 1.0 : 0.0);
275             }
276         }
277 
278         for (int i = 0; i < n; i++) {
279             int ibest = i;
280             double vbest = Math.abs(scratch.get(ibest, ibest));
281             for (int ii = i + 1; ii < n; ii++) {
282                 double v = Math.abs(scratch.get(ii, i));
283                 if (v > vbest) {
284                     ibest = ii;
285                     vbest = v;
286                 }
287             }
288             if (ibest != i) {
289                 for (int j = 0; j < scratch.m; j++) {
290                     double t = scratch.get(i, j);
291                     scratch.put(i, j, scratch.get(ibest, j));
292                     scratch.put(ibest, j, t);
293                 }
294             }
295             double d = scratch.get(i, i);
296             if (d == 0.0) throw new ArithmeticException("Singular matrix");
297             for (int j = 0; j < scratch.m; j++) {
298                 scratch.put(i, j, scratch.get(i, j) / d);
299             }
300             for (int ii = i + 1; ii < n; ii++) {
301                 d = scratch.get(ii, i);
302                 for (int j = 0; j < scratch.m; j++) {
303                     scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j));
304                 }
305             }
306         }
307         for (int i = n - 1; i >= 0; i--) {
308             for (int ii = 0; ii < i; ii++) {
309                 double d = scratch.get(ii, i);
310                 for (int j = 0; j < scratch.m; j++) {
311                     scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j));
312                 }
313             }
314         }
315         for (int i = 0; i < result.n; i++) {
316             for (int j = 0; j < result.m; j++) {
317                 result.put(i, j, scratch.get(i, m + j));
318             }
319         }
320         return result;
321     }
322     /**
323      * Forms the matrix product with the transpose of a second matrix
324      *
325      * @param that is the other matrix
326      * @return newly allocated matrix representing the matrix product of this and that.transpose()
327      * @throws IllegalArgumentException if shapes are not conformant
328      */
dotTranspose(Matrix that)329     public Matrix dotTranspose(Matrix that) {
330         return dotTranspose(that, new Matrix(this.n, that.n));
331     }
332 
333     /**
334      * Forms the matrix product with the transpose of a second matrix
335      * <p>
336      * Caller supplies an object to contain the result, as well as scratch space
337      *
338      * @param that is the other matrix
339      * @param result is space to hold the result
340      * @return result, filled with the matrix product of this and that.transpose()
341      * @throws IllegalArgumentException if shapes are not conformant
342      */
dotTranspose(Matrix that, Matrix result)343     public Matrix dotTranspose(Matrix that, Matrix result) {
344         if (!(this.n == result.n && this.m == that.m && that.n == result.m)) {
345             throw new IllegalArgumentException();
346         }
347         for (int i = 0; i < n; i++) {
348             for (int j = 0; j < that.n; j++) {
349                 double s = 0.0;
350                 for (int k = 0; k < m; k++) {
351                     s += this.get(i, k) * that.get(j, k);
352                 }
353                 result.put(i, j, s);
354             }
355         }
356         return result;
357     }
358 
359     /**
360      * Tests for equality
361      */
362     @Override
equals(Object that)363     public boolean equals(Object that) {
364         if (this == that) return true;
365         if (!(that instanceof Matrix)) return false;
366         Matrix other = (Matrix) that;
367         if (n != other.n) return false;
368         if (m != other.m) return false;
369         for (int i = 0; i < mem.length; i++) {
370             if (mem[i] != other.mem[i]) return false;
371         }
372         return true;
373     }
374 
375     /**
376      * Calculates a hash code
377      */
378     @Override
hashCode()379     public int hashCode() {
380         int h = n * 101 + m;
381         for (int i = 0; i < mem.length; i++) {
382             h = h * 37 + Double.hashCode(mem[i]);
383         }
384         return h;
385     }
386 
387     /**
388      * Makes a string representation
389      *
390      * @return string like "[a, b; c, d]"
391      */
392     @Override
toString()393     public String toString() {
394         StringBuilder sb = new StringBuilder(n * m * 8);
395         sb.append("[");
396         for (int i = 0; i < mem.length; i++) {
397             if (i > 0) sb.append(i % m == 0 ? "; " : ", ");
398             sb.append(mem[i]);
399         }
400         sb.append("]");
401         return sb.toString();
402     }
403 
404 }
405