• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 //
2 // Copyright (c) 2019-2020 Advanced Micro Devices, Inc. All rights reserved.
3 //
4 // Permission is hereby granted, free of charge, to any person obtaining a copy
5 // of this software and associated documentation files (the "Software"), to deal
6 // in the Software without restriction, including without limitation the rights
7 // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
8 // copies of the Software, and to permit persons to whom the Software is
9 // furnished to do so, subject to the following conditions:
10 //
11 // The above copyright notice and this permission notice shall be included in
12 // all copies or substantial portions of the Software.
13 //
14 // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
15 // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
16 // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
17 // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
18 // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
19 // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
20 // THE SOFTWARE.
21 //
22 
23 #pragma once
24 
25 #define WIN32_LEAN_AND_MEAN
26 #define NOMINMAX
27 
28 #include <dxgi1_4.h>
29 #include <d3d12.h>
30 
31 #include <Windows.h>
32 #include <atlbase.h> // For CComPtr
33 
34 #include <iostream>
35 #include <fstream>
36 #include <vector>
37 #include <memory>
38 #include <algorithm>
39 #include <numeric>
40 #include <array>
41 #include <type_traits>
42 #include <utility>
43 #include <chrono>
44 #include <string>
45 #include <exception>
46 
47 #include <cassert>
48 #include <cstdlib>
49 #include <cstdio>
50 #include <cstdarg>
51 
52 typedef std::chrono::high_resolution_clock::time_point time_point;
53 typedef std::chrono::high_resolution_clock::duration duration;
54 
55 #define STRINGIZE(x) STRINGIZE2(x)
56 #define STRINGIZE2(x) #x
57 #define LINE_STRING STRINGIZE(__LINE__)
58 #define FAIL(msg) do { \
59         assert(0 && msg); \
60         throw std::runtime_error(msg); \
61     } while(false)
62 
63 #define CHECK_BOOL(expr)  do { if(!(expr)) FAIL(__FILE__ "(" LINE_STRING "): !( " #expr " )"); } while(false)
64 #define CHECK_HR(expr)  do { if(FAILED(expr)) FAIL(__FILE__ "(" LINE_STRING "): FAILED( " #expr " )"); } while(false)
65 
66 template <typename T>
CeilDiv(T x,T y)67 inline constexpr T CeilDiv(T x, T y)
68 {
69     return (x+y-1) / y;
70 }
71 template <typename T>
RoundDiv(T x,T y)72 inline constexpr T RoundDiv(T x, T y)
73 {
74     return (x+y/(T)2) / y;
75 }
76 
77 template <typename T>
AlignUp(T val,T align)78 inline constexpr T AlignUp(T val, T align)
79 {
80     return (val + align - 1) / align * align;
81 }
82 
83 static const float PI = 3.14159265358979323846264338327950288419716939937510582f;
84 
85 static const D3D12_RANGE EMPTY_RANGE = {0, 0};
86 
87 struct vec2
88 {
89     float x, y;
90 
vec2vec291     vec2() { }
vec2vec292     vec2(float x, float y) : x(x), y(y) { }
93 
94     float& operator[](uint32_t index) { return *(&x + index); }
95     const float& operator[](uint32_t index) const { return *(&x + index); }
96 
97     vec2 operator+(const vec2& rhs) const { return vec2(x + rhs.x, y + rhs.y); }
98     vec2 operator-(const vec2& rhs) const { return vec2(x - rhs.x, y - rhs.y); }
99     vec2 operator*(float s) const { return vec2(x * s, y * s); }
100 
Normalizedvec2101     vec2 Normalized() const
102     {
103         return (*this) * (1.f / sqrt(x * x + y * y));
104     }
105 };
106 
107 struct vec3
108 {
109     float x, y, z;
110 
vec3vec3111     vec3() { }
vec3vec3112     vec3(float x, float y, float z) : x(x), y(y), z(z) { }
113 
114     float& operator[](uint32_t index) { return *(&x + index); }
115     const float& operator[](uint32_t index) const { return *(&x + index); }
116 
117     vec3 operator+(const vec3& rhs) const { return vec3(x + rhs.x, y + rhs.y, z + rhs.z); }
118     vec3 operator-(const vec3& rhs) const { return vec3(x - rhs.x, y - rhs.y, z - rhs.z); }
119     vec3 operator*(float s) const { return vec3(x * s, y * s, z * s); }
120 
Normalizedvec3121     vec3 Normalized() const
122     {
123         return (*this) * (1.f / sqrt(x * x + y * y + z * z));
124     }
125 };
126 
Dot(const vec3 & lhs,const vec3 & rhs)127 inline float Dot(const vec3& lhs, const vec3& rhs)
128 {
129     return lhs.x * rhs.x + lhs.y * rhs.y + lhs.z * rhs.z;
130 }
Cross(const vec3 & lhs,const vec3 & rhs)131 inline vec3 Cross(const vec3& lhs, const vec3& rhs)
132 {
133     return vec3(
134         lhs.y * rhs.z - lhs.z * rhs.y,
135 	    lhs.z * rhs.x - lhs.x * rhs.z,
136 	    lhs.x * rhs.y - lhs.y * rhs.x);
137 }
138 
139 struct vec4
140 {
141     float x, y, z, w;
142 
vec4vec4143     vec4() { }
vec4vec4144     vec4(float x, float y, float z, float w) : x(x), y(y), z(z), w(w) { }
vec4vec4145     vec4(const vec3& v, float w) : x(v.x), y(v.y), z(v.z), w(w) { }
146 
147     float& operator[](uint32_t index) { return *(&x + index); }
148     const float& operator[](uint32_t index) const { return *(&x + index); }
149 
150     vec4 operator+(const vec4& rhs) const { return vec4(x + rhs.x, y + rhs.y, z + rhs.z, w + rhs.w); }
151     vec4 operator-(const vec4& rhs) const { return vec4(x - rhs.x, y - rhs.y, z - rhs.z, w - rhs.w); }
152     vec4 operator*(float s) const { return vec4(x * s, y * s, z * s, w * s); }
153 };
154 
155 struct mat4
156 {
157     union
158     {
159         struct
160         {
161             float _11, _12, _13, _14;
162             float _21, _22, _23, _24;
163             float _31, _32, _33, _34;
164             float _41, _42, _43, _44;
165         };
166         float m[4][4]; // [row][column]
167     };
168 
mat4mat4169     mat4() { }
170 
mat4mat4171     mat4(
172         float _11, float _12, float _13, float _14,
173         float _21, float _22, float _23, float _24,
174         float _31, float _32, float _33, float _34,
175         float _41, float _42, float _43, float _44) :
176         _11(_11), _12(_12), _13(_13), _14(_14),
177         _21(_21), _22(_22), _23(_23), _24(_24),
178         _31(_31), _32(_32), _33(_33), _34(_34),
179         _41(_41), _42(_42), _43(_43), _44(_44)
180     {
181     }
182 
mat4mat4183     mat4(
184         const vec4& row1,
185         const vec4& row2,
186         const vec4& row3,
187         const vec4& row4) :
188         _11(row1.x), _12(row1.y), _13(row1.z), _14(row1.w),
189         _21(row2.x), _22(row2.y), _23(row2.z), _24(row2.w),
190         _31(row3.x), _32(row3.y), _33(row3.z), _34(row3.w),
191         _41(row4.x), _42(row4.y), _43(row4.z), _44(row4.w)
192     {
193     }
194 
mat4mat4195     mat4(const float* data) :
196         _11(data[ 0]), _12(data[ 1]), _13(data[ 2]), _14(data[ 3]),
197         _21(data[ 4]), _22(data[ 5]), _23(data[ 6]), _24(data[ 7]),
198         _31(data[ 8]), _32(data[ 9]), _33(data[10]), _34(data[11]),
199         _41(data[12]), _42(data[13]), _43(data[14]), _44(data[15])
200     {
201     }
202 
203     mat4 operator*(const mat4 &rhs) const
204     {
205         return mat4(
206             _11 * rhs._11 + _12 * rhs._21 + _13 * rhs._31 + _14 * rhs._41,
207             _11 * rhs._12 + _12 * rhs._22 + _13 * rhs._32 + _14 * rhs._42,
208             _11 * rhs._13 + _12 * rhs._23 + _13 * rhs._33 + _14 * rhs._43,
209             _11 * rhs._14 + _12 * rhs._24 + _13 * rhs._34 + _14 * rhs._44,
210 
211             _21 * rhs._11 + _22 * rhs._21 + _23 * rhs._31 + _24 * rhs._41,
212             _21 * rhs._12 + _22 * rhs._22 + _23 * rhs._32 + _24 * rhs._42,
213             _21 * rhs._13 + _22 * rhs._23 + _23 * rhs._33 + _24 * rhs._43,
214             _21 * rhs._14 + _22 * rhs._24 + _23 * rhs._34 + _24 * rhs._44,
215 
216             _31 * rhs._11 + _32 * rhs._21 + _33 * rhs._31 + _34 * rhs._41,
217             _31 * rhs._12 + _32 * rhs._22 + _33 * rhs._32 + _34 * rhs._42,
218             _31 * rhs._13 + _32 * rhs._23 + _33 * rhs._33 + _34 * rhs._43,
219             _31 * rhs._14 + _32 * rhs._24 + _33 * rhs._34 + _34 * rhs._44,
220 
221             _41 * rhs._11 + _42 * rhs._21 + _43 * rhs._31 + _44 * rhs._41,
222             _41 * rhs._12 + _42 * rhs._22 + _43 * rhs._32 + _44 * rhs._42,
223             _41 * rhs._13 + _42 * rhs._23 + _43 * rhs._33 + _44 * rhs._43,
224             _41 * rhs._14 + _42 * rhs._24 + _43 * rhs._34 + _44 * rhs._44);
225     }
226 
Identitymat4227     static mat4 Identity()
228     {
229         return mat4(
230             1.f, 0.f, 0.f, 0.f,
231             0.f, 1.f, 0.f, 0.f,
232             0.f, 0.f, 1.f, 0.f,
233             0.f, 0.f, 0.f, 1.f);
234     }
235 
Translationmat4236     static mat4 Translation(const vec3& v)
237     {
238         return mat4(
239             1.f, 0.f, 0.f, 0.f,
240             0.f, 1.f, 0.f, 0.f,
241             0.f, 0.f, 1.f, 0.f,
242             v.x, v.y, v.z, 1.f);
243     }
244 
Scalingmat4245     static mat4 Scaling(float s)
246     {
247         return mat4(
248             s,   0.f, 0.f, 0.f,
249             0.f, s,   0.f, 0.f,
250             0.f, 0.f, s,   0.f,
251             0.f, 0.f, 0.f, 1.f);
252     }
253 
Scalingmat4254     static mat4 Scaling(const vec3& s)
255     {
256         return mat4(
257             s.x, 0.f, 0.f, 0.f,
258             0.f, s.y, 0.f, 0.f,
259             0.f, 0.f, s.z, 0.f,
260             0.f, 0.f, 0.f, 1.f);
261     }
262 
RotationXmat4263     static mat4 RotationX(float angle)
264     {
265         const float s = sin(angle), c = cos(angle);
266         return mat4(
267             1.f, 0.f, 0.f, 0.f,
268             0.f, c,   s,   0.f,
269             0.f, -s,  c,   0.f,
270             0.f, 0.f, 0.f, 1.f);
271     }
272 
RotationYmat4273     static mat4 RotationY(float angle)
274     {
275         const float s = sin(angle), c = cos(angle);
276         return mat4(
277             c,   s,  0.f,  0.f,
278             -s,  c,  0.f,  0.f,
279             0.f, 0.f, 1.f, 0.f,
280             0.f, 0.f, 0.f, 1.f);
281     }
282 
RotationZmat4283     static mat4 RotationZ(float angle)
284     {
285         const float s = sin(angle), c = cos(angle);
286         return mat4(
287             c,   0.f, -s,  0.f,
288             0.f, 1.f, 0.f, 0.f,
289             s,   0.f, c,   0.f,
290             0.f, 0.f, 0.f, 1.f);
291     }
292 
Perspectivemat4293     static mat4 Perspective(float fovY, float aspectRatio, float zNear, float zFar)
294     {
295         float yScale = 1.0f / tan(fovY * 0.5f);
296         float xScale = yScale / aspectRatio;
297         return mat4(
298             xScale, 0.0f, 0.0f, 0.0f,
299             0.0f, yScale, 0.0f, 0.0f,
300             0.0f, 0.0f, zFar / (zFar - zNear), 1.0f,
301             0.0f, 0.0f, -zNear * zFar / (zFar - zNear), 0.0f);
302     }
303 
LookAtmat4304     static mat4 LookAt(vec3 at, vec3 eye, vec3 up)
305     {
306         vec3 zAxis = (at - eye).Normalized();
307         vec3 xAxis = Cross(up, zAxis).Normalized();
308         vec3 yAxis = Cross(zAxis, xAxis);
309         return mat4(
310             xAxis.x, yAxis.x, zAxis.x, 0.0f,
311             xAxis.y, yAxis.y, zAxis.y, 0.0f,
312             xAxis.z, yAxis.z, zAxis.z, 0.0f,
313             -Dot(xAxis, eye), -Dot(yAxis, eye), -Dot(zAxis, eye), 1.0f);
314     }
315 
Transposedmat4316     mat4 Transposed() const
317     {
318         return mat4(
319             _11, _21, _31, _41,
320             _12, _22, _32, _42,
321             _13, _23, _33, _43,
322             _14, _24, _34, _44);
323     }
324 };
325 
326 class RandomNumberGenerator
327 {
328 public:
RandomNumberGenerator()329     RandomNumberGenerator() : m_Value{GetTickCount()} {}
RandomNumberGenerator(uint32_t seed)330     RandomNumberGenerator(uint32_t seed) : m_Value{seed} { }
Seed(uint32_t seed)331     void Seed(uint32_t seed) { m_Value = seed; }
Generate()332     uint32_t Generate() { return GenerateFast() ^ (GenerateFast() >> 7); }
GenerateBool()333     bool GenerateBool() { return (GenerateFast() & 0x4) != 0; }
334 
335 private:
336     uint32_t m_Value;
GenerateFast()337     uint32_t GenerateFast() { return m_Value = (m_Value * 196314165 + 907633515); }
338 };
339 
340 // Wrapper for RandomNumberGenerator compatible with STL "UniformRandomNumberGenerator" idea.
341 struct MyUniformRandomNumberGenerator
342 {
343     typedef uint32_t result_type;
MyUniformRandomNumberGeneratorMyUniformRandomNumberGenerator344     MyUniformRandomNumberGenerator(RandomNumberGenerator& gen) : m_Gen(gen) { }
operatorMyUniformRandomNumberGenerator345     uint32_t operator()() { return m_Gen.Generate(); }
346 
347 private:
348     RandomNumberGenerator& m_Gen;
349 };
350 
351 void ReadFile(std::vector<char>& out, const wchar_t* fileName);
352 void SaveFile(const wchar_t* filePath, const void* data, size_t dataSize);
353 
354 enum class CONSOLE_COLOR
355 {
356     INFO,
357     NORMAL,
358     WARNING,
359     ERROR_,
360     COUNT
361 };
362 
363 void SetConsoleColor(CONSOLE_COLOR color);
364 
365 void PrintMessage(CONSOLE_COLOR color, const char* msg);
366 void PrintMessage(CONSOLE_COLOR color, const wchar_t* msg);
367 
Print(const char * msg)368 inline void Print(const char* msg) { PrintMessage(CONSOLE_COLOR::NORMAL, msg); }
Print(const wchar_t * msg)369 inline void Print(const wchar_t* msg) { PrintMessage(CONSOLE_COLOR::NORMAL, msg); }
PrintWarning(const char * msg)370 inline void PrintWarning(const char* msg) { PrintMessage(CONSOLE_COLOR::WARNING, msg); }
PrintWarning(const wchar_t * msg)371 inline void PrintWarning(const wchar_t* msg) { PrintMessage(CONSOLE_COLOR::WARNING, msg); }
PrintError(const char * msg)372 inline void PrintError(const char* msg) { PrintMessage(CONSOLE_COLOR::ERROR_, msg); }
PrintError(const wchar_t * msg)373 inline void PrintError(const wchar_t* msg) { PrintMessage(CONSOLE_COLOR::ERROR_, msg); }
374 
375 void PrintMessageV(CONSOLE_COLOR color, const char* format, va_list argList);
376 void PrintMessageV(CONSOLE_COLOR color, const wchar_t* format, va_list argList);
377 void PrintMessageF(CONSOLE_COLOR color, const char* format, ...);
378 void PrintMessageF(CONSOLE_COLOR color, const wchar_t* format, ...);
379 void PrintWarningF(const char* format, ...);
380 void PrintWarningF(const wchar_t* format, ...);
381 void PrintErrorF(const char* format, ...);
382 void PrintErrorF(const wchar_t* format, ...);
383 
384