1 /*
2 * Copyright (c) 2014 The WebRTC project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #ifndef WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
12 #define WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
13
14 #include <algorithm>
15 #include <cstring>
16 #include <string>
17 #include <vector>
18
19 #include "webrtc/base/checks.h"
20 #include "webrtc/base/constructormagic.h"
21 #include "webrtc/base/scoped_ptr.h"
22
23 namespace {
24
25 // Wrappers to get around the compiler warning resulting from the fact that
26 // there's no std::sqrt overload for ints. We cast all non-complex types to
27 // a double for the sqrt method.
28 template <typename T>
sqrt_wrapper(T x)29 T sqrt_wrapper(T x) {
30 return sqrt(static_cast<double>(x));
31 }
32
33 template <typename S>
sqrt_wrapper(std::complex<S> x)34 std::complex<S> sqrt_wrapper(std::complex<S> x) {
35 return sqrt(x);
36 }
37 } // namespace
38
39 namespace webrtc {
40
41 // Matrix is a class for doing standard matrix operations on 2 dimensional
42 // matrices of any size. Results of matrix operations are stored in the
43 // calling object. Function overloads exist for both in-place (the calling
44 // object is used as both an operand and the result) and out-of-place (all
45 // operands are passed in as parameters) operations. If operand dimensions
46 // mismatch, the program crashes. Out-of-place operations change the size of
47 // the calling object, if necessary, before operating.
48 //
49 // 'In-place' operations that inherently change the size of the matrix (eg.
50 // Transpose, Multiply on different-sized matrices) must make temporary copies
51 // (|scratch_elements_| and |scratch_data_|) of existing data to complete the
52 // operations.
53 //
54 // The data is stored contiguously. Data can be accessed internally as a flat
55 // array, |data_|, or as an array of row pointers, |elements_|, but is
56 // available to users only as an array of row pointers through |elements()|.
57 // Memory for storage is allocated when a matrix is resized only if the new
58 // size overflows capacity. Memory needed temporarily for any operations is
59 // similarly resized only if the new size overflows capacity.
60 //
61 // If you pass in storage through the ctor, that storage is copied into the
62 // matrix. TODO(claguna): albeit tricky, allow for data to be referenced
63 // instead of copied, and owned by the user.
64 template <typename T>
65 class Matrix {
66 public:
Matrix()67 Matrix() : num_rows_(0), num_columns_(0) {}
68
69 // Allocates space for the elements and initializes all values to zero.
Matrix(size_t num_rows,size_t num_columns)70 Matrix(size_t num_rows, size_t num_columns)
71 : num_rows_(num_rows), num_columns_(num_columns) {
72 Resize();
73 scratch_data_.resize(num_rows_ * num_columns_);
74 scratch_elements_.resize(num_rows_);
75 }
76
77 // Copies |data| into the new Matrix.
Matrix(const T * data,size_t num_rows,size_t num_columns)78 Matrix(const T* data, size_t num_rows, size_t num_columns)
79 : num_rows_(0), num_columns_(0) {
80 CopyFrom(data, num_rows, num_columns);
81 scratch_data_.resize(num_rows_ * num_columns_);
82 scratch_elements_.resize(num_rows_);
83 }
84
~Matrix()85 virtual ~Matrix() {}
86
87 // Deep copy an existing matrix.
CopyFrom(const Matrix & other)88 void CopyFrom(const Matrix& other) {
89 CopyFrom(&other.data_[0], other.num_rows_, other.num_columns_);
90 }
91
92 // Copy |data| into the Matrix. The current data is lost.
CopyFrom(const T * const data,size_t num_rows,size_t num_columns)93 void CopyFrom(const T* const data, size_t num_rows, size_t num_columns) {
94 Resize(num_rows, num_columns);
95 memcpy(&data_[0], data, num_rows_ * num_columns_ * sizeof(data_[0]));
96 }
97
CopyFromColumn(const T * const * src,size_t column_index,size_t num_rows)98 Matrix& CopyFromColumn(const T* const* src,
99 size_t column_index,
100 size_t num_rows) {
101 Resize(1, num_rows);
102 for (size_t i = 0; i < num_columns_; ++i) {
103 data_[i] = src[i][column_index];
104 }
105
106 return *this;
107 }
108
Resize(size_t num_rows,size_t num_columns)109 void Resize(size_t num_rows, size_t num_columns) {
110 if (num_rows != num_rows_ || num_columns != num_columns_) {
111 num_rows_ = num_rows;
112 num_columns_ = num_columns;
113 Resize();
114 }
115 }
116
117 // Accessors and mutators.
num_rows()118 size_t num_rows() const { return num_rows_; }
num_columns()119 size_t num_columns() const { return num_columns_; }
elements()120 T* const* elements() { return &elements_[0]; }
elements()121 const T* const* elements() const { return &elements_[0]; }
122
Trace()123 T Trace() {
124 RTC_CHECK_EQ(num_rows_, num_columns_);
125
126 T trace = 0;
127 for (size_t i = 0; i < num_rows_; ++i) {
128 trace += elements_[i][i];
129 }
130 return trace;
131 }
132
133 // Matrix Operations. Returns *this to support method chaining.
Transpose()134 Matrix& Transpose() {
135 CopyDataToScratch();
136 Resize(num_columns_, num_rows_);
137 return Transpose(scratch_elements());
138 }
139
Transpose(const Matrix & operand)140 Matrix& Transpose(const Matrix& operand) {
141 RTC_CHECK_EQ(operand.num_rows_, num_columns_);
142 RTC_CHECK_EQ(operand.num_columns_, num_rows_);
143
144 return Transpose(operand.elements());
145 }
146
147 template <typename S>
Scale(const S & scalar)148 Matrix& Scale(const S& scalar) {
149 for (size_t i = 0; i < data_.size(); ++i) {
150 data_[i] *= scalar;
151 }
152
153 return *this;
154 }
155
156 template <typename S>
Scale(const Matrix & operand,const S & scalar)157 Matrix& Scale(const Matrix& operand, const S& scalar) {
158 CopyFrom(operand);
159 return Scale(scalar);
160 }
161
Add(const Matrix & operand)162 Matrix& Add(const Matrix& operand) {
163 RTC_CHECK_EQ(num_rows_, operand.num_rows_);
164 RTC_CHECK_EQ(num_columns_, operand.num_columns_);
165
166 for (size_t i = 0; i < data_.size(); ++i) {
167 data_[i] += operand.data_[i];
168 }
169
170 return *this;
171 }
172
Add(const Matrix & lhs,const Matrix & rhs)173 Matrix& Add(const Matrix& lhs, const Matrix& rhs) {
174 CopyFrom(lhs);
175 return Add(rhs);
176 }
177
Subtract(const Matrix & operand)178 Matrix& Subtract(const Matrix& operand) {
179 RTC_CHECK_EQ(num_rows_, operand.num_rows_);
180 RTC_CHECK_EQ(num_columns_, operand.num_columns_);
181
182 for (size_t i = 0; i < data_.size(); ++i) {
183 data_[i] -= operand.data_[i];
184 }
185
186 return *this;
187 }
188
Subtract(const Matrix & lhs,const Matrix & rhs)189 Matrix& Subtract(const Matrix& lhs, const Matrix& rhs) {
190 CopyFrom(lhs);
191 return Subtract(rhs);
192 }
193
PointwiseMultiply(const Matrix & operand)194 Matrix& PointwiseMultiply(const Matrix& operand) {
195 RTC_CHECK_EQ(num_rows_, operand.num_rows_);
196 RTC_CHECK_EQ(num_columns_, operand.num_columns_);
197
198 for (size_t i = 0; i < data_.size(); ++i) {
199 data_[i] *= operand.data_[i];
200 }
201
202 return *this;
203 }
204
PointwiseMultiply(const Matrix & lhs,const Matrix & rhs)205 Matrix& PointwiseMultiply(const Matrix& lhs, const Matrix& rhs) {
206 CopyFrom(lhs);
207 return PointwiseMultiply(rhs);
208 }
209
PointwiseDivide(const Matrix & operand)210 Matrix& PointwiseDivide(const Matrix& operand) {
211 RTC_CHECK_EQ(num_rows_, operand.num_rows_);
212 RTC_CHECK_EQ(num_columns_, operand.num_columns_);
213
214 for (size_t i = 0; i < data_.size(); ++i) {
215 data_[i] /= operand.data_[i];
216 }
217
218 return *this;
219 }
220
PointwiseDivide(const Matrix & lhs,const Matrix & rhs)221 Matrix& PointwiseDivide(const Matrix& lhs, const Matrix& rhs) {
222 CopyFrom(lhs);
223 return PointwiseDivide(rhs);
224 }
225
PointwiseSquareRoot()226 Matrix& PointwiseSquareRoot() {
227 for (size_t i = 0; i < data_.size(); ++i) {
228 data_[i] = sqrt_wrapper(data_[i]);
229 }
230
231 return *this;
232 }
233
PointwiseSquareRoot(const Matrix & operand)234 Matrix& PointwiseSquareRoot(const Matrix& operand) {
235 CopyFrom(operand);
236 return PointwiseSquareRoot();
237 }
238
PointwiseAbsoluteValue()239 Matrix& PointwiseAbsoluteValue() {
240 for (size_t i = 0; i < data_.size(); ++i) {
241 data_[i] = abs(data_[i]);
242 }
243
244 return *this;
245 }
246
PointwiseAbsoluteValue(const Matrix & operand)247 Matrix& PointwiseAbsoluteValue(const Matrix& operand) {
248 CopyFrom(operand);
249 return PointwiseAbsoluteValue();
250 }
251
PointwiseSquare()252 Matrix& PointwiseSquare() {
253 for (size_t i = 0; i < data_.size(); ++i) {
254 data_[i] *= data_[i];
255 }
256
257 return *this;
258 }
259
PointwiseSquare(const Matrix & operand)260 Matrix& PointwiseSquare(const Matrix& operand) {
261 CopyFrom(operand);
262 return PointwiseSquare();
263 }
264
Multiply(const Matrix & lhs,const Matrix & rhs)265 Matrix& Multiply(const Matrix& lhs, const Matrix& rhs) {
266 RTC_CHECK_EQ(lhs.num_columns_, rhs.num_rows_);
267 RTC_CHECK_EQ(num_rows_, lhs.num_rows_);
268 RTC_CHECK_EQ(num_columns_, rhs.num_columns_);
269
270 return Multiply(lhs.elements(), rhs.num_rows_, rhs.elements());
271 }
272
Multiply(const Matrix & rhs)273 Matrix& Multiply(const Matrix& rhs) {
274 RTC_CHECK_EQ(num_columns_, rhs.num_rows_);
275
276 CopyDataToScratch();
277 Resize(num_rows_, rhs.num_columns_);
278 return Multiply(scratch_elements(), rhs.num_rows_, rhs.elements());
279 }
280
ToString()281 std::string ToString() const {
282 std::ostringstream ss;
283 ss << std::endl << "Matrix" << std::endl;
284
285 for (size_t i = 0; i < num_rows_; ++i) {
286 for (size_t j = 0; j < num_columns_; ++j) {
287 ss << elements_[i][j] << " ";
288 }
289 ss << std::endl;
290 }
291 ss << std::endl;
292
293 return ss.str();
294 }
295
296 protected:
SetNumRows(const size_t num_rows)297 void SetNumRows(const size_t num_rows) { num_rows_ = num_rows; }
SetNumColumns(const size_t num_columns)298 void SetNumColumns(const size_t num_columns) { num_columns_ = num_columns; }
data()299 T* data() { return &data_[0]; }
data()300 const T* data() const { return &data_[0]; }
scratch_elements()301 const T* const* scratch_elements() const { return &scratch_elements_[0]; }
302
303 // Resize the matrix. If an increase in capacity is required, the current
304 // data is lost.
Resize()305 void Resize() {
306 size_t size = num_rows_ * num_columns_;
307 data_.resize(size);
308 elements_.resize(num_rows_);
309
310 for (size_t i = 0; i < num_rows_; ++i) {
311 elements_[i] = &data_[i * num_columns_];
312 }
313 }
314
315 // Copies data_ into scratch_data_ and updates scratch_elements_ accordingly.
CopyDataToScratch()316 void CopyDataToScratch() {
317 scratch_data_ = data_;
318 scratch_elements_.resize(num_rows_);
319
320 for (size_t i = 0; i < num_rows_; ++i) {
321 scratch_elements_[i] = &scratch_data_[i * num_columns_];
322 }
323 }
324
325 private:
326 size_t num_rows_;
327 size_t num_columns_;
328 std::vector<T> data_;
329 std::vector<T*> elements_;
330
331 // Stores temporary copies of |data_| and |elements_| for in-place operations
332 // where referring to original data is necessary.
333 std::vector<T> scratch_data_;
334 std::vector<T*> scratch_elements_;
335
336 // Helpers for Transpose and Multiply operations that unify in-place and
337 // out-of-place solutions.
Transpose(const T * const * src)338 Matrix& Transpose(const T* const* src) {
339 for (size_t i = 0; i < num_rows_; ++i) {
340 for (size_t j = 0; j < num_columns_; ++j) {
341 elements_[i][j] = src[j][i];
342 }
343 }
344
345 return *this;
346 }
347
Multiply(const T * const * lhs,size_t num_rows_rhs,const T * const * rhs)348 Matrix& Multiply(const T* const* lhs,
349 size_t num_rows_rhs,
350 const T* const* rhs) {
351 for (size_t row = 0; row < num_rows_; ++row) {
352 for (size_t col = 0; col < num_columns_; ++col) {
353 T cur_element = 0;
354 for (size_t i = 0; i < num_rows_rhs; ++i) {
355 cur_element += lhs[row][i] * rhs[i][col];
356 }
357
358 elements_[row][col] = cur_element;
359 }
360 }
361
362 return *this;
363 }
364
365 RTC_DISALLOW_COPY_AND_ASSIGN(Matrix);
366 };
367
368 } // namespace webrtc
369
370 #endif // WEBRTC_MODULES_AUDIO_PROCESSING_BEAMFORMER_MATRIX_H_
371