1 /*
2  * Copyright © Microsoft Corporation
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  */
23 
24 #include <stdio.h>
25 #include <stdint.h>
26 #include <stdexcept>
27 
28 #include <directx/d3d12.h>
29 #include <dxgi1_4.h>
30 #include <gtest/gtest.h>
31 #include <wrl.h>
32 
33 #include "clc_compiler.h"
34 
35 using std::runtime_error;
36 using Microsoft::WRL::ComPtr;
37 
38 inline D3D12_CPU_DESCRIPTOR_HANDLE
offset_cpu_handle(D3D12_CPU_DESCRIPTOR_HANDLE handle,UINT offset)39 offset_cpu_handle(D3D12_CPU_DESCRIPTOR_HANDLE handle, UINT offset)
40 {
41    handle.ptr += offset;
42    return handle;
43 }
44 
45 inline size_t
align(size_t value,unsigned alignment)46 align(size_t value, unsigned alignment)
47 {
48    assert(alignment > 0);
49    return ((value + (alignment - 1)) / alignment) * alignment;
50 }
51 
52 class ComputeTest : public ::testing::Test {
53 protected:
54    struct Shader {
55       std::shared_ptr<struct clc_binary> obj;
56       std::shared_ptr<struct clc_parsed_spirv> metadata;
57       std::shared_ptr<struct clc_dxil_object> dxil;
58    };
59 
60    static void
61    enable_d3d12_debug_layer();
62 
63    static IDXGIFactory4 *
64    get_dxgi_factory();
65 
66    static IDXGIAdapter1 *
67    choose_adapter(IDXGIFactory4 *factory);
68 
69    static ID3D12Device *
70    create_device(IDXGIAdapter1 *adapter);
71 
72    struct Resources {
addResources73       void add(ComPtr<ID3D12Resource> res,
74                D3D12_DESCRIPTOR_RANGE_TYPE type,
75                unsigned spaceid,
76                unsigned resid)
77       {
78          descs.push_back(res);
79 
80          if(!ranges.empty() &&
81             ranges.back().RangeType == type &&
82             ranges.back().RegisterSpace == spaceid &&
83             ranges.back().BaseShaderRegister + ranges.back().NumDescriptors == resid) {
84             ranges.back().NumDescriptors++;
85             return;
86          }
87 
88          D3D12_DESCRIPTOR_RANGE1 range;
89 
90          range.RangeType = type;
91          range.NumDescriptors = 1;
92          range.BaseShaderRegister = resid;
93          range.RegisterSpace = spaceid;
94          range.OffsetInDescriptorsFromTableStart = descs.size() - 1;
95          range.Flags = D3D12_DESCRIPTOR_RANGE_FLAG_DESCRIPTORS_STATIC_KEEPING_BUFFER_BOUNDS_CHECKS;
96          ranges.push_back(range);
97       }
98 
99       std::vector<D3D12_DESCRIPTOR_RANGE1> ranges;
100       std::vector<ComPtr<ID3D12Resource>> descs;
101    };
102 
103    ComPtr<ID3D12RootSignature>
104    create_root_signature(const Resources &resources);
105 
106    ComPtr<ID3D12PipelineState>
107    create_pipeline_state(ComPtr<ID3D12RootSignature> &root_sig,
108                          const struct clc_dxil_object &dxil);
109 
110    ComPtr<ID3D12Resource>
111    create_buffer(int size, D3D12_HEAP_TYPE heap_type);
112 
113    ComPtr<ID3D12Resource>
114    create_upload_buffer_with_data(const void *data, size_t size);
115 
116    ComPtr<ID3D12Resource>
117    create_sized_buffer_with_data(size_t buffer_size, const void *data,
118                                  size_t data_size);
119 
120    ComPtr<ID3D12Resource>
create_buffer_with_data(const void * data,size_t size)121    create_buffer_with_data(const void *data, size_t size)
122    {
123       return create_sized_buffer_with_data(size, data, size);
124    }
125 
126    void
127    get_buffer_data(ComPtr<ID3D12Resource> res,
128                    void *buf, size_t size);
129 
130    void
131    resource_barrier(ComPtr<ID3D12Resource> &res,
132                     D3D12_RESOURCE_STATES state_before,
133                     D3D12_RESOURCE_STATES state_after);
134 
135    void
136    execute_cmdlist();
137 
138    void
139    create_uav_buffer(ComPtr<ID3D12Resource> res,
140                      size_t width, size_t byte_stride,
141                      D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle);
142 
143    void create_cbv(ComPtr<ID3D12Resource> res, size_t size,
144                    D3D12_CPU_DESCRIPTOR_HANDLE cpu_handle);
145 
146    ComPtr<ID3D12Resource>
147    add_uav_resource(Resources &resources, unsigned spaceid, unsigned resid,
148                     const void *data = NULL, size_t num_elems = 0,
149                     size_t elem_size = 0);
150 
151    ComPtr<ID3D12Resource>
152    add_cbv_resource(Resources &resources, unsigned spaceid, unsigned resid,
153                     const void *data, size_t size);
154 
155    void
156    SetUp() override;
157 
158    void
159    TearDown() override;
160 
161    Shader
162    compile(const std::vector<const char *> &sources,
163            const std::vector<const char *> &compile_args = {},
164            bool create_library = false);
165 
166    Shader
167    link(const std::vector<Shader> &sources,
168         bool create_library = false);
169 
170    Shader
171    assemble(const char *source);
172 
173    void
174    configure(Shader &shader,
175              const struct clc_runtime_kernel_conf *conf);
176 
177    void
178    validate(Shader &shader);
179 
180    template <typename T>
181    Shader
specialize(Shader & shader,uint32_t id,T const & val)182    specialize(Shader &shader, uint32_t id, T const& val)
183    {
184       Shader new_shader;
185       new_shader.obj = std::shared_ptr<clc_binary>(new clc_binary{}, [](clc_binary *spirv)
186          {
187             clc_free_spirv(spirv);
188             delete spirv;
189          });
190       if (!shader.metadata)
191          configure(shader, NULL);
192 
193       clc_spirv_specialization spec;
194       spec.id = id;
195       memcpy(&spec.value, &val, sizeof(val));
196       clc_spirv_specialization_consts consts;
197       consts.specializations = &spec;
198       consts.num_specializations = 1;
199       if (!clc_specialize_spirv(shader.obj.get(), shader.metadata.get(), &consts, new_shader.obj.get()))
200          throw runtime_error("failed to specialize");
201 
202       configure(new_shader, NULL);
203 
204       return new_shader;
205    }
206 
207    enum ShaderArgDirection {
208       SHADER_ARG_INPUT = 1,
209       SHADER_ARG_OUTPUT = 2,
210       SHADER_ARG_INOUT = SHADER_ARG_INPUT | SHADER_ARG_OUTPUT,
211    };
212 
213    class RawShaderArg {
214    public:
RawShaderArg(enum ShaderArgDirection dir)215       RawShaderArg(enum ShaderArgDirection dir) : dir(dir) { }
216       virtual size_t get_elem_size() const = 0;
217       virtual size_t get_num_elems() const = 0;
218       virtual const void *get_data() const = 0;
219       virtual void *get_data() = 0;
get_direction()220       enum ShaderArgDirection get_direction() { return dir; }
221    private:
222       enum ShaderArgDirection dir;
223    };
224 
225    class NullShaderArg : public RawShaderArg {
226    public:
NullShaderArg()227       NullShaderArg() : RawShaderArg(SHADER_ARG_INPUT) { }
get_elem_size()228       size_t get_elem_size() const override { return 0; }
get_num_elems()229       size_t get_num_elems() const override { return 0; }
get_data()230       const void *get_data() const override { return NULL; }
get_data()231       void *get_data() override { return NULL; }
232    };
233 
234    template <typename T>
235    class ShaderArg : public std::vector<T>, public RawShaderArg
236    {
237    public:
238       ShaderArg(const T &v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) :
239          std::vector<T>({ v }), RawShaderArg(dir) { }
240       ShaderArg(const std::vector<T> &v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) :
241          std::vector<T>(v), RawShaderArg(dir) { }
242       ShaderArg(const std::initializer_list<T> v, enum ShaderArgDirection dir = SHADER_ARG_INOUT) :
243          std::vector<T>(v), RawShaderArg(dir) { }
244 
245       ShaderArg<T>& operator =(const T &v)
246       {
247          this->clear();
248          this->push_back(v);
249          return *this;
250       }
251 
252       operator T&() { return this->at(0); }
253       operator const T&() const { return this->at(0); }
254 
255       ShaderArg<T>& operator =(const std::vector<T> &v)
256       {
257          *this = v;
258          return *this;
259       }
260 
261       ShaderArg<T>& operator =(std::initializer_list<T> v)
262       {
263          *this = v;
264          return *this;
265       }
266 
get_elem_size()267       size_t get_elem_size() const override { return sizeof(T); }
get_num_elems()268       size_t get_num_elems() const override { return this->size(); }
get_data()269       const void *get_data() const override { return this->data(); }
get_data()270       void *get_data() override { return this->data(); }
271    };
272 
273    struct CompileArgs
274    {
275       unsigned x, y, z;
276       std::vector<const char *> compiler_command_line;
277       clc_work_properties_data work_props;
278    };
279 
280 private:
gather_args(std::vector<RawShaderArg * > & args)281    void gather_args(std::vector<RawShaderArg *> &args) { }
282 
283    template <typename T, typename... Rest>
gather_args(std::vector<RawShaderArg * > & args,T & arg,Rest &...rest)284    void gather_args(std::vector<RawShaderArg *> &args, T &arg, Rest&... rest)
285    {
286       args.push_back(&arg);
287       gather_args(args, rest...);
288    }
289 
290    void run_shader_with_raw_args(Shader shader,
291                                  const CompileArgs &compile_args,
292                                  const std::vector<RawShaderArg *> &args);
293 
294 protected:
295    template <typename... Args>
run_shader(Shader shader,const CompileArgs & compile_args,Args &...args)296    void run_shader(Shader shader,
297                    const CompileArgs &compile_args,
298                    Args&... args)
299    {
300       std::vector<RawShaderArg *> raw_args;
301       gather_args(raw_args, args...);
302       run_shader_with_raw_args(shader, compile_args, raw_args);
303    }
304 
305    template <typename... Args>
run_shader(const std::vector<const char * > & sources,unsigned x,unsigned y,unsigned z,Args &...args)306    void run_shader(const std::vector<const char *> &sources,
307                    unsigned x, unsigned y, unsigned z,
308                    Args&... args)
309    {
310       std::vector<RawShaderArg *> raw_args;
311       gather_args(raw_args, args...);
312       CompileArgs compile_args = { x, y, z };
313       run_shader_with_raw_args(compile(sources), compile_args, raw_args);
314    }
315 
316    template <typename... Args>
run_shader(const std::vector<const char * > & sources,const CompileArgs & compile_args,Args &...args)317    void run_shader(const std::vector<const char *> &sources,
318                    const CompileArgs &compile_args,
319                    Args&... args)
320    {
321       std::vector<RawShaderArg *> raw_args;
322       gather_args(raw_args, args...);
323       run_shader_with_raw_args(
324          compile(sources, compile_args.compiler_command_line),
325          compile_args, raw_args);
326    }
327 
328    template <typename... Args>
run_shader(const char * source,unsigned x,unsigned y,unsigned z,Args &...args)329    void run_shader(const char *source,
330                    unsigned x, unsigned y, unsigned z,
331                    Args&... args)
332    {
333       std::vector<RawShaderArg *> raw_args;
334       gather_args(raw_args, args...);
335       CompileArgs compile_args = { x, y, z };
336       run_shader_with_raw_args(compile({ source }), compile_args, raw_args);
337    }
338 
339    IDXGIFactory4 *factory;
340    IDXGIAdapter1 *adapter;
341    ID3D12Device *dev;
342    ID3D12Fence *cmdqueue_fence;
343    ID3D12CommandQueue *cmdqueue;
344    ID3D12CommandAllocator *cmdalloc;
345    ID3D12GraphicsCommandList *cmdlist;
346    ID3D12DescriptorHeap *uav_heap;
347 
348    struct clc_libclc *compiler_ctx;
349 
350    UINT uav_heap_incr;
351    int fence_value;
352 
353    HANDLE event;
354    static PFN_D3D12_SERIALIZE_VERSIONED_ROOT_SIGNATURE D3D12SerializeVersionedRootSignature;
355 };
356