• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //===- Math.h - PBQP Vector and Matrix classes ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_CODEGEN_PBQP_MATH_H
10 #define LLVM_CODEGEN_PBQP_MATH_H
11 
12 #include "llvm/ADT/Hashing.h"
13 #include "llvm/ADT/STLExtras.h"
14 #include <algorithm>
15 #include <cassert>
16 #include <functional>
17 #include <memory>
18 
19 namespace llvm {
20 namespace PBQP {
21 
22 using PBQPNum = float;
23 
24 /// PBQP Vector class.
25 class Vector {
26   friend hash_code hash_value(const Vector &);
27 
28 public:
29   /// Construct a PBQP vector of the given size.
Vector(unsigned Length)30   explicit Vector(unsigned Length)
31     : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {}
32 
33   /// Construct a PBQP vector with initializer.
Vector(unsigned Length,PBQPNum InitVal)34   Vector(unsigned Length, PBQPNum InitVal)
35     : Length(Length), Data(std::make_unique<PBQPNum []>(Length)) {
36     std::fill(Data.get(), Data.get() + Length, InitVal);
37   }
38 
39   /// Copy construct a PBQP vector.
Vector(const Vector & V)40   Vector(const Vector &V)
41     : Length(V.Length), Data(std::make_unique<PBQPNum []>(Length)) {
42     std::copy(V.Data.get(), V.Data.get() + Length, Data.get());
43   }
44 
45   /// Move construct a PBQP vector.
Vector(Vector && V)46   Vector(Vector &&V)
47     : Length(V.Length), Data(std::move(V.Data)) {
48     V.Length = 0;
49   }
50 
51   /// Comparison operator.
52   bool operator==(const Vector &V) const {
53     assert(Length != 0 && Data && "Invalid vector");
54     if (Length != V.Length)
55       return false;
56     return std::equal(Data.get(), Data.get() + Length, V.Data.get());
57   }
58 
59   /// Return the length of the vector
getLength()60   unsigned getLength() const {
61     assert(Length != 0 && Data && "Invalid vector");
62     return Length;
63   }
64 
65   /// Element access.
66   PBQPNum& operator[](unsigned Index) {
67     assert(Length != 0 && Data && "Invalid vector");
68     assert(Index < Length && "Vector element access out of bounds.");
69     return Data[Index];
70   }
71 
72   /// Const element access.
73   const PBQPNum& operator[](unsigned Index) const {
74     assert(Length != 0 && Data && "Invalid vector");
75     assert(Index < Length && "Vector element access out of bounds.");
76     return Data[Index];
77   }
78 
79   /// Add another vector to this one.
80   Vector& operator+=(const Vector &V) {
81     assert(Length != 0 && Data && "Invalid vector");
82     assert(Length == V.Length && "Vector length mismatch.");
83     std::transform(Data.get(), Data.get() + Length, V.Data.get(), Data.get(),
84                    std::plus<PBQPNum>());
85     return *this;
86   }
87 
88   /// Returns the index of the minimum value in this vector
minIndex()89   unsigned minIndex() const {
90     assert(Length != 0 && Data && "Invalid vector");
91     return std::min_element(Data.get(), Data.get() + Length) - Data.get();
92   }
93 
94 private:
95   unsigned Length;
96   std::unique_ptr<PBQPNum []> Data;
97 };
98 
99 /// Return a hash_value for the given vector.
hash_value(const Vector & V)100 inline hash_code hash_value(const Vector &V) {
101   unsigned *VBegin = reinterpret_cast<unsigned*>(V.Data.get());
102   unsigned *VEnd = reinterpret_cast<unsigned*>(V.Data.get() + V.Length);
103   return hash_combine(V.Length, hash_combine_range(VBegin, VEnd));
104 }
105 
106 /// Output a textual representation of the given vector on the given
107 ///        output stream.
108 template <typename OStream>
109 OStream& operator<<(OStream &OS, const Vector &V) {
110   assert((V.getLength() != 0) && "Zero-length vector badness.");
111 
112   OS << "[ " << V[0];
113   for (unsigned i = 1; i < V.getLength(); ++i)
114     OS << ", " << V[i];
115   OS << " ]";
116 
117   return OS;
118 }
119 
120 /// PBQP Matrix class
121 class Matrix {
122 private:
123   friend hash_code hash_value(const Matrix &);
124 
125 public:
126   /// Construct a PBQP Matrix with the given dimensions.
Matrix(unsigned Rows,unsigned Cols)127   Matrix(unsigned Rows, unsigned Cols) :
128     Rows(Rows), Cols(Cols), Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
129   }
130 
131   /// Construct a PBQP Matrix with the given dimensions and initial
132   /// value.
Matrix(unsigned Rows,unsigned Cols,PBQPNum InitVal)133   Matrix(unsigned Rows, unsigned Cols, PBQPNum InitVal)
134     : Rows(Rows), Cols(Cols),
135       Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
136     std::fill(Data.get(), Data.get() + (Rows * Cols), InitVal);
137   }
138 
139   /// Copy construct a PBQP matrix.
Matrix(const Matrix & M)140   Matrix(const Matrix &M)
141     : Rows(M.Rows), Cols(M.Cols),
142       Data(std::make_unique<PBQPNum []>(Rows * Cols)) {
143     std::copy(M.Data.get(), M.Data.get() + (Rows * Cols), Data.get());
144   }
145 
146   /// Move construct a PBQP matrix.
Matrix(Matrix && M)147   Matrix(Matrix &&M)
148     : Rows(M.Rows), Cols(M.Cols), Data(std::move(M.Data)) {
149     M.Rows = M.Cols = 0;
150   }
151 
152   /// Comparison operator.
153   bool operator==(const Matrix &M) const {
154     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
155     if (Rows != M.Rows || Cols != M.Cols)
156       return false;
157     return std::equal(Data.get(), Data.get() + (Rows * Cols), M.Data.get());
158   }
159 
160   /// Return the number of rows in this matrix.
getRows()161   unsigned getRows() const {
162     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
163     return Rows;
164   }
165 
166   /// Return the number of cols in this matrix.
getCols()167   unsigned getCols() const {
168     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
169     return Cols;
170   }
171 
172   /// Matrix element access.
173   PBQPNum* operator[](unsigned R) {
174     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
175     assert(R < Rows && "Row out of bounds.");
176     return Data.get() + (R * Cols);
177   }
178 
179   /// Matrix element access.
180   const PBQPNum* operator[](unsigned R) const {
181     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
182     assert(R < Rows && "Row out of bounds.");
183     return Data.get() + (R * Cols);
184   }
185 
186   /// Returns the given row as a vector.
getRowAsVector(unsigned R)187   Vector getRowAsVector(unsigned R) const {
188     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
189     Vector V(Cols);
190     for (unsigned C = 0; C < Cols; ++C)
191       V[C] = (*this)[R][C];
192     return V;
193   }
194 
195   /// Returns the given column as a vector.
getColAsVector(unsigned C)196   Vector getColAsVector(unsigned C) const {
197     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
198     Vector V(Rows);
199     for (unsigned R = 0; R < Rows; ++R)
200       V[R] = (*this)[R][C];
201     return V;
202   }
203 
204   /// Matrix transpose.
transpose()205   Matrix transpose() const {
206     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
207     Matrix M(Cols, Rows);
208     for (unsigned r = 0; r < Rows; ++r)
209       for (unsigned c = 0; c < Cols; ++c)
210         M[c][r] = (*this)[r][c];
211     return M;
212   }
213 
214   /// Add the given matrix to this one.
215   Matrix& operator+=(const Matrix &M) {
216     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
217     assert(Rows == M.Rows && Cols == M.Cols &&
218            "Matrix dimensions mismatch.");
219     std::transform(Data.get(), Data.get() + (Rows * Cols), M.Data.get(),
220                    Data.get(), std::plus<PBQPNum>());
221     return *this;
222   }
223 
224   Matrix operator+(const Matrix &M) {
225     assert(Rows != 0 && Cols != 0 && Data && "Invalid matrix");
226     Matrix Tmp(*this);
227     Tmp += M;
228     return Tmp;
229   }
230 
231 private:
232   unsigned Rows, Cols;
233   std::unique_ptr<PBQPNum []> Data;
234 };
235 
236 /// Return a hash_code for the given matrix.
hash_value(const Matrix & M)237 inline hash_code hash_value(const Matrix &M) {
238   unsigned *MBegin = reinterpret_cast<unsigned*>(M.Data.get());
239   unsigned *MEnd =
240     reinterpret_cast<unsigned*>(M.Data.get() + (M.Rows * M.Cols));
241   return hash_combine(M.Rows, M.Cols, hash_combine_range(MBegin, MEnd));
242 }
243 
244 /// Output a textual representation of the given matrix on the given
245 ///        output stream.
246 template <typename OStream>
247 OStream& operator<<(OStream &OS, const Matrix &M) {
248   assert((M.getRows() != 0) && "Zero-row matrix badness.");
249   for (unsigned i = 0; i < M.getRows(); ++i)
250     OS << M.getRowAsVector(i) << "\n";
251   return OS;
252 }
253 
254 template <typename Metadata>
255 class MDVector : public Vector {
256 public:
MDVector(const Vector & v)257   MDVector(const Vector &v) : Vector(v), md(*this) {}
MDVector(Vector && v)258   MDVector(Vector &&v) : Vector(std::move(v)), md(*this) { }
259 
getMetadata()260   const Metadata& getMetadata() const { return md; }
261 
262 private:
263   Metadata md;
264 };
265 
266 template <typename Metadata>
hash_value(const MDVector<Metadata> & V)267 inline hash_code hash_value(const MDVector<Metadata> &V) {
268   return hash_value(static_cast<const Vector&>(V));
269 }
270 
271 template <typename Metadata>
272 class MDMatrix : public Matrix {
273 public:
MDMatrix(const Matrix & m)274   MDMatrix(const Matrix &m) : Matrix(m), md(*this) {}
MDMatrix(Matrix && m)275   MDMatrix(Matrix &&m) : Matrix(std::move(m)), md(*this) { }
276 
getMetadata()277   const Metadata& getMetadata() const { return md; }
278 
279 private:
280   Metadata md;
281 };
282 
283 template <typename Metadata>
hash_value(const MDMatrix<Metadata> & M)284 inline hash_code hash_value(const MDMatrix<Metadata> &M) {
285   return hash_value(static_cast<const Matrix&>(M));
286 }
287 
288 } // end namespace PBQP
289 } // end namespace llvm
290 
291 #endif // LLVM_CODEGEN_PBQP_MATH_H
292