• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright (c) 2016 The vulkano developers
2 // Licensed under the Apache License, Version 2.0
3 // <LICENSE-APACHE or
4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT
5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
6 // at your option. All files in the project carrying such
7 // notice may not be copied, modified, or distributed except
8 // according to those terms.
9 
10 //! Stage of a graphics pipeline.
11 //!
12 //! In Vulkan, shaders are grouped in *shader modules*. Each shader module is built from SPIR-V
13 //! code and can contain one or more entry points. Note that for the moment the official
14 //! GLSL-to-SPIR-V compiler does not support multiple entry points.
15 //!
16 //! The vulkano library does not provide any functionality that checks and introspects the SPIR-V
17 //! code, therefore the whole shader-related API is unsafe. You are encouraged to use the
18 //! `vulkano-shaders` crate that will generate Rust code that wraps around vulkano's shaders API.
19 
20 use crate::check_errors;
21 use crate::descriptor_set::layout::DescriptorSetDesc;
22 use crate::device::Device;
23 use crate::format::Format;
24 use crate::pipeline::input_assembly::PrimitiveTopology;
25 use crate::pipeline::layout::PipelineLayoutPcRange;
26 use crate::sync::PipelineStages;
27 use crate::OomError;
28 use crate::VulkanObject;
29 use smallvec::SmallVec;
30 use std::borrow::Cow;
31 use std::error;
32 use std::ffi::CStr;
33 use std::fmt;
34 use std::mem;
35 use std::mem::MaybeUninit;
36 use std::ops::BitOr;
37 use std::ops::Range;
38 use std::ptr;
39 use std::sync::Arc;
40 
41 /// Contains SPIR-V code with one or more entry points.
42 ///
43 /// Note that it is advised to wrap around a `ShaderModule` with a struct that is different for
44 /// each shader.
45 #[derive(Debug)]
46 pub struct ShaderModule {
47     // The module.
48     module: ash::vk::ShaderModule,
49     // Pointer to the device.
50     device: Arc<Device>,
51 }
52 
53 impl ShaderModule {
54     /// Builds a new shader module from SPIR-V bytes.
55     ///
56     /// # Safety
57     ///
58     /// - The SPIR-V code is not validated.
59     /// - The SPIR-V code may require some features that are not enabled. This isn't checked by
60     ///   this function either.
61     ///
new(device: Arc<Device>, spirv: &[u8]) -> Result<Arc<ShaderModule>, OomError>62     pub unsafe fn new(device: Arc<Device>, spirv: &[u8]) -> Result<Arc<ShaderModule>, OomError> {
63         debug_assert!((spirv.len() % 4) == 0);
64         Self::from_ptr(device, spirv.as_ptr() as *const _, spirv.len())
65     }
66 
67     /// Builds a new shader module from SPIR-V 32-bit words.
68     ///
69     /// # Safety
70     ///
71     /// - The SPIR-V code is not validated.
72     /// - The SPIR-V code may require some features that are not enabled. This isn't checked by
73     ///   this function either.
74     ///
from_words( device: Arc<Device>, spirv: &[u32], ) -> Result<Arc<ShaderModule>, OomError>75     pub unsafe fn from_words(
76         device: Arc<Device>,
77         spirv: &[u32],
78     ) -> Result<Arc<ShaderModule>, OomError> {
79         Self::from_ptr(device, spirv.as_ptr(), spirv.len() * mem::size_of::<u32>())
80     }
81 
82     /// Builds a new shader module from SPIR-V.
83     ///
84     /// # Safety
85     ///
86     /// - The SPIR-V code is not validated.
87     /// - The SPIR-V code may require some features that are not enabled. This isn't checked by
88     ///   this function either.
89     ///
from_ptr( device: Arc<Device>, spirv: *const u32, spirv_len: usize, ) -> Result<Arc<ShaderModule>, OomError>90     unsafe fn from_ptr(
91         device: Arc<Device>,
92         spirv: *const u32,
93         spirv_len: usize,
94     ) -> Result<Arc<ShaderModule>, OomError> {
95         let module = {
96             let infos = ash::vk::ShaderModuleCreateInfo {
97                 flags: ash::vk::ShaderModuleCreateFlags::empty(),
98                 code_size: spirv_len,
99                 p_code: spirv,
100                 ..Default::default()
101             };
102 
103             let fns = device.fns();
104             let mut output = MaybeUninit::uninit();
105             check_errors(fns.v1_0.create_shader_module(
106                 device.internal_object(),
107                 &infos,
108                 ptr::null(),
109                 output.as_mut_ptr(),
110             ))?;
111             output.assume_init()
112         };
113 
114         Ok(Arc::new(ShaderModule {
115             module: module,
116             device: device,
117         }))
118     }
119 
120     /// Gets access to an entry point contained in this module.
121     ///
122     /// This is purely a *logical* operation. It returns a struct that *represents* the entry
123     /// point but doesn't actually do anything.
124     ///
125     /// # Safety
126     ///
127     /// - The user must check that the entry point exists in the module, as this is not checked
128     ///   by Vulkan.
129     /// - The input, output and layout must correctly describe the input, output and layout used
130     ///   by this stage.
131     ///
graphics_entry_point<'a, D>( &'a self, name: &'a CStr, descriptor_set_layout_descs: D, push_constant_range: Option<PipelineLayoutPcRange>, spec_constants: &'static [SpecializationMapEntry], input: ShaderInterface, output: ShaderInterface, ty: GraphicsShaderType, ) -> GraphicsEntryPoint<'a> where D: IntoIterator<Item = DescriptorSetDesc>,132     pub unsafe fn graphics_entry_point<'a, D>(
133         &'a self,
134         name: &'a CStr,
135         descriptor_set_layout_descs: D,
136         push_constant_range: Option<PipelineLayoutPcRange>,
137         spec_constants: &'static [SpecializationMapEntry],
138         input: ShaderInterface,
139         output: ShaderInterface,
140         ty: GraphicsShaderType,
141     ) -> GraphicsEntryPoint<'a>
142     where
143         D: IntoIterator<Item = DescriptorSetDesc>,
144     {
145         GraphicsEntryPoint {
146             module: self,
147             name,
148             descriptor_set_layout_descs: descriptor_set_layout_descs.into_iter().collect(),
149             push_constant_range,
150             spec_constants,
151             input,
152             output,
153             ty,
154         }
155     }
156 
157     /// Gets access to an entry point contained in this module.
158     ///
159     /// This is purely a *logical* operation. It returns a struct that *represents* the entry
160     /// point but doesn't actually do anything.
161     ///
162     /// # Safety
163     ///
164     /// - The user must check that the entry point exists in the module, as this is not checked
165     ///   by Vulkan.
166     /// - The layout must correctly describe the layout used by this stage.
167     ///
168     #[inline]
compute_entry_point<'a, D>( &'a self, name: &'a CStr, descriptor_set_layout_descs: D, push_constant_range: Option<PipelineLayoutPcRange>, spec_constants: &'static [SpecializationMapEntry], ) -> ComputeEntryPoint<'a> where D: IntoIterator<Item = DescriptorSetDesc>,169     pub unsafe fn compute_entry_point<'a, D>(
170         &'a self,
171         name: &'a CStr,
172         descriptor_set_layout_descs: D,
173         push_constant_range: Option<PipelineLayoutPcRange>,
174         spec_constants: &'static [SpecializationMapEntry],
175     ) -> ComputeEntryPoint<'a>
176     where
177         D: IntoIterator<Item = DescriptorSetDesc>,
178     {
179         ComputeEntryPoint {
180             module: self,
181             name,
182             descriptor_set_layout_descs: descriptor_set_layout_descs.into_iter().collect(),
183             push_constant_range,
184             spec_constants,
185         }
186     }
187 }
188 
189 unsafe impl VulkanObject for ShaderModule {
190     type Object = ash::vk::ShaderModule;
191 
192     #[inline]
internal_object(&self) -> ash::vk::ShaderModule193     fn internal_object(&self) -> ash::vk::ShaderModule {
194         self.module
195     }
196 }
197 
198 impl Drop for ShaderModule {
199     #[inline]
drop(&mut self)200     fn drop(&mut self) {
201         unsafe {
202             let fns = self.device.fns();
203             fns.v1_0
204                 .destroy_shader_module(self.device.internal_object(), self.module, ptr::null());
205         }
206     }
207 }
208 
209 pub unsafe trait EntryPointAbstract {
210     /// Returns the module this entry point comes from.
module(&self) -> &ShaderModule211     fn module(&self) -> &ShaderModule;
212 
213     /// Returns the name of the entry point.
name(&self) -> &CStr214     fn name(&self) -> &CStr;
215 
216     /// Returns a description of the descriptor set layouts.
descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]217     fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc];
218 
219     /// Returns the push constant ranges.
push_constant_range(&self) -> &Option<PipelineLayoutPcRange>220     fn push_constant_range(&self) -> &Option<PipelineLayoutPcRange>;
221 
222     /// Returns the layout of the specialization constants.
spec_constants(&self) -> &[SpecializationMapEntry]223     fn spec_constants(&self) -> &[SpecializationMapEntry];
224 }
225 
226 /// Represents a shader entry point in a shader module.
227 ///
228 /// Can be obtained by calling `entry_point()` on the shader module.
229 #[derive(Clone, Debug)]
230 pub struct GraphicsEntryPoint<'a> {
231     module: &'a ShaderModule,
232     name: &'a CStr,
233 
234     descriptor_set_layout_descs: SmallVec<[DescriptorSetDesc; 16]>,
235     push_constant_range: Option<PipelineLayoutPcRange>,
236     spec_constants: &'static [SpecializationMapEntry],
237     input: ShaderInterface,
238     output: ShaderInterface,
239     ty: GraphicsShaderType,
240 }
241 
242 impl<'a> GraphicsEntryPoint<'a> {
243     /// Returns the input attributes used by the shader stage.
244     #[inline]
input(&self) -> &ShaderInterface245     pub fn input(&self) -> &ShaderInterface {
246         &self.input
247     }
248 
249     /// Returns the output attributes used by the shader stage.
250     #[inline]
output(&self) -> &ShaderInterface251     pub fn output(&self) -> &ShaderInterface {
252         &self.output
253     }
254 
255     /// Returns the type of shader.
256     #[inline]
ty(&self) -> GraphicsShaderType257     pub fn ty(&self) -> GraphicsShaderType {
258         self.ty
259     }
260 }
261 
262 unsafe impl<'a> EntryPointAbstract for GraphicsEntryPoint<'a> {
263     #[inline]
module(&self) -> &ShaderModule264     fn module(&self) -> &ShaderModule {
265         self.module
266     }
267 
268     #[inline]
name(&self) -> &CStr269     fn name(&self) -> &CStr {
270         self.name
271     }
272 
273     #[inline]
descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]274     fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc] {
275         &self.descriptor_set_layout_descs
276     }
277 
278     #[inline]
push_constant_range(&self) -> &Option<PipelineLayoutPcRange>279     fn push_constant_range(&self) -> &Option<PipelineLayoutPcRange> {
280         &self.push_constant_range
281     }
282 
283     #[inline]
spec_constants(&self) -> &[SpecializationMapEntry]284     fn spec_constants(&self) -> &[SpecializationMapEntry] {
285         self.spec_constants
286     }
287 }
288 
289 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
290 pub enum GraphicsShaderType {
291     Vertex,
292     TessellationControl,
293     TessellationEvaluation,
294     Geometry(GeometryShaderExecutionMode),
295     Fragment,
296 }
297 
298 /// Declares which type of primitives are expected by the geometry shader.
299 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
300 pub enum GeometryShaderExecutionMode {
301     Points,
302     Lines,
303     LinesWithAdjacency,
304     Triangles,
305     TrianglesWithAdjacency,
306 }
307 
308 impl GeometryShaderExecutionMode {
309     /// Returns true if the given primitive topology can be used with this execution mode.
310     #[inline]
matches(&self, input: PrimitiveTopology) -> bool311     pub fn matches(&self, input: PrimitiveTopology) -> bool {
312         match (*self, input) {
313             (GeometryShaderExecutionMode::Points, PrimitiveTopology::PointList) => true,
314             (GeometryShaderExecutionMode::Lines, PrimitiveTopology::LineList) => true,
315             (GeometryShaderExecutionMode::Lines, PrimitiveTopology::LineStrip) => true,
316             (
317                 GeometryShaderExecutionMode::LinesWithAdjacency,
318                 PrimitiveTopology::LineListWithAdjacency,
319             ) => true,
320             (
321                 GeometryShaderExecutionMode::LinesWithAdjacency,
322                 PrimitiveTopology::LineStripWithAdjacency,
323             ) => true,
324             (GeometryShaderExecutionMode::Triangles, PrimitiveTopology::TriangleList) => true,
325             (GeometryShaderExecutionMode::Triangles, PrimitiveTopology::TriangleStrip) => true,
326             (GeometryShaderExecutionMode::Triangles, PrimitiveTopology::TriangleFan) => true,
327             (
328                 GeometryShaderExecutionMode::TrianglesWithAdjacency,
329                 PrimitiveTopology::TriangleListWithAdjacency,
330             ) => true,
331             (
332                 GeometryShaderExecutionMode::TrianglesWithAdjacency,
333                 PrimitiveTopology::TriangleStripWithAdjacency,
334             ) => true,
335             _ => false,
336         }
337     }
338 }
339 
340 /// Represents the entry point of a compute shader in a shader module.
341 ///
342 /// Can be obtained by calling `compute_shader_entry_point()` on the shader module.
343 #[derive(Debug, Clone)]
344 pub struct ComputeEntryPoint<'a> {
345     module: &'a ShaderModule,
346     name: &'a CStr,
347     descriptor_set_layout_descs: SmallVec<[DescriptorSetDesc; 16]>,
348     push_constant_range: Option<PipelineLayoutPcRange>,
349     spec_constants: &'static [SpecializationMapEntry],
350 }
351 
352 unsafe impl<'a> EntryPointAbstract for ComputeEntryPoint<'a> {
353     #[inline]
module(&self) -> &ShaderModule354     fn module(&self) -> &ShaderModule {
355         self.module
356     }
357 
358     #[inline]
name(&self) -> &CStr359     fn name(&self) -> &CStr {
360         self.name
361     }
362 
363     #[inline]
descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]364     fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc] {
365         &self.descriptor_set_layout_descs
366     }
367 
368     #[inline]
push_constant_range(&self) -> &Option<PipelineLayoutPcRange>369     fn push_constant_range(&self) -> &Option<PipelineLayoutPcRange> {
370         &self.push_constant_range
371     }
372 
373     #[inline]
spec_constants(&self) -> &[SpecializationMapEntry]374     fn spec_constants(&self) -> &[SpecializationMapEntry] {
375         self.spec_constants
376     }
377 }
378 
379 /// Type that contains the definition of an interface between two shader stages, or between
380 /// the outside and a shader stage.
381 #[derive(Clone, Debug)]
382 pub struct ShaderInterface {
383     elements: Vec<ShaderInterfaceEntry>,
384 }
385 
386 impl ShaderInterface {
387     /// Constructs a new `ShaderInterface`.
388     ///
389     /// # Safety
390     ///
391     /// - Must only provide one entry per location.
392     /// - The format of each element must not be larger than 128 bits.
393     // TODO: could this be made safe?
394     #[inline]
new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface395     pub unsafe fn new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface {
396         ShaderInterface { elements }
397     }
398 
399     /// Creates a description of an empty shader interface.
empty() -> ShaderInterface400     pub const fn empty() -> ShaderInterface {
401         ShaderInterface {
402             elements: Vec::new(),
403         }
404     }
405 
406     /// Returns a slice containing the elements of the interface.
407     #[inline]
elements(&self) -> &[ShaderInterfaceEntry]408     pub fn elements(&self) -> &[ShaderInterfaceEntry] {
409         self.elements.as_ref()
410     }
411 
412     /// Checks whether the interface is potentially compatible with another one.
413     ///
414     /// Returns `Ok` if the two interfaces are compatible.
matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError>415     pub fn matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError> {
416         if self.elements().len() != other.elements().len() {
417             return Err(ShaderInterfaceMismatchError::ElementsCountMismatch {
418                 self_elements: self.elements().len() as u32,
419                 other_elements: other.elements().len() as u32,
420             });
421         }
422 
423         for a in self.elements() {
424             for loc in a.location.clone() {
425                 let b = match other
426                     .elements()
427                     .iter()
428                     .find(|e| loc >= e.location.start && loc < e.location.end)
429                 {
430                     None => {
431                         return Err(ShaderInterfaceMismatchError::MissingElement { location: loc })
432                     }
433                     Some(b) => b,
434                 };
435 
436                 if a.format != b.format {
437                     return Err(ShaderInterfaceMismatchError::FormatMismatch {
438                         location: loc,
439                         self_format: a.format,
440                         other_format: b.format,
441                     });
442                 }
443 
444                 // TODO: enforce this?
445                 /*match (a.name, b.name) {
446                     (Some(ref an), Some(ref bn)) => if an != bn { return false },
447                     _ => ()
448                 };*/
449             }
450         }
451 
452         // Note: since we check that the number of elements is the same, we don't need to iterate
453         // over b's elements.
454 
455         Ok(())
456     }
457 }
458 
459 /// Entry of a shader interface definition.
460 #[derive(Debug, Clone)]
461 pub struct ShaderInterfaceEntry {
462     /// Range of locations covered by the element.
463     pub location: Range<u32>,
464     /// Format of a each location of the element.
465     pub format: Format,
466     /// Name of the element, or `None` if the name is unknown.
467     pub name: Option<Cow<'static, str>>,
468 }
469 
470 /// Error that can happen when the interface mismatches between two shader stages.
471 #[derive(Clone, Debug, PartialEq, Eq)]
472 pub enum ShaderInterfaceMismatchError {
473     /// The number of elements is not the same between the two shader interfaces.
474     ElementsCountMismatch {
475         /// Number of elements in the first interface.
476         self_elements: u32,
477         /// Number of elements in the second interface.
478         other_elements: u32,
479     },
480 
481     /// An element is missing from one of the interfaces.
482     MissingElement {
483         /// Location of the missing element.
484         location: u32,
485     },
486 
487     /// The format of an element does not match.
488     FormatMismatch {
489         /// Location of the element that mismatches.
490         location: u32,
491         /// Format in the first interface.
492         self_format: Format,
493         /// Format in the second interface.
494         other_format: Format,
495     },
496 }
497 
498 impl error::Error for ShaderInterfaceMismatchError {}
499 
500 impl fmt::Display for ShaderInterfaceMismatchError {
501     #[inline]
fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>502     fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
503         write!(
504             fmt,
505             "{}",
506             match *self {
507                 ShaderInterfaceMismatchError::ElementsCountMismatch { .. } => {
508                     "the number of elements mismatches"
509                 }
510                 ShaderInterfaceMismatchError::MissingElement { .. } => "an element is missing",
511                 ShaderInterfaceMismatchError::FormatMismatch { .. } => {
512                     "the format of an element does not match"
513                 }
514             }
515         )
516     }
517 }
518 
519 /// Trait for types that contain specialization data for shaders.
520 ///
521 /// Shader modules can contain what is called *specialization constants*. They are the same as
522 /// constants except that their values can be defined when you create a compute pipeline or a
523 /// graphics pipeline. Doing so is done by passing a type that implements the
524 /// `SpecializationConstants` trait and that stores the values in question. The `descriptors()`
525 /// method of this trait indicates how to grab them.
526 ///
527 /// Boolean specialization constants must be stored as 32bits integers, where `0` means `false` and
528 /// any non-zero value means `true`. Integer and floating-point specialization constants are
529 /// stored as their Rust equivalent.
530 ///
531 /// This trait is implemented on `()` for shaders that don't have any specialization constant.
532 ///
533 /// Note that it is the shader module that chooses which type that implements
534 /// `SpecializationConstants` it is possible to pass when creating the pipeline, through [the
535 /// `EntryPointAbstract` trait](trait.EntryPointAbstract.html). Therefore there is generally no
536 /// point to implement this trait yourself, unless you are also writing your own implementation of
537 /// `EntryPointAbstract`.
538 ///
539 /// # Example
540 ///
541 /// ```rust
542 /// use vulkano::pipeline::shader::SpecializationConstants;
543 /// use vulkano::pipeline::shader::SpecializationMapEntry;
544 ///
545 /// #[repr(C)]      // `#[repr(C)]` guarantees that the struct has a specific layout
546 /// struct MySpecConstants {
547 ///     my_integer_constant: i32,
548 ///     a_boolean: u32,
549 ///     floating_point: f32,
550 /// }
551 ///
552 /// unsafe impl SpecializationConstants for MySpecConstants {
553 ///     fn descriptors() -> &'static [SpecializationMapEntry] {
554 ///         static DESCRIPTORS: [SpecializationMapEntry; 3] = [
555 ///             SpecializationMapEntry {
556 ///                 constant_id: 0,
557 ///                 offset: 0,
558 ///                 size: 4,
559 ///             },
560 ///             SpecializationMapEntry {
561 ///                 constant_id: 1,
562 ///                 offset: 4,
563 ///                 size: 4,
564 ///             },
565 ///             SpecializationMapEntry {
566 ///                 constant_id: 2,
567 ///                 offset: 8,
568 ///                 size: 4,
569 ///             },
570 ///         ];
571 ///
572 ///         &DESCRIPTORS
573 ///     }
574 /// }
575 /// ```
576 ///
577 /// # Safety
578 ///
579 /// - The `SpecializationMapEntry` returned must contain valid offsets and sizes.
580 /// - The size of each `SpecializationMapEntry` must match the size of the corresponding constant
581 ///   (`4` for booleans).
582 ///
583 pub unsafe trait SpecializationConstants {
584     /// Returns descriptors of the struct's layout.
descriptors() -> &'static [SpecializationMapEntry]585     fn descriptors() -> &'static [SpecializationMapEntry];
586 }
587 
588 unsafe impl SpecializationConstants for () {
589     #[inline]
descriptors() -> &'static [SpecializationMapEntry]590     fn descriptors() -> &'static [SpecializationMapEntry] {
591         &[]
592     }
593 }
594 
595 /// Describes an individual constant to set in the shader. Also a field in the struct.
596 // Implementation note: has the same memory representation as a `VkSpecializationMapEntry`.
597 #[derive(Clone, Copy, Debug, PartialEq, Eq)]
598 #[repr(C)]
599 pub struct SpecializationMapEntry {
600     /// Identifier of the constant in the shader that corresponds to this field.
601     ///
602     /// For SPIR-V, this must be the value of the `SpecId` decoration applied to the specialization
603     /// constant.
604     /// For GLSL, this must be the value of `N` in the `layout(constant_id = N)` attribute applied
605     /// to a constant.
606     pub constant_id: u32,
607 
608     /// Offset within the struct where the data can be found.
609     pub offset: u32,
610 
611     /// Size of the data in bytes. Must match the size of the constant (`4` for booleans).
612     pub size: usize,
613 }
614 
615 /// Describes a set of shader stages.
616 // TODO: add example with BitOr
617 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
618 pub struct ShaderStages {
619     pub vertex: bool,
620     pub tessellation_control: bool,
621     pub tessellation_evaluation: bool,
622     pub geometry: bool,
623     pub fragment: bool,
624     pub compute: bool,
625 }
626 
627 impl ShaderStages {
628     /// Creates a `ShaderStages` struct will all stages set to `true`.
629     // TODO: add example
630     #[inline]
all() -> ShaderStages631     pub const fn all() -> ShaderStages {
632         ShaderStages {
633             vertex: true,
634             tessellation_control: true,
635             tessellation_evaluation: true,
636             geometry: true,
637             fragment: true,
638             compute: true,
639         }
640     }
641 
642     /// Creates a `ShaderStages` struct will all stages set to `false`.
643     // TODO: add example
644     #[inline]
none() -> ShaderStages645     pub const fn none() -> ShaderStages {
646         ShaderStages {
647             vertex: false,
648             tessellation_control: false,
649             tessellation_evaluation: false,
650             geometry: false,
651             fragment: false,
652             compute: false,
653         }
654     }
655 
656     /// Creates a `ShaderStages` struct with all graphics stages set to `true`.
657     // TODO: add example
658     #[inline]
all_graphics() -> ShaderStages659     pub const fn all_graphics() -> ShaderStages {
660         ShaderStages {
661             vertex: true,
662             tessellation_control: true,
663             tessellation_evaluation: true,
664             geometry: true,
665             fragment: true,
666             compute: false,
667         }
668     }
669 
670     /// Creates a `ShaderStages` struct with the compute stage set to `true`.
671     // TODO: add example
672     #[inline]
compute() -> ShaderStages673     pub const fn compute() -> ShaderStages {
674         ShaderStages {
675             vertex: false,
676             tessellation_control: false,
677             tessellation_evaluation: false,
678             geometry: false,
679             fragment: false,
680             compute: true,
681         }
682     }
683 
684     /// Checks whether we have more stages enabled than `other`.
685     // TODO: add example
686     #[inline]
ensure_superset_of( &self, other: &ShaderStages, ) -> Result<(), ShaderStagesSupersetError>687     pub const fn ensure_superset_of(
688         &self,
689         other: &ShaderStages,
690     ) -> Result<(), ShaderStagesSupersetError> {
691         if (self.vertex || !other.vertex)
692             && (self.tessellation_control || !other.tessellation_control)
693             && (self.tessellation_evaluation || !other.tessellation_evaluation)
694             && (self.geometry || !other.geometry)
695             && (self.fragment || !other.fragment)
696             && (self.compute || !other.compute)
697         {
698             Ok(())
699         } else {
700             Err(ShaderStagesSupersetError::NotSuperset)
701         }
702     }
703 
704     /// Checks whether any of the stages in `self` are also present in `other`.
705     // TODO: add example
706     #[inline]
intersects(&self, other: &ShaderStages) -> bool707     pub const fn intersects(&self, other: &ShaderStages) -> bool {
708         (self.vertex && other.vertex)
709             || (self.tessellation_control && other.tessellation_control)
710             || (self.tessellation_evaluation && other.tessellation_evaluation)
711             || (self.geometry && other.geometry)
712             || (self.fragment && other.fragment)
713             || (self.compute && other.compute)
714     }
715 }
716 
717 impl From<ShaderStages> for ash::vk::ShaderStageFlags {
718     #[inline]
from(val: ShaderStages) -> Self719     fn from(val: ShaderStages) -> Self {
720         let mut result = ash::vk::ShaderStageFlags::empty();
721         if val.vertex {
722             result |= ash::vk::ShaderStageFlags::VERTEX;
723         }
724         if val.tessellation_control {
725             result |= ash::vk::ShaderStageFlags::TESSELLATION_CONTROL;
726         }
727         if val.tessellation_evaluation {
728             result |= ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION;
729         }
730         if val.geometry {
731             result |= ash::vk::ShaderStageFlags::GEOMETRY;
732         }
733         if val.fragment {
734             result |= ash::vk::ShaderStageFlags::FRAGMENT;
735         }
736         if val.compute {
737             result |= ash::vk::ShaderStageFlags::COMPUTE;
738         }
739         result
740     }
741 }
742 
743 impl From<ash::vk::ShaderStageFlags> for ShaderStages {
744     #[inline]
from(val: ash::vk::ShaderStageFlags) -> Self745     fn from(val: ash::vk::ShaderStageFlags) -> Self {
746         Self {
747             vertex: val.intersects(ash::vk::ShaderStageFlags::VERTEX),
748             tessellation_control: val.intersects(ash::vk::ShaderStageFlags::TESSELLATION_CONTROL),
749             tessellation_evaluation: val
750                 .intersects(ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION),
751             geometry: val.intersects(ash::vk::ShaderStageFlags::GEOMETRY),
752             fragment: val.intersects(ash::vk::ShaderStageFlags::FRAGMENT),
753             compute: val.intersects(ash::vk::ShaderStageFlags::COMPUTE),
754         }
755     }
756 }
757 
758 impl BitOr for ShaderStages {
759     type Output = ShaderStages;
760 
761     #[inline]
bitor(self, other: ShaderStages) -> ShaderStages762     fn bitor(self, other: ShaderStages) -> ShaderStages {
763         ShaderStages {
764             vertex: self.vertex || other.vertex,
765             tessellation_control: self.tessellation_control || other.tessellation_control,
766             tessellation_evaluation: self.tessellation_evaluation || other.tessellation_evaluation,
767             geometry: self.geometry || other.geometry,
768             fragment: self.fragment || other.fragment,
769             compute: self.compute || other.compute,
770         }
771     }
772 }
773 
774 impl From<ShaderStages> for PipelineStages {
775     #[inline]
from(stages: ShaderStages) -> PipelineStages776     fn from(stages: ShaderStages) -> PipelineStages {
777         PipelineStages {
778             vertex_shader: stages.vertex,
779             tessellation_control_shader: stages.tessellation_control,
780             tessellation_evaluation_shader: stages.tessellation_evaluation,
781             geometry_shader: stages.geometry,
782             fragment_shader: stages.fragment,
783             compute_shader: stages.compute,
784             ..PipelineStages::none()
785         }
786     }
787 }
788 
789 /// Error when checking that a `ShaderStages` object is a superset of another.
790 #[derive(Debug, Clone)]
791 pub enum ShaderStagesSupersetError {
792     NotSuperset,
793 }
794 
795 impl error::Error for ShaderStagesSupersetError {}
796 
797 impl fmt::Display for ShaderStagesSupersetError {
798     #[inline]
fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>799     fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
800         write!(
801             fmt,
802             "{}",
803             match *self {
804                 ShaderStagesSupersetError::NotSuperset => "shader stages not a superset",
805             }
806         )
807     }
808 }
809