• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 The vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9 
10 use crate::check_errors;
11 use crate::descriptor_set::layout::DescriptorSetLayout;
12 use crate::device::Device;
13 use crate::device::DeviceOwned;
14 use crate::pipeline::cache::PipelineCache;
15 use crate::pipeline::layout::PipelineLayout;
16 use crate::pipeline::layout::PipelineLayoutCreationError;
17 use crate::pipeline::layout::PipelineLayoutSupersetError;
18 use crate::pipeline::shader::EntryPointAbstract;
19 use crate::pipeline::shader::SpecializationConstants;
20 use crate::Error;
21 use crate::OomError;
22 use crate::SafeDeref;
23 use crate::VulkanObject;
24 use std::error;
25 use std::fmt;
26 use std::marker::PhantomData;
27 use std::mem;
28 use std::mem::MaybeUninit;
29 use std::ptr;
30 use std::sync::Arc;
31 
32 /// A pipeline object that describes to the Vulkan implementation how it should perform compute
33 /// operations.
34 ///
35 /// The template parameter contains the descriptor set to use with this pipeline.
36 ///
37 /// All compute pipeline objects implement the `ComputePipelineAbstract` trait. You can turn any
38 /// `Arc<ComputePipeline>` into an `Arc<ComputePipelineAbstract>` if necessary.
39 ///
40 /// Pass an optional `Arc` to a `PipelineCache` to enable pipeline caching. The vulkan
41 /// implementation will handle the `PipelineCache` and check if it is available.
42 /// Check the documentation of the `PipelineCache` for more information.
43 pub struct ComputePipeline {
44     inner: Inner,
45     pipeline_layout: Arc<PipelineLayout>,
46 }
47 
48 struct Inner {
49     pipeline: ash::vk::Pipeline,
50     device: Arc<Device>,
51 }
52 
53 impl ComputePipeline {
54     /// Builds a new `ComputePipeline`.
new<Cs, Css>( device: Arc<Device>, shader: &Cs, spec_constants: &Css, cache: Option<Arc<PipelineCache>>, ) -> Result<ComputePipeline, ComputePipelineCreationError> where Cs: EntryPointAbstract, Css: SpecializationConstants,55     pub fn new<Cs, Css>(
56         device: Arc<Device>,
57         shader: &Cs,
58         spec_constants: &Css,
59         cache: Option<Arc<PipelineCache>>,
60     ) -> Result<ComputePipeline, ComputePipelineCreationError>
61     where
62         Cs: EntryPointAbstract,
63         Css: SpecializationConstants,
64     {
65         unsafe {
66             let descriptor_set_layouts = shader
67                 .descriptor_set_layout_descs()
68                 .iter()
69                 .map(|desc| {
70                     Ok(Arc::new(DescriptorSetLayout::new(
71                         device.clone(),
72                         desc.clone(),
73                     )?))
74                 })
75                 .collect::<Result<Vec<_>, OomError>>()?;
76             let pipeline_layout = Arc::new(PipelineLayout::new(
77                 device.clone(),
78                 descriptor_set_layouts,
79                 shader.push_constant_range().iter().cloned(),
80             )?);
81             ComputePipeline::with_unchecked_pipeline_layout(
82                 device,
83                 shader,
84                 spec_constants,
85                 pipeline_layout,
86                 cache,
87             )
88         }
89     }
90 
91     /// Builds a new `ComputePipeline` with a specific pipeline layout.
92     ///
93     /// An error will be returned if the pipeline layout isn't a superset of what the shader
94     /// uses.
with_pipeline_layout<Cs, Css>( device: Arc<Device>, shader: &Cs, spec_constants: &Css, pipeline_layout: Arc<PipelineLayout>, cache: Option<Arc<PipelineCache>>, ) -> Result<ComputePipeline, ComputePipelineCreationError> where Cs: EntryPointAbstract, Css: SpecializationConstants,95     pub fn with_pipeline_layout<Cs, Css>(
96         device: Arc<Device>,
97         shader: &Cs,
98         spec_constants: &Css,
99         pipeline_layout: Arc<PipelineLayout>,
100         cache: Option<Arc<PipelineCache>>,
101     ) -> Result<ComputePipeline, ComputePipelineCreationError>
102     where
103         Cs: EntryPointAbstract,
104         Css: SpecializationConstants,
105     {
106         if Css::descriptors() != shader.spec_constants() {
107             return Err(ComputePipelineCreationError::IncompatibleSpecializationConstants);
108         }
109 
110         unsafe {
111             pipeline_layout.ensure_superset_of(
112                 shader.descriptor_set_layout_descs(),
113                 shader.push_constant_range(),
114             )?;
115             ComputePipeline::with_unchecked_pipeline_layout(
116                 device,
117                 shader,
118                 spec_constants,
119                 pipeline_layout,
120                 cache,
121             )
122         }
123     }
124 
125     /// Same as `with_pipeline_layout`, but doesn't check whether the pipeline layout is a
126     /// superset of what the shader expects.
with_unchecked_pipeline_layout<Cs, Css>( device: Arc<Device>, shader: &Cs, spec_constants: &Css, pipeline_layout: Arc<PipelineLayout>, cache: Option<Arc<PipelineCache>>, ) -> Result<ComputePipeline, ComputePipelineCreationError> where Cs: EntryPointAbstract, Css: SpecializationConstants,127     pub unsafe fn with_unchecked_pipeline_layout<Cs, Css>(
128         device: Arc<Device>,
129         shader: &Cs,
130         spec_constants: &Css,
131         pipeline_layout: Arc<PipelineLayout>,
132         cache: Option<Arc<PipelineCache>>,
133     ) -> Result<ComputePipeline, ComputePipelineCreationError>
134     where
135         Cs: EntryPointAbstract,
136         Css: SpecializationConstants,
137     {
138         let fns = device.fns();
139 
140         let pipeline = {
141             let spec_descriptors = Css::descriptors();
142             let specialization = ash::vk::SpecializationInfo {
143                 map_entry_count: spec_descriptors.len() as u32,
144                 p_map_entries: spec_descriptors.as_ptr() as *const _,
145                 data_size: mem::size_of_val(spec_constants),
146                 p_data: spec_constants as *const Css as *const _,
147             };
148 
149             let stage = ash::vk::PipelineShaderStageCreateInfo {
150                 flags: ash::vk::PipelineShaderStageCreateFlags::empty(),
151                 stage: ash::vk::ShaderStageFlags::COMPUTE,
152                 module: shader.module().internal_object(),
153                 p_name: shader.name().as_ptr(),
154                 p_specialization_info: if specialization.data_size == 0 {
155                     ptr::null()
156                 } else {
157                     &specialization
158                 },
159                 ..Default::default()
160             };
161 
162             let infos = ash::vk::ComputePipelineCreateInfo {
163                 flags: ash::vk::PipelineCreateFlags::empty(),
164                 stage,
165                 layout: pipeline_layout.internal_object(),
166                 base_pipeline_handle: ash::vk::Pipeline::null(),
167                 base_pipeline_index: 0,
168                 ..Default::default()
169             };
170 
171             let cache_handle = match cache {
172                 Some(ref cache) => cache.internal_object(),
173                 None => ash::vk::PipelineCache::null(),
174             };
175 
176             let mut output = MaybeUninit::uninit();
177             check_errors(fns.v1_0.create_compute_pipelines(
178                 device.internal_object(),
179                 cache_handle,
180                 1,
181                 &infos,
182                 ptr::null(),
183                 output.as_mut_ptr(),
184             ))?;
185             output.assume_init()
186         };
187 
188         Ok(ComputePipeline {
189             inner: Inner {
190                 device: device.clone(),
191                 pipeline: pipeline,
192             },
193             pipeline_layout: pipeline_layout,
194         })
195     }
196 }
197 
198 impl fmt::Debug for ComputePipeline {
199     #[inline]
fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>200     fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
201         write!(fmt, "<Vulkan compute pipeline {:?}>", self.inner.pipeline)
202     }
203 }
204 
205 impl ComputePipeline {
206     /// Returns the `Device` this compute pipeline was created with.
207     #[inline]
device(&self) -> &Arc<Device>208     pub fn device(&self) -> &Arc<Device> {
209         &self.inner.device
210     }
211 }
212 
213 /// Trait implemented on all compute pipelines.
214 pub unsafe trait ComputePipelineAbstract: DeviceOwned {
215     /// Returns an opaque object that represents the inside of the compute pipeline.
inner(&self) -> ComputePipelineSys216     fn inner(&self) -> ComputePipelineSys;
217 
218     /// Returns the pipeline layout used in this compute pipeline.
layout(&self) -> &Arc<PipelineLayout>219     fn layout(&self) -> &Arc<PipelineLayout>;
220 }
221 
222 unsafe impl ComputePipelineAbstract for ComputePipeline {
223     #[inline]
inner(&self) -> ComputePipelineSys224     fn inner(&self) -> ComputePipelineSys {
225         ComputePipelineSys(self.inner.pipeline, PhantomData)
226     }
227 
228     #[inline]
layout(&self) -> &Arc<PipelineLayout>229     fn layout(&self) -> &Arc<PipelineLayout> {
230         &self.pipeline_layout
231     }
232 }
233 
234 unsafe impl<T> ComputePipelineAbstract for T
235 where
236     T: SafeDeref,
237     T::Target: ComputePipelineAbstract,
238 {
239     #[inline]
inner(&self) -> ComputePipelineSys240     fn inner(&self) -> ComputePipelineSys {
241         (**self).inner()
242     }
243 
244     #[inline]
layout(&self) -> &Arc<PipelineLayout>245     fn layout(&self) -> &Arc<PipelineLayout> {
246         (**self).layout()
247     }
248 }
249 
250 /// Opaque object that represents the inside of the compute pipeline. Can be made into a trait
251 /// object.
252 #[derive(Debug, Copy, Clone)]
253 pub struct ComputePipelineSys<'a>(ash::vk::Pipeline, PhantomData<&'a ()>);
254 
255 unsafe impl<'a> VulkanObject for ComputePipelineSys<'a> {
256     type Object = ash::vk::Pipeline;
257 
258     #[inline]
internal_object(&self) -> ash::vk::Pipeline259     fn internal_object(&self) -> ash::vk::Pipeline {
260         self.0
261     }
262 }
263 
264 unsafe impl DeviceOwned for ComputePipeline {
265     #[inline]
device(&self) -> &Arc<Device>266     fn device(&self) -> &Arc<Device> {
267         self.device()
268     }
269 }
270 
271 unsafe impl VulkanObject for ComputePipeline {
272     type Object = ash::vk::Pipeline;
273 
274     #[inline]
internal_object(&self) -> ash::vk::Pipeline275     fn internal_object(&self) -> ash::vk::Pipeline {
276         self.inner.pipeline
277     }
278 }
279 
280 impl Drop for Inner {
281     #[inline]
drop(&mut self)282     fn drop(&mut self) {
283         unsafe {
284             let fns = self.device.fns();
285             fns.v1_0
286                 .destroy_pipeline(self.device.internal_object(), self.pipeline, ptr::null());
287         }
288     }
289 }
290 
291 /// Error that can happen when creating a compute pipeline.
292 #[derive(Clone, Debug, PartialEq, Eq)]
293 pub enum ComputePipelineCreationError {
294     /// Not enough memory.
295     OomError(OomError),
296     /// Error while creating the pipeline layout object.
297     PipelineLayoutCreationError(PipelineLayoutCreationError),
298     /// The pipeline layout is not compatible with what the shader expects.
299     IncompatiblePipelineLayout(PipelineLayoutSupersetError),
300     /// The provided specialization constants are not compatible with what the shader expects.
301     IncompatibleSpecializationConstants,
302 }
303 
304 impl error::Error for ComputePipelineCreationError {
305     #[inline]
source(&self) -> Option<&(dyn error::Error + 'static)>306     fn source(&self) -> Option<&(dyn error::Error + 'static)> {
307         match *self {
308             ComputePipelineCreationError::OomError(ref err) => Some(err),
309             ComputePipelineCreationError::PipelineLayoutCreationError(ref err) => Some(err),
310             ComputePipelineCreationError::IncompatiblePipelineLayout(ref err) => Some(err),
311             ComputePipelineCreationError::IncompatibleSpecializationConstants => None,
312         }
313     }
314 }
315 
316 impl fmt::Display for ComputePipelineCreationError {
317     #[inline]
fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>318     fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
319         write!(
320             fmt,
321             "{}",
322             match *self {
323                 ComputePipelineCreationError::OomError(_) => "not enough memory available",
324                 ComputePipelineCreationError::PipelineLayoutCreationError(_) => {
325                     "error while creating the pipeline layout object"
326                 }
327                 ComputePipelineCreationError::IncompatiblePipelineLayout(_) => {
328                     "the pipeline layout is not compatible with what the shader expects"
329                 }
330                 ComputePipelineCreationError::IncompatibleSpecializationConstants => {
331                     "the provided specialization constants are not compatible with what the shader expects"
332                 }
333             }
334         )
335     }
336 }
337 
338 impl From<OomError> for ComputePipelineCreationError {
339     #[inline]
from(err: OomError) -> ComputePipelineCreationError340     fn from(err: OomError) -> ComputePipelineCreationError {
341         ComputePipelineCreationError::OomError(err)
342     }
343 }
344 
345 impl From<PipelineLayoutCreationError> for ComputePipelineCreationError {
346     #[inline]
from(err: PipelineLayoutCreationError) -> ComputePipelineCreationError347     fn from(err: PipelineLayoutCreationError) -> ComputePipelineCreationError {
348         ComputePipelineCreationError::PipelineLayoutCreationError(err)
349     }
350 }
351 
352 impl From<PipelineLayoutSupersetError> for ComputePipelineCreationError {
353     #[inline]
from(err: PipelineLayoutSupersetError) -> ComputePipelineCreationError354     fn from(err: PipelineLayoutSupersetError) -> ComputePipelineCreationError {
355         ComputePipelineCreationError::IncompatiblePipelineLayout(err)
356     }
357 }
358 
359 impl From<Error> for ComputePipelineCreationError {
360     #[inline]
from(err: Error) -> ComputePipelineCreationError361     fn from(err: Error) -> ComputePipelineCreationError {
362         match err {
363             err @ Error::OutOfHostMemory => {
364                 ComputePipelineCreationError::OomError(OomError::from(err))
365             }
366             err @ Error::OutOfDeviceMemory => {
367                 ComputePipelineCreationError::OomError(OomError::from(err))
368             }
369             _ => panic!("unexpected error: {:?}", err),
370         }
371     }
372 }
373 
374 #[cfg(test)]
375 mod tests {
376     use crate::buffer::BufferUsage;
377     use crate::buffer::CpuAccessibleBuffer;
378     use crate::command_buffer::AutoCommandBufferBuilder;
379     use crate::command_buffer::CommandBufferUsage;
380     use crate::descriptor_set::layout::DescriptorBufferDesc;
381     use crate::descriptor_set::layout::DescriptorDesc;
382     use crate::descriptor_set::layout::DescriptorDescTy;
383     use crate::descriptor_set::layout::DescriptorSetDesc;
384     use crate::descriptor_set::PersistentDescriptorSet;
385     use crate::pipeline::shader::ShaderModule;
386     use crate::pipeline::shader::ShaderStages;
387     use crate::pipeline::shader::SpecializationConstants;
388     use crate::pipeline::shader::SpecializationMapEntry;
389     use crate::pipeline::ComputePipeline;
390     use crate::pipeline::ComputePipelineAbstract;
391     use crate::sync::now;
392     use crate::sync::GpuFuture;
393     use std::ffi::CStr;
394     use std::sync::Arc;
395 
396     // TODO: test for basic creation
397     // TODO: test for pipeline layout error
398 
399     #[test]
spec_constants()400     fn spec_constants() {
401         // This test checks whether specialization constants work.
402         // It executes a single compute shader (one invocation) that writes the value of a spec.
403         // constant to a buffer. The buffer content is then checked for the right value.
404 
405         let (device, queue) = gfx_dev_and_queue!();
406 
407         let module = unsafe {
408             /*
409             #version 450
410 
411             layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
412 
413             layout(constant_id = 83) const int VALUE = 0xdeadbeef;
414 
415             layout(set = 0, binding = 0) buffer Output {
416                 int write;
417             } write;
418 
419             void main() {
420                 write.write = VALUE;
421             }
422             */
423             const MODULE: [u8; 480] = [
424                 3, 2, 35, 7, 0, 0, 1, 0, 1, 0, 8, 0, 14, 0, 0, 0, 0, 0, 0, 0, 17, 0, 2, 0, 1, 0, 0,
425                 0, 11, 0, 6, 0, 1, 0, 0, 0, 71, 76, 83, 76, 46, 115, 116, 100, 46, 52, 53, 48, 0,
426                 0, 0, 0, 14, 0, 3, 0, 0, 0, 0, 0, 1, 0, 0, 0, 15, 0, 5, 0, 5, 0, 0, 0, 4, 0, 0, 0,
427                 109, 97, 105, 110, 0, 0, 0, 0, 16, 0, 6, 0, 4, 0, 0, 0, 17, 0, 0, 0, 1, 0, 0, 0, 1,
428                 0, 0, 0, 1, 0, 0, 0, 3, 0, 3, 0, 2, 0, 0, 0, 194, 1, 0, 0, 5, 0, 4, 0, 4, 0, 0, 0,
429                 109, 97, 105, 110, 0, 0, 0, 0, 5, 0, 4, 0, 7, 0, 0, 0, 79, 117, 116, 112, 117, 116,
430                 0, 0, 6, 0, 5, 0, 7, 0, 0, 0, 0, 0, 0, 0, 119, 114, 105, 116, 101, 0, 0, 0, 5, 0,
431                 4, 0, 9, 0, 0, 0, 119, 114, 105, 116, 101, 0, 0, 0, 5, 0, 4, 0, 11, 0, 0, 0, 86,
432                 65, 76, 85, 69, 0, 0, 0, 72, 0, 5, 0, 7, 0, 0, 0, 0, 0, 0, 0, 35, 0, 0, 0, 0, 0, 0,
433                 0, 71, 0, 3, 0, 7, 0, 0, 0, 3, 0, 0, 0, 71, 0, 4, 0, 9, 0, 0, 0, 34, 0, 0, 0, 0, 0,
434                 0, 0, 71, 0, 4, 0, 9, 0, 0, 0, 33, 0, 0, 0, 0, 0, 0, 0, 71, 0, 4, 0, 11, 0, 0, 0,
435                 1, 0, 0, 0, 83, 0, 0, 0, 19, 0, 2, 0, 2, 0, 0, 0, 33, 0, 3, 0, 3, 0, 0, 0, 2, 0, 0,
436                 0, 21, 0, 4, 0, 6, 0, 0, 0, 32, 0, 0, 0, 1, 0, 0, 0, 30, 0, 3, 0, 7, 0, 0, 0, 6, 0,
437                 0, 0, 32, 0, 4, 0, 8, 0, 0, 0, 2, 0, 0, 0, 7, 0, 0, 0, 59, 0, 4, 0, 8, 0, 0, 0, 9,
438                 0, 0, 0, 2, 0, 0, 0, 43, 0, 4, 0, 6, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 50, 0, 4, 0,
439                 6, 0, 0, 0, 11, 0, 0, 0, 239, 190, 173, 222, 32, 0, 4, 0, 12, 0, 0, 0, 2, 0, 0, 0,
440                 6, 0, 0, 0, 54, 0, 5, 0, 2, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 248, 0, 2,
441                 0, 5, 0, 0, 0, 65, 0, 5, 0, 12, 0, 0, 0, 13, 0, 0, 0, 9, 0, 0, 0, 10, 0, 0, 0, 62,
442                 0, 3, 0, 13, 0, 0, 0, 11, 0, 0, 0, 253, 0, 1, 0, 56, 0, 1, 0,
443             ];
444             ShaderModule::new(device.clone(), &MODULE).unwrap()
445         };
446 
447         let shader = unsafe {
448             static NAME: [u8; 5] = [109, 97, 105, 110, 0]; // "main"
449             module.compute_entry_point(
450                 CStr::from_ptr(NAME.as_ptr() as *const _),
451                 [DescriptorSetDesc::new([Some(DescriptorDesc {
452                     ty: DescriptorDescTy::Buffer(DescriptorBufferDesc {
453                         dynamic: Some(false),
454                         storage: true,
455                     }),
456                     array_count: 1,
457                     stages: ShaderStages {
458                         compute: true,
459                         ..ShaderStages::none()
460                     },
461                     readonly: true,
462                 })])],
463                 None,
464                 SpecConsts::descriptors(),
465             )
466         };
467 
468         #[derive(Debug, Copy, Clone)]
469         #[allow(non_snake_case)]
470         #[repr(C)]
471         struct SpecConsts {
472             VALUE: i32,
473         }
474         unsafe impl SpecializationConstants for SpecConsts {
475             fn descriptors() -> &'static [SpecializationMapEntry] {
476                 static DESCRIPTORS: [SpecializationMapEntry; 1] = [SpecializationMapEntry {
477                     constant_id: 83,
478                     offset: 0,
479                     size: 4,
480                 }];
481                 &DESCRIPTORS
482             }
483         }
484 
485         let pipeline = Arc::new(
486             ComputePipeline::new(
487                 device.clone(),
488                 &shader,
489                 &SpecConsts { VALUE: 0x12345678 },
490                 None,
491             )
492             .unwrap(),
493         );
494 
495         let data_buffer =
496             CpuAccessibleBuffer::from_data(device.clone(), BufferUsage::all(), false, 0).unwrap();
497         let layout = pipeline.layout().descriptor_set_layouts().get(0).unwrap();
498         let set = PersistentDescriptorSet::start(layout.clone())
499             .add_buffer(data_buffer.clone())
500             .unwrap()
501             .build()
502             .unwrap();
503 
504         let mut cbb = AutoCommandBufferBuilder::primary(
505             device.clone(),
506             queue.family(),
507             CommandBufferUsage::OneTimeSubmit,
508         )
509         .unwrap();
510         cbb.dispatch([1, 1, 1], pipeline.clone(), set, ()).unwrap();
511         let cb = cbb.build().unwrap();
512 
513         let future = now(device.clone())
514             .then_execute(queue.clone(), cb)
515             .unwrap()
516             .then_signal_fence_and_flush()
517             .unwrap();
518         future.wait(None).unwrap();
519 
520         let data_buffer_content = data_buffer.read().unwrap();
521         assert_eq!(*data_buffer_content, 0x12345678);
522     }
523 }
524