• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2020 The Dawn Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "dawn_native/QueryHelper.h"
16 
17 #include "dawn_native/BindGroup.h"
18 #include "dawn_native/BindGroupLayout.h"
19 #include "dawn_native/Buffer.h"
20 #include "dawn_native/CommandEncoder.h"
21 #include "dawn_native/ComputePassEncoder.h"
22 #include "dawn_native/ComputePipeline.h"
23 #include "dawn_native/Device.h"
24 #include "dawn_native/InternalPipelineStore.h"
25 #include "dawn_native/utils/WGPUHelpers.h"
26 
27 namespace dawn_native {
28 
29     namespace {
30 
31         // Assert the offsets in dawn_native::TimestampParams are same with the ones in the shader
32         static_assert(offsetof(dawn_native::TimestampParams, first) == 0, "");
33         static_assert(offsetof(dawn_native::TimestampParams, count) == 4, "");
34         static_assert(offsetof(dawn_native::TimestampParams, offset) == 8, "");
35         static_assert(offsetof(dawn_native::TimestampParams, period) == 12, "");
36 
37         static const char sConvertTimestampsToNanoseconds[] = R"(
38             struct Timestamp {
39                 low  : u32;
40                 high : u32;
41             };
42 
43             [[block]] struct TimestampArr {
44                 t : array<Timestamp>;
45             };
46 
47             [[block]] struct AvailabilityArr {
48                 v : array<u32>;
49             };
50 
51             [[block]] struct TimestampParams {
52                 first  : u32;
53                 count  : u32;
54                 offset : u32;
55                 period : f32;
56             };
57 
58             [[group(0), binding(0)]]
59                 var<storage, read_write> timestamps : TimestampArr;
60             [[group(0), binding(1)]]
61                 var<storage, read> availability : AvailabilityArr;
62             [[group(0), binding(2)]] var<uniform> params : TimestampParams;
63 
64 
65             let sizeofTimestamp : u32 = 8u;
66 
67             [[stage(compute), workgroup_size(8, 1, 1)]]
68             fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) {
69                 if (GlobalInvocationID.x >= params.count) { return; }
70 
71                 var index = GlobalInvocationID.x + params.offset / sizeofTimestamp;
72 
73                 var timestamp = timestamps.t[index];
74 
75                 // Return 0 for the unavailable value.
76                 if (availability.v[GlobalInvocationID.x + params.first] == 0u) {
77                     timestamps.t[index].low = 0u;
78                     timestamps.t[index].high = 0u;
79                     return;
80                 }
81 
82                 // Multiply the values in timestamps buffer by the period.
83                 var period = params.period;
84                 var w = 0u;
85 
86                 // If the product of low 32-bits and the period does not exceed the maximum of u32,
87                 // directly do the multiplication, otherwise, use two u32 to represent the high
88                 // 16-bits and low 16-bits of this u32, then multiply them by the period separately.
89                 if (timestamp.low <= u32(f32(0xFFFFFFFFu) / period)) {
90                     timestamps.t[index].low = u32(round(f32(timestamp.low) * period));
91                 } else {
92                     var lo = timestamp.low & 0xFFFFu;
93                     var hi = timestamp.low >> 16u;
94 
95                     var t0 = u32(round(f32(lo) * period));
96                     var t1 = u32(round(f32(hi) * period)) + (t0 >> 16u);
97                     w = t1 >> 16u;
98 
99                     var result = t1 << 16u;
100                     result = result | (t0 & 0xFFFFu);
101                     timestamps.t[index].low = result;
102                 }
103 
104                 // Get the nearest integer to the float result. For high 32-bits, the round
105                 // function will greatly help reduce the accuracy loss of the final result.
106                 timestamps.t[index].high = u32(round(f32(timestamp.high) * period)) + w;
107             }
108         )";
109 
GetOrCreateTimestampComputePipeline(DeviceBase * device)110         ResultOrError<ComputePipelineBase*> GetOrCreateTimestampComputePipeline(
111             DeviceBase* device) {
112             InternalPipelineStore* store = device->GetInternalPipelineStore();
113 
114             if (store->timestampComputePipeline == nullptr) {
115                 // Create compute shader module if not cached before.
116                 if (store->timestampCS == nullptr) {
117                     DAWN_TRY_ASSIGN(
118                         store->timestampCS,
119                         utils::CreateShaderModule(device, sConvertTimestampsToNanoseconds));
120                 }
121 
122                 // Create binding group layout
123                 Ref<BindGroupLayoutBase> bgl;
124                 DAWN_TRY_ASSIGN(
125                     bgl, utils::MakeBindGroupLayout(
126                              device,
127                              {
128                                  {0, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding},
129                                  {1, wgpu::ShaderStage::Compute,
130                                   wgpu::BufferBindingType::ReadOnlyStorage},
131                                  {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform},
132                              },
133                              /* allowInternalBinding */ true));
134 
135                 // Create pipeline layout
136                 Ref<PipelineLayoutBase> layout;
137                 DAWN_TRY_ASSIGN(layout, utils::MakeBasicPipelineLayout(device, bgl));
138 
139                 // Create ComputePipeline.
140                 ComputePipelineDescriptor computePipelineDesc = {};
141                 // Generate the layout based on shader module.
142                 computePipelineDesc.layout = layout.Get();
143                 computePipelineDesc.compute.module = store->timestampCS.Get();
144                 computePipelineDesc.compute.entryPoint = "main";
145 
146                 DAWN_TRY_ASSIGN(store->timestampComputePipeline,
147                                 device->CreateComputePipeline(&computePipelineDesc));
148             }
149 
150             return store->timestampComputePipeline.Get();
151         }
152 
153     }  // anonymous namespace
154 
EncodeConvertTimestampsToNanoseconds(CommandEncoder * encoder,BufferBase * timestamps,BufferBase * availability,BufferBase * params)155     MaybeError EncodeConvertTimestampsToNanoseconds(CommandEncoder* encoder,
156                                                     BufferBase* timestamps,
157                                                     BufferBase* availability,
158                                                     BufferBase* params) {
159         DeviceBase* device = encoder->GetDevice();
160 
161         ComputePipelineBase* pipeline;
162         DAWN_TRY_ASSIGN(pipeline, GetOrCreateTimestampComputePipeline(device));
163 
164         // Prepare bind group layout.
165         Ref<BindGroupLayoutBase> layout;
166         DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0));
167 
168         // Create bind group after all binding entries are set.
169         Ref<BindGroupBase> bindGroup;
170         DAWN_TRY_ASSIGN(bindGroup,
171                         utils::MakeBindGroup(device, layout,
172                                              {{0, timestamps}, {1, availability}, {2, params}}));
173 
174         // Create compute encoder and issue dispatch.
175         ComputePassDescriptor passDesc = {};
176         // TODO(dawn:723): change to not use AcquireRef for reentrant object creation.
177         Ref<ComputePassEncoder> pass = AcquireRef(encoder->APIBeginComputePass(&passDesc));
178         pass->APISetPipeline(pipeline);
179         pass->APISetBindGroup(0, bindGroup.Get());
180         pass->APIDispatch(
181             static_cast<uint32_t>((timestamps->GetSize() / sizeof(uint64_t) + 7) / 8));
182         pass->APIEndPass();
183 
184         return {};
185     }
186 
187 }  // namespace dawn_native
188