1 // Copyright 2020 The Dawn Authors 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "dawn_native/QueryHelper.h" 16 17 #include "dawn_native/BindGroup.h" 18 #include "dawn_native/BindGroupLayout.h" 19 #include "dawn_native/Buffer.h" 20 #include "dawn_native/CommandEncoder.h" 21 #include "dawn_native/ComputePassEncoder.h" 22 #include "dawn_native/ComputePipeline.h" 23 #include "dawn_native/Device.h" 24 #include "dawn_native/InternalPipelineStore.h" 25 #include "dawn_native/utils/WGPUHelpers.h" 26 27 namespace dawn_native { 28 29 namespace { 30 31 // Assert the offsets in dawn_native::TimestampParams are same with the ones in the shader 32 static_assert(offsetof(dawn_native::TimestampParams, first) == 0, ""); 33 static_assert(offsetof(dawn_native::TimestampParams, count) == 4, ""); 34 static_assert(offsetof(dawn_native::TimestampParams, offset) == 8, ""); 35 static_assert(offsetof(dawn_native::TimestampParams, period) == 12, ""); 36 37 static const char sConvertTimestampsToNanoseconds[] = R"( 38 struct Timestamp { 39 low : u32; 40 high : u32; 41 }; 42 43 [[block]] struct TimestampArr { 44 t : array<Timestamp>; 45 }; 46 47 [[block]] struct AvailabilityArr { 48 v : array<u32>; 49 }; 50 51 [[block]] struct TimestampParams { 52 first : u32; 53 count : u32; 54 offset : u32; 55 period : f32; 56 }; 57 58 [[group(0), binding(0)]] 59 var<storage, read_write> timestamps : TimestampArr; 60 [[group(0), binding(1)]] 61 var<storage, read> availability : AvailabilityArr; 62 [[group(0), binding(2)]] var<uniform> params : TimestampParams; 63 64 65 let sizeofTimestamp : u32 = 8u; 66 67 [[stage(compute), workgroup_size(8, 1, 1)]] 68 fn main([[builtin(global_invocation_id)]] GlobalInvocationID : vec3<u32>) { 69 if (GlobalInvocationID.x >= params.count) { return; } 70 71 var index = GlobalInvocationID.x + params.offset / sizeofTimestamp; 72 73 var timestamp = timestamps.t[index]; 74 75 // Return 0 for the unavailable value. 76 if (availability.v[GlobalInvocationID.x + params.first] == 0u) { 77 timestamps.t[index].low = 0u; 78 timestamps.t[index].high = 0u; 79 return; 80 } 81 82 // Multiply the values in timestamps buffer by the period. 83 var period = params.period; 84 var w = 0u; 85 86 // If the product of low 32-bits and the period does not exceed the maximum of u32, 87 // directly do the multiplication, otherwise, use two u32 to represent the high 88 // 16-bits and low 16-bits of this u32, then multiply them by the period separately. 89 if (timestamp.low <= u32(f32(0xFFFFFFFFu) / period)) { 90 timestamps.t[index].low = u32(round(f32(timestamp.low) * period)); 91 } else { 92 var lo = timestamp.low & 0xFFFFu; 93 var hi = timestamp.low >> 16u; 94 95 var t0 = u32(round(f32(lo) * period)); 96 var t1 = u32(round(f32(hi) * period)) + (t0 >> 16u); 97 w = t1 >> 16u; 98 99 var result = t1 << 16u; 100 result = result | (t0 & 0xFFFFu); 101 timestamps.t[index].low = result; 102 } 103 104 // Get the nearest integer to the float result. For high 32-bits, the round 105 // function will greatly help reduce the accuracy loss of the final result. 106 timestamps.t[index].high = u32(round(f32(timestamp.high) * period)) + w; 107 } 108 )"; 109 GetOrCreateTimestampComputePipeline(DeviceBase * device)110 ResultOrError<ComputePipelineBase*> GetOrCreateTimestampComputePipeline( 111 DeviceBase* device) { 112 InternalPipelineStore* store = device->GetInternalPipelineStore(); 113 114 if (store->timestampComputePipeline == nullptr) { 115 // Create compute shader module if not cached before. 116 if (store->timestampCS == nullptr) { 117 DAWN_TRY_ASSIGN( 118 store->timestampCS, 119 utils::CreateShaderModule(device, sConvertTimestampsToNanoseconds)); 120 } 121 122 // Create binding group layout 123 Ref<BindGroupLayoutBase> bgl; 124 DAWN_TRY_ASSIGN( 125 bgl, utils::MakeBindGroupLayout( 126 device, 127 { 128 {0, wgpu::ShaderStage::Compute, kInternalStorageBufferBinding}, 129 {1, wgpu::ShaderStage::Compute, 130 wgpu::BufferBindingType::ReadOnlyStorage}, 131 {2, wgpu::ShaderStage::Compute, wgpu::BufferBindingType::Uniform}, 132 }, 133 /* allowInternalBinding */ true)); 134 135 // Create pipeline layout 136 Ref<PipelineLayoutBase> layout; 137 DAWN_TRY_ASSIGN(layout, utils::MakeBasicPipelineLayout(device, bgl)); 138 139 // Create ComputePipeline. 140 ComputePipelineDescriptor computePipelineDesc = {}; 141 // Generate the layout based on shader module. 142 computePipelineDesc.layout = layout.Get(); 143 computePipelineDesc.compute.module = store->timestampCS.Get(); 144 computePipelineDesc.compute.entryPoint = "main"; 145 146 DAWN_TRY_ASSIGN(store->timestampComputePipeline, 147 device->CreateComputePipeline(&computePipelineDesc)); 148 } 149 150 return store->timestampComputePipeline.Get(); 151 } 152 153 } // anonymous namespace 154 EncodeConvertTimestampsToNanoseconds(CommandEncoder * encoder,BufferBase * timestamps,BufferBase * availability,BufferBase * params)155 MaybeError EncodeConvertTimestampsToNanoseconds(CommandEncoder* encoder, 156 BufferBase* timestamps, 157 BufferBase* availability, 158 BufferBase* params) { 159 DeviceBase* device = encoder->GetDevice(); 160 161 ComputePipelineBase* pipeline; 162 DAWN_TRY_ASSIGN(pipeline, GetOrCreateTimestampComputePipeline(device)); 163 164 // Prepare bind group layout. 165 Ref<BindGroupLayoutBase> layout; 166 DAWN_TRY_ASSIGN(layout, pipeline->GetBindGroupLayout(0)); 167 168 // Create bind group after all binding entries are set. 169 Ref<BindGroupBase> bindGroup; 170 DAWN_TRY_ASSIGN(bindGroup, 171 utils::MakeBindGroup(device, layout, 172 {{0, timestamps}, {1, availability}, {2, params}})); 173 174 // Create compute encoder and issue dispatch. 175 ComputePassDescriptor passDesc = {}; 176 // TODO(dawn:723): change to not use AcquireRef for reentrant object creation. 177 Ref<ComputePassEncoder> pass = AcquireRef(encoder->APIBeginComputePass(&passDesc)); 178 pass->APISetPipeline(pipeline); 179 pass->APISetBindGroup(0, bindGroup.Get()); 180 pass->APIDispatch( 181 static_cast<uint32_t>((timestamps->GetSize() / sizeof(uint64_t) + 7) / 8)); 182 pass->APIEndPass(); 183 184 return {}; 185 } 186 187 } // namespace dawn_native 188