1 /* 2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package com.android.server.wifi.util; 18 19 /** 20 * Utility for doing basic matrix calculations 21 */ 22 public class Matrix { 23 public final int n; 24 public final int m; 25 public final double[] mem; 26 27 /** 28 * Creates a new matrix, initialized to zeros 29 * 30 * @param rows - number of rows (n) 31 * @param cols - number of columns (m) 32 */ Matrix(int rows, int cols)33 public Matrix(int rows, int cols) { 34 n = rows; 35 m = cols; 36 mem = new double[rows * cols]; 37 } 38 39 /** 40 * Creates a new matrix using the provided array of values 41 * <p> 42 * Values are in row-major order. 43 * 44 * @param stride is the number of columns. 45 * @param values is the array of values. 46 * @throws IllegalArgumentException if length of values array not a multiple of stride 47 */ Matrix(int stride, double[] values)48 public Matrix(int stride, double[] values) { 49 n = (values.length + stride - 1) / stride; 50 m = stride; 51 mem = values; 52 if (mem.length != n * m) throw new IllegalArgumentException(); 53 } 54 55 /** 56 * Creates a new matrix duplicating the given one 57 * 58 * @param that is the source Matrix. 59 */ Matrix(Matrix that)60 public Matrix(Matrix that) { 61 n = that.n; 62 m = that.m; 63 mem = new double[that.mem.length]; 64 for (int i = 0; i < mem.length; i++) { 65 mem[i] = that.mem[i]; 66 } 67 } 68 69 /** 70 * Gets the matrix coefficient from row i, column j 71 * 72 * @param i row number 73 * @param j column number 74 * @return Coefficient at i,j 75 * @throws IndexOutOfBoundsException if an index is out of bounds 76 */ get(int i, int j)77 public double get(int i, int j) { 78 if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException(); 79 return mem[i * m + j]; 80 } 81 82 /** 83 * Store a matrix coefficient in row i, column j 84 * 85 * @param i row number 86 * @param j column number 87 * @param v Coefficient to store at i,j 88 * @throws IndexOutOfBoundsException if an index is out of bounds 89 */ put(int i, int j, double v)90 public void put(int i, int j, double v) { 91 if (!(0 <= i && i < n && 0 <= j && j < m)) throw new IndexOutOfBoundsException(); 92 mem[i * m + j] = v; 93 } 94 95 /** 96 * Forms the sum of two matrices, this and that 97 * 98 * @param that is the other matrix 99 * @return newly allocated matrix representing the sum of this and that 100 * @throws IllegalArgumentException if shapes differ 101 */ plus(Matrix that)102 public Matrix plus(Matrix that) { 103 return plus(that, new Matrix(n, m)); 104 105 } 106 107 /** 108 * Forms the sum of two matrices, this and that 109 * 110 * @param that is the other matrix 111 * @param result is space to hold the result 112 * @return result, filled with the matrix sum 113 * @throws IllegalArgumentException if shapes differ 114 */ plus(Matrix that, Matrix result)115 public Matrix plus(Matrix that, Matrix result) { 116 if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) { 117 throw new IllegalArgumentException(); 118 } 119 for (int i = 0; i < mem.length; i++) { 120 result.mem[i] = this.mem[i] + that.mem[i]; 121 } 122 return result; 123 } 124 125 /** 126 * Forms the difference of two matrices, this and that 127 * 128 * @param that is the other matrix 129 * @return newly allocated matrix representing the difference of this and that 130 * @throws IllegalArgumentException if shapes differ 131 */ minus(Matrix that)132 public Matrix minus(Matrix that) { 133 return minus(that, new Matrix(n, m)); 134 } 135 136 /** 137 * Forms the difference of two matrices, this and that 138 * 139 * @param that is the other matrix 140 * @param result is space to hold the result 141 * @return result, filled with the matrix difference 142 * @throws IllegalArgumentException if shapes differ 143 */ minus(Matrix that, Matrix result)144 public Matrix minus(Matrix that, Matrix result) { 145 if (!(this.n == that.n && this.m == that.m && this.n == result.n && this.m == result.m)) { 146 throw new IllegalArgumentException(); 147 } 148 for (int i = 0; i < mem.length; i++) { 149 result.mem[i] = this.mem[i] - that.mem[i]; 150 } 151 return result; 152 } 153 154 /** 155 * Forms a scalar product 156 * 157 * @param scalar is the value to multiply by 158 * @return newly allocated matrix representing the product this and scalar 159 */ times(double scalar)160 public Matrix times(double scalar) { 161 return times(scalar, new Matrix(n, m)); 162 } 163 164 /** 165 * Forms a scalar product 166 * 167 * @param scalar is the value to multiply by 168 * @param result is space to hold the result 169 * @return result, filled with the matrix difference 170 * @throws IllegalArgumentException if shapes differ 171 */ times(double scalar, Matrix result)172 public Matrix times(double scalar, Matrix result) { 173 if (!(this.n == result.n && this.m == result.m)) { 174 throw new IllegalArgumentException(); 175 } 176 for (int i = 0; i < mem.length; i++) { 177 result.mem[i] = this.mem[i] * scalar; 178 } 179 return result; 180 } 181 182 /** 183 * Forms the matrix product of two matrices, this and that 184 * 185 * @param that is the other matrix 186 * @return newly allocated matrix representing the matrix product of this and that 187 * @throws IllegalArgumentException if shapes are not conformant 188 */ dot(Matrix that)189 public Matrix dot(Matrix that) { 190 return dot(that, new Matrix(this.n, that.m)); 191 } 192 193 /** 194 * Forms the matrix product of two matrices, this and that 195 * <p> 196 * Caller supplies an object to contain the result, as well as scratch space 197 * 198 * @param that is the other matrix 199 * @param result is space to hold the result 200 * @return result, filled with the matrix product 201 * @throws IllegalArgumentException if shapes are not conformant 202 */ dot(Matrix that, Matrix result)203 public Matrix dot(Matrix that, Matrix result) { 204 if (!(this.n == result.n && this.m == that.n && that.m == result.m)) { 205 throw new IllegalArgumentException("shape error" + this + that + result); 206 } 207 for (int i = 0; i < n; i++) { 208 for (int j = 0; j < that.m; j++) { 209 double s = 0.0; 210 for (int k = 0; k < m; k++) { 211 s += this.get(i, k) * that.get(k, j); 212 } 213 result.put(i, j, s); 214 } 215 } 216 return result; 217 } 218 219 /** 220 * Forms the matrix transpose 221 * 222 * @return newly allocated transpose matrix 223 */ transpose()224 public Matrix transpose() { 225 return transpose(new Matrix(m, n)); 226 } 227 228 /** 229 * Forms the matrix transpose 230 * <p> 231 * Caller supplies an object to contain the result 232 * 233 * @param result is space to hold the result 234 * @return result, filled with the matrix transpose 235 * @throws IllegalArgumentException if result shape is wrong 236 */ transpose(Matrix result)237 public Matrix transpose(Matrix result) { 238 if (!(this.n == result.m && this.m == result.n)) throw new IllegalArgumentException(); 239 for (int i = 0; i < n; i++) { 240 for (int j = 0; j < m; j++) { 241 result.put(j, i, get(i, j)); 242 } 243 } 244 return result; 245 } 246 247 /** 248 * Forms the inverse of a square matrix 249 * 250 * @return newly allocated matrix representing the matrix inverse 251 * @throws ArithmeticException if the matrix is not invertible 252 */ inverse()253 public Matrix inverse() { 254 return inverse(new Matrix(n, m), new Matrix(n, 2 * m)); 255 } 256 257 /** 258 * Forms the inverse of a square matrix 259 * 260 * @param result is space to hold the result 261 * @param scratch is workspace of dimension n by 2*n 262 * @return result, filled with the matrix inverse 263 * @throws ArithmeticException if the matrix is not invertible 264 * @throws IllegalArgumentException if shape of scratch or result is wrong 265 */ inverse(Matrix result, Matrix scratch)266 public Matrix inverse(Matrix result, Matrix scratch) { 267 if (!(n == m && n == result.n && m == result.m && n == scratch.n && 2 * m == scratch.m)) { 268 throw new IllegalArgumentException(); 269 } 270 271 for (int i = 0; i < n; i++) { 272 for (int j = 0; j < m; j++) { 273 scratch.put(i, j, get(i, j)); 274 scratch.put(i, m + j, i == j ? 1.0 : 0.0); 275 } 276 } 277 278 for (int i = 0; i < n; i++) { 279 int ibest = i; 280 double vbest = Math.abs(scratch.get(ibest, ibest)); 281 for (int ii = i + 1; ii < n; ii++) { 282 double v = Math.abs(scratch.get(ii, i)); 283 if (v > vbest) { 284 ibest = ii; 285 vbest = v; 286 } 287 } 288 if (ibest != i) { 289 for (int j = 0; j < scratch.m; j++) { 290 double t = scratch.get(i, j); 291 scratch.put(i, j, scratch.get(ibest, j)); 292 scratch.put(ibest, j, t); 293 } 294 } 295 double d = scratch.get(i, i); 296 if (d == 0.0) throw new ArithmeticException("Singular matrix"); 297 for (int j = 0; j < scratch.m; j++) { 298 scratch.put(i, j, scratch.get(i, j) / d); 299 } 300 for (int ii = i + 1; ii < n; ii++) { 301 d = scratch.get(ii, i); 302 for (int j = 0; j < scratch.m; j++) { 303 scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j)); 304 } 305 } 306 } 307 for (int i = n - 1; i >= 0; i--) { 308 for (int ii = 0; ii < i; ii++) { 309 double d = scratch.get(ii, i); 310 for (int j = 0; j < scratch.m; j++) { 311 scratch.put(ii, j, scratch.get(ii, j) - d * scratch.get(i, j)); 312 } 313 } 314 } 315 for (int i = 0; i < result.n; i++) { 316 for (int j = 0; j < result.m; j++) { 317 result.put(i, j, scratch.get(i, m + j)); 318 } 319 } 320 return result; 321 } 322 /** 323 * Forms the matrix product with the transpose of a second matrix 324 * 325 * @param that is the other matrix 326 * @return newly allocated matrix representing the matrix product of this and that.transpose() 327 * @throws IllegalArgumentException if shapes are not conformant 328 */ dotTranspose(Matrix that)329 public Matrix dotTranspose(Matrix that) { 330 return dotTranspose(that, new Matrix(this.n, that.n)); 331 } 332 333 /** 334 * Forms the matrix product with the transpose of a second matrix 335 * <p> 336 * Caller supplies an object to contain the result, as well as scratch space 337 * 338 * @param that is the other matrix 339 * @param result is space to hold the result 340 * @return result, filled with the matrix product of this and that.transpose() 341 * @throws IllegalArgumentException if shapes are not conformant 342 */ dotTranspose(Matrix that, Matrix result)343 public Matrix dotTranspose(Matrix that, Matrix result) { 344 if (!(this.n == result.n && this.m == that.m && that.n == result.m)) { 345 throw new IllegalArgumentException(); 346 } 347 for (int i = 0; i < n; i++) { 348 for (int j = 0; j < that.n; j++) { 349 double s = 0.0; 350 for (int k = 0; k < m; k++) { 351 s += this.get(i, k) * that.get(j, k); 352 } 353 result.put(i, j, s); 354 } 355 } 356 return result; 357 } 358 359 /** 360 * Tests for equality 361 */ 362 @Override equals(Object that)363 public boolean equals(Object that) { 364 if (this == that) return true; 365 if (!(that instanceof Matrix)) return false; 366 Matrix other = (Matrix) that; 367 if (n != other.n) return false; 368 if (m != other.m) return false; 369 for (int i = 0; i < mem.length; i++) { 370 if (mem[i] != other.mem[i]) return false; 371 } 372 return true; 373 } 374 375 /** 376 * Calculates a hash code 377 */ 378 @Override hashCode()379 public int hashCode() { 380 int h = n * 101 + m; 381 for (int i = 0; i < mem.length; i++) { 382 h = h * 37 + Double.hashCode(mem[i]); 383 } 384 return h; 385 } 386 387 /** 388 * Makes a string representation 389 * 390 * @return string like "[a, b; c, d]" 391 */ 392 @Override toString()393 public String toString() { 394 StringBuilder sb = new StringBuilder(n * m * 8); 395 sb.append("["); 396 for (int i = 0; i < mem.length; i++) { 397 if (i > 0) sb.append(i % m == 0 ? "; " : ", "); 398 sb.append(mem[i]); 399 } 400 sb.append("]"); 401 return sb.toString(); 402 } 403 404 } 405