1 // Copyright (c) 2016 The vulkano developers 2 // Licensed under the Apache License, Version 2.0 3 // <LICENSE-APACHE or 4 // https://www.apache.org/licenses/LICENSE-2.0> or the MIT 5 // license <LICENSE-MIT or https://opensource.org/licenses/MIT>, 6 // at your option. All files in the project carrying such 7 // notice may not be copied, modified, or distributed except 8 // according to those terms. 9 10 //! Stage of a graphics pipeline. 11 //! 12 //! In Vulkan, shaders are grouped in *shader modules*. Each shader module is built from SPIR-V 13 //! code and can contain one or more entry points. Note that for the moment the official 14 //! GLSL-to-SPIR-V compiler does not support multiple entry points. 15 //! 16 //! The vulkano library does not provide any functionality that checks and introspects the SPIR-V 17 //! code, therefore the whole shader-related API is unsafe. You are encouraged to use the 18 //! `vulkano-shaders` crate that will generate Rust code that wraps around vulkano's shaders API. 19 20 use crate::check_errors; 21 use crate::descriptor_set::layout::DescriptorSetDesc; 22 use crate::device::Device; 23 use crate::format::Format; 24 use crate::pipeline::input_assembly::PrimitiveTopology; 25 use crate::pipeline::layout::PipelineLayoutPcRange; 26 use crate::sync::PipelineStages; 27 use crate::OomError; 28 use crate::VulkanObject; 29 use smallvec::SmallVec; 30 use std::borrow::Cow; 31 use std::error; 32 use std::ffi::CStr; 33 use std::fmt; 34 use std::mem; 35 use std::mem::MaybeUninit; 36 use std::ops::BitOr; 37 use std::ops::Range; 38 use std::ptr; 39 use std::sync::Arc; 40 41 /// Contains SPIR-V code with one or more entry points. 42 /// 43 /// Note that it is advised to wrap around a `ShaderModule` with a struct that is different for 44 /// each shader. 45 #[derive(Debug)] 46 pub struct ShaderModule { 47 // The module. 48 module: ash::vk::ShaderModule, 49 // Pointer to the device. 50 device: Arc<Device>, 51 } 52 53 impl ShaderModule { 54 /// Builds a new shader module from SPIR-V bytes. 55 /// 56 /// # Safety 57 /// 58 /// - The SPIR-V code is not validated. 59 /// - The SPIR-V code may require some features that are not enabled. This isn't checked by 60 /// this function either. 61 /// new(device: Arc<Device>, spirv: &[u8]) -> Result<Arc<ShaderModule>, OomError>62 pub unsafe fn new(device: Arc<Device>, spirv: &[u8]) -> Result<Arc<ShaderModule>, OomError> { 63 debug_assert!((spirv.len() % 4) == 0); 64 Self::from_ptr(device, spirv.as_ptr() as *const _, spirv.len()) 65 } 66 67 /// Builds a new shader module from SPIR-V 32-bit words. 68 /// 69 /// # Safety 70 /// 71 /// - The SPIR-V code is not validated. 72 /// - The SPIR-V code may require some features that are not enabled. This isn't checked by 73 /// this function either. 74 /// from_words( device: Arc<Device>, spirv: &[u32], ) -> Result<Arc<ShaderModule>, OomError>75 pub unsafe fn from_words( 76 device: Arc<Device>, 77 spirv: &[u32], 78 ) -> Result<Arc<ShaderModule>, OomError> { 79 Self::from_ptr(device, spirv.as_ptr(), spirv.len() * mem::size_of::<u32>()) 80 } 81 82 /// Builds a new shader module from SPIR-V. 83 /// 84 /// # Safety 85 /// 86 /// - The SPIR-V code is not validated. 87 /// - The SPIR-V code may require some features that are not enabled. This isn't checked by 88 /// this function either. 89 /// from_ptr( device: Arc<Device>, spirv: *const u32, spirv_len: usize, ) -> Result<Arc<ShaderModule>, OomError>90 unsafe fn from_ptr( 91 device: Arc<Device>, 92 spirv: *const u32, 93 spirv_len: usize, 94 ) -> Result<Arc<ShaderModule>, OomError> { 95 let module = { 96 let infos = ash::vk::ShaderModuleCreateInfo { 97 flags: ash::vk::ShaderModuleCreateFlags::empty(), 98 code_size: spirv_len, 99 p_code: spirv, 100 ..Default::default() 101 }; 102 103 let fns = device.fns(); 104 let mut output = MaybeUninit::uninit(); 105 check_errors(fns.v1_0.create_shader_module( 106 device.internal_object(), 107 &infos, 108 ptr::null(), 109 output.as_mut_ptr(), 110 ))?; 111 output.assume_init() 112 }; 113 114 Ok(Arc::new(ShaderModule { 115 module: module, 116 device: device, 117 })) 118 } 119 120 /// Gets access to an entry point contained in this module. 121 /// 122 /// This is purely a *logical* operation. It returns a struct that *represents* the entry 123 /// point but doesn't actually do anything. 124 /// 125 /// # Safety 126 /// 127 /// - The user must check that the entry point exists in the module, as this is not checked 128 /// by Vulkan. 129 /// - The input, output and layout must correctly describe the input, output and layout used 130 /// by this stage. 131 /// graphics_entry_point<'a, D>( &'a self, name: &'a CStr, descriptor_set_layout_descs: D, push_constant_range: Option<PipelineLayoutPcRange>, spec_constants: &'static [SpecializationMapEntry], input: ShaderInterface, output: ShaderInterface, ty: GraphicsShaderType, ) -> GraphicsEntryPoint<'a> where D: IntoIterator<Item = DescriptorSetDesc>,132 pub unsafe fn graphics_entry_point<'a, D>( 133 &'a self, 134 name: &'a CStr, 135 descriptor_set_layout_descs: D, 136 push_constant_range: Option<PipelineLayoutPcRange>, 137 spec_constants: &'static [SpecializationMapEntry], 138 input: ShaderInterface, 139 output: ShaderInterface, 140 ty: GraphicsShaderType, 141 ) -> GraphicsEntryPoint<'a> 142 where 143 D: IntoIterator<Item = DescriptorSetDesc>, 144 { 145 GraphicsEntryPoint { 146 module: self, 147 name, 148 descriptor_set_layout_descs: descriptor_set_layout_descs.into_iter().collect(), 149 push_constant_range, 150 spec_constants, 151 input, 152 output, 153 ty, 154 } 155 } 156 157 /// Gets access to an entry point contained in this module. 158 /// 159 /// This is purely a *logical* operation. It returns a struct that *represents* the entry 160 /// point but doesn't actually do anything. 161 /// 162 /// # Safety 163 /// 164 /// - The user must check that the entry point exists in the module, as this is not checked 165 /// by Vulkan. 166 /// - The layout must correctly describe the layout used by this stage. 167 /// 168 #[inline] compute_entry_point<'a, D>( &'a self, name: &'a CStr, descriptor_set_layout_descs: D, push_constant_range: Option<PipelineLayoutPcRange>, spec_constants: &'static [SpecializationMapEntry], ) -> ComputeEntryPoint<'a> where D: IntoIterator<Item = DescriptorSetDesc>,169 pub unsafe fn compute_entry_point<'a, D>( 170 &'a self, 171 name: &'a CStr, 172 descriptor_set_layout_descs: D, 173 push_constant_range: Option<PipelineLayoutPcRange>, 174 spec_constants: &'static [SpecializationMapEntry], 175 ) -> ComputeEntryPoint<'a> 176 where 177 D: IntoIterator<Item = DescriptorSetDesc>, 178 { 179 ComputeEntryPoint { 180 module: self, 181 name, 182 descriptor_set_layout_descs: descriptor_set_layout_descs.into_iter().collect(), 183 push_constant_range, 184 spec_constants, 185 } 186 } 187 } 188 189 unsafe impl VulkanObject for ShaderModule { 190 type Object = ash::vk::ShaderModule; 191 192 #[inline] internal_object(&self) -> ash::vk::ShaderModule193 fn internal_object(&self) -> ash::vk::ShaderModule { 194 self.module 195 } 196 } 197 198 impl Drop for ShaderModule { 199 #[inline] drop(&mut self)200 fn drop(&mut self) { 201 unsafe { 202 let fns = self.device.fns(); 203 fns.v1_0 204 .destroy_shader_module(self.device.internal_object(), self.module, ptr::null()); 205 } 206 } 207 } 208 209 pub unsafe trait EntryPointAbstract { 210 /// Returns the module this entry point comes from. module(&self) -> &ShaderModule211 fn module(&self) -> &ShaderModule; 212 213 /// Returns the name of the entry point. name(&self) -> &CStr214 fn name(&self) -> &CStr; 215 216 /// Returns a description of the descriptor set layouts. descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]217 fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]; 218 219 /// Returns the push constant ranges. push_constant_range(&self) -> &Option<PipelineLayoutPcRange>220 fn push_constant_range(&self) -> &Option<PipelineLayoutPcRange>; 221 222 /// Returns the layout of the specialization constants. spec_constants(&self) -> &[SpecializationMapEntry]223 fn spec_constants(&self) -> &[SpecializationMapEntry]; 224 } 225 226 /// Represents a shader entry point in a shader module. 227 /// 228 /// Can be obtained by calling `entry_point()` on the shader module. 229 #[derive(Clone, Debug)] 230 pub struct GraphicsEntryPoint<'a> { 231 module: &'a ShaderModule, 232 name: &'a CStr, 233 234 descriptor_set_layout_descs: SmallVec<[DescriptorSetDesc; 16]>, 235 push_constant_range: Option<PipelineLayoutPcRange>, 236 spec_constants: &'static [SpecializationMapEntry], 237 input: ShaderInterface, 238 output: ShaderInterface, 239 ty: GraphicsShaderType, 240 } 241 242 impl<'a> GraphicsEntryPoint<'a> { 243 /// Returns the input attributes used by the shader stage. 244 #[inline] input(&self) -> &ShaderInterface245 pub fn input(&self) -> &ShaderInterface { 246 &self.input 247 } 248 249 /// Returns the output attributes used by the shader stage. 250 #[inline] output(&self) -> &ShaderInterface251 pub fn output(&self) -> &ShaderInterface { 252 &self.output 253 } 254 255 /// Returns the type of shader. 256 #[inline] ty(&self) -> GraphicsShaderType257 pub fn ty(&self) -> GraphicsShaderType { 258 self.ty 259 } 260 } 261 262 unsafe impl<'a> EntryPointAbstract for GraphicsEntryPoint<'a> { 263 #[inline] module(&self) -> &ShaderModule264 fn module(&self) -> &ShaderModule { 265 self.module 266 } 267 268 #[inline] name(&self) -> &CStr269 fn name(&self) -> &CStr { 270 self.name 271 } 272 273 #[inline] descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]274 fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc] { 275 &self.descriptor_set_layout_descs 276 } 277 278 #[inline] push_constant_range(&self) -> &Option<PipelineLayoutPcRange>279 fn push_constant_range(&self) -> &Option<PipelineLayoutPcRange> { 280 &self.push_constant_range 281 } 282 283 #[inline] spec_constants(&self) -> &[SpecializationMapEntry]284 fn spec_constants(&self) -> &[SpecializationMapEntry] { 285 self.spec_constants 286 } 287 } 288 289 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 290 pub enum GraphicsShaderType { 291 Vertex, 292 TessellationControl, 293 TessellationEvaluation, 294 Geometry(GeometryShaderExecutionMode), 295 Fragment, 296 } 297 298 /// Declares which type of primitives are expected by the geometry shader. 299 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 300 pub enum GeometryShaderExecutionMode { 301 Points, 302 Lines, 303 LinesWithAdjacency, 304 Triangles, 305 TrianglesWithAdjacency, 306 } 307 308 impl GeometryShaderExecutionMode { 309 /// Returns true if the given primitive topology can be used with this execution mode. 310 #[inline] matches(&self, input: PrimitiveTopology) -> bool311 pub fn matches(&self, input: PrimitiveTopology) -> bool { 312 match (*self, input) { 313 (GeometryShaderExecutionMode::Points, PrimitiveTopology::PointList) => true, 314 (GeometryShaderExecutionMode::Lines, PrimitiveTopology::LineList) => true, 315 (GeometryShaderExecutionMode::Lines, PrimitiveTopology::LineStrip) => true, 316 ( 317 GeometryShaderExecutionMode::LinesWithAdjacency, 318 PrimitiveTopology::LineListWithAdjacency, 319 ) => true, 320 ( 321 GeometryShaderExecutionMode::LinesWithAdjacency, 322 PrimitiveTopology::LineStripWithAdjacency, 323 ) => true, 324 (GeometryShaderExecutionMode::Triangles, PrimitiveTopology::TriangleList) => true, 325 (GeometryShaderExecutionMode::Triangles, PrimitiveTopology::TriangleStrip) => true, 326 (GeometryShaderExecutionMode::Triangles, PrimitiveTopology::TriangleFan) => true, 327 ( 328 GeometryShaderExecutionMode::TrianglesWithAdjacency, 329 PrimitiveTopology::TriangleListWithAdjacency, 330 ) => true, 331 ( 332 GeometryShaderExecutionMode::TrianglesWithAdjacency, 333 PrimitiveTopology::TriangleStripWithAdjacency, 334 ) => true, 335 _ => false, 336 } 337 } 338 } 339 340 /// Represents the entry point of a compute shader in a shader module. 341 /// 342 /// Can be obtained by calling `compute_shader_entry_point()` on the shader module. 343 #[derive(Debug, Clone)] 344 pub struct ComputeEntryPoint<'a> { 345 module: &'a ShaderModule, 346 name: &'a CStr, 347 descriptor_set_layout_descs: SmallVec<[DescriptorSetDesc; 16]>, 348 push_constant_range: Option<PipelineLayoutPcRange>, 349 spec_constants: &'static [SpecializationMapEntry], 350 } 351 352 unsafe impl<'a> EntryPointAbstract for ComputeEntryPoint<'a> { 353 #[inline] module(&self) -> &ShaderModule354 fn module(&self) -> &ShaderModule { 355 self.module 356 } 357 358 #[inline] name(&self) -> &CStr359 fn name(&self) -> &CStr { 360 self.name 361 } 362 363 #[inline] descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc]364 fn descriptor_set_layout_descs(&self) -> &[DescriptorSetDesc] { 365 &self.descriptor_set_layout_descs 366 } 367 368 #[inline] push_constant_range(&self) -> &Option<PipelineLayoutPcRange>369 fn push_constant_range(&self) -> &Option<PipelineLayoutPcRange> { 370 &self.push_constant_range 371 } 372 373 #[inline] spec_constants(&self) -> &[SpecializationMapEntry]374 fn spec_constants(&self) -> &[SpecializationMapEntry] { 375 self.spec_constants 376 } 377 } 378 379 /// Type that contains the definition of an interface between two shader stages, or between 380 /// the outside and a shader stage. 381 #[derive(Clone, Debug)] 382 pub struct ShaderInterface { 383 elements: Vec<ShaderInterfaceEntry>, 384 } 385 386 impl ShaderInterface { 387 /// Constructs a new `ShaderInterface`. 388 /// 389 /// # Safety 390 /// 391 /// - Must only provide one entry per location. 392 /// - The format of each element must not be larger than 128 bits. 393 // TODO: could this be made safe? 394 #[inline] new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface395 pub unsafe fn new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface { 396 ShaderInterface { elements } 397 } 398 399 /// Creates a description of an empty shader interface. empty() -> ShaderInterface400 pub const fn empty() -> ShaderInterface { 401 ShaderInterface { 402 elements: Vec::new(), 403 } 404 } 405 406 /// Returns a slice containing the elements of the interface. 407 #[inline] elements(&self) -> &[ShaderInterfaceEntry]408 pub fn elements(&self) -> &[ShaderInterfaceEntry] { 409 self.elements.as_ref() 410 } 411 412 /// Checks whether the interface is potentially compatible with another one. 413 /// 414 /// Returns `Ok` if the two interfaces are compatible. matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError>415 pub fn matches(&self, other: &ShaderInterface) -> Result<(), ShaderInterfaceMismatchError> { 416 if self.elements().len() != other.elements().len() { 417 return Err(ShaderInterfaceMismatchError::ElementsCountMismatch { 418 self_elements: self.elements().len() as u32, 419 other_elements: other.elements().len() as u32, 420 }); 421 } 422 423 for a in self.elements() { 424 for loc in a.location.clone() { 425 let b = match other 426 .elements() 427 .iter() 428 .find(|e| loc >= e.location.start && loc < e.location.end) 429 { 430 None => { 431 return Err(ShaderInterfaceMismatchError::MissingElement { location: loc }) 432 } 433 Some(b) => b, 434 }; 435 436 if a.format != b.format { 437 return Err(ShaderInterfaceMismatchError::FormatMismatch { 438 location: loc, 439 self_format: a.format, 440 other_format: b.format, 441 }); 442 } 443 444 // TODO: enforce this? 445 /*match (a.name, b.name) { 446 (Some(ref an), Some(ref bn)) => if an != bn { return false }, 447 _ => () 448 };*/ 449 } 450 } 451 452 // Note: since we check that the number of elements is the same, we don't need to iterate 453 // over b's elements. 454 455 Ok(()) 456 } 457 } 458 459 /// Entry of a shader interface definition. 460 #[derive(Debug, Clone)] 461 pub struct ShaderInterfaceEntry { 462 /// Range of locations covered by the element. 463 pub location: Range<u32>, 464 /// Format of a each location of the element. 465 pub format: Format, 466 /// Name of the element, or `None` if the name is unknown. 467 pub name: Option<Cow<'static, str>>, 468 } 469 470 /// Error that can happen when the interface mismatches between two shader stages. 471 #[derive(Clone, Debug, PartialEq, Eq)] 472 pub enum ShaderInterfaceMismatchError { 473 /// The number of elements is not the same between the two shader interfaces. 474 ElementsCountMismatch { 475 /// Number of elements in the first interface. 476 self_elements: u32, 477 /// Number of elements in the second interface. 478 other_elements: u32, 479 }, 480 481 /// An element is missing from one of the interfaces. 482 MissingElement { 483 /// Location of the missing element. 484 location: u32, 485 }, 486 487 /// The format of an element does not match. 488 FormatMismatch { 489 /// Location of the element that mismatches. 490 location: u32, 491 /// Format in the first interface. 492 self_format: Format, 493 /// Format in the second interface. 494 other_format: Format, 495 }, 496 } 497 498 impl error::Error for ShaderInterfaceMismatchError {} 499 500 impl fmt::Display for ShaderInterfaceMismatchError { 501 #[inline] fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>502 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { 503 write!( 504 fmt, 505 "{}", 506 match *self { 507 ShaderInterfaceMismatchError::ElementsCountMismatch { .. } => { 508 "the number of elements mismatches" 509 } 510 ShaderInterfaceMismatchError::MissingElement { .. } => "an element is missing", 511 ShaderInterfaceMismatchError::FormatMismatch { .. } => { 512 "the format of an element does not match" 513 } 514 } 515 ) 516 } 517 } 518 519 /// Trait for types that contain specialization data for shaders. 520 /// 521 /// Shader modules can contain what is called *specialization constants*. They are the same as 522 /// constants except that their values can be defined when you create a compute pipeline or a 523 /// graphics pipeline. Doing so is done by passing a type that implements the 524 /// `SpecializationConstants` trait and that stores the values in question. The `descriptors()` 525 /// method of this trait indicates how to grab them. 526 /// 527 /// Boolean specialization constants must be stored as 32bits integers, where `0` means `false` and 528 /// any non-zero value means `true`. Integer and floating-point specialization constants are 529 /// stored as their Rust equivalent. 530 /// 531 /// This trait is implemented on `()` for shaders that don't have any specialization constant. 532 /// 533 /// Note that it is the shader module that chooses which type that implements 534 /// `SpecializationConstants` it is possible to pass when creating the pipeline, through [the 535 /// `EntryPointAbstract` trait](trait.EntryPointAbstract.html). Therefore there is generally no 536 /// point to implement this trait yourself, unless you are also writing your own implementation of 537 /// `EntryPointAbstract`. 538 /// 539 /// # Example 540 /// 541 /// ```rust 542 /// use vulkano::pipeline::shader::SpecializationConstants; 543 /// use vulkano::pipeline::shader::SpecializationMapEntry; 544 /// 545 /// #[repr(C)] // `#[repr(C)]` guarantees that the struct has a specific layout 546 /// struct MySpecConstants { 547 /// my_integer_constant: i32, 548 /// a_boolean: u32, 549 /// floating_point: f32, 550 /// } 551 /// 552 /// unsafe impl SpecializationConstants for MySpecConstants { 553 /// fn descriptors() -> &'static [SpecializationMapEntry] { 554 /// static DESCRIPTORS: [SpecializationMapEntry; 3] = [ 555 /// SpecializationMapEntry { 556 /// constant_id: 0, 557 /// offset: 0, 558 /// size: 4, 559 /// }, 560 /// SpecializationMapEntry { 561 /// constant_id: 1, 562 /// offset: 4, 563 /// size: 4, 564 /// }, 565 /// SpecializationMapEntry { 566 /// constant_id: 2, 567 /// offset: 8, 568 /// size: 4, 569 /// }, 570 /// ]; 571 /// 572 /// &DESCRIPTORS 573 /// } 574 /// } 575 /// ``` 576 /// 577 /// # Safety 578 /// 579 /// - The `SpecializationMapEntry` returned must contain valid offsets and sizes. 580 /// - The size of each `SpecializationMapEntry` must match the size of the corresponding constant 581 /// (`4` for booleans). 582 /// 583 pub unsafe trait SpecializationConstants { 584 /// Returns descriptors of the struct's layout. descriptors() -> &'static [SpecializationMapEntry]585 fn descriptors() -> &'static [SpecializationMapEntry]; 586 } 587 588 unsafe impl SpecializationConstants for () { 589 #[inline] descriptors() -> &'static [SpecializationMapEntry]590 fn descriptors() -> &'static [SpecializationMapEntry] { 591 &[] 592 } 593 } 594 595 /// Describes an individual constant to set in the shader. Also a field in the struct. 596 // Implementation note: has the same memory representation as a `VkSpecializationMapEntry`. 597 #[derive(Clone, Copy, Debug, PartialEq, Eq)] 598 #[repr(C)] 599 pub struct SpecializationMapEntry { 600 /// Identifier of the constant in the shader that corresponds to this field. 601 /// 602 /// For SPIR-V, this must be the value of the `SpecId` decoration applied to the specialization 603 /// constant. 604 /// For GLSL, this must be the value of `N` in the `layout(constant_id = N)` attribute applied 605 /// to a constant. 606 pub constant_id: u32, 607 608 /// Offset within the struct where the data can be found. 609 pub offset: u32, 610 611 /// Size of the data in bytes. Must match the size of the constant (`4` for booleans). 612 pub size: usize, 613 } 614 615 /// Describes a set of shader stages. 616 // TODO: add example with BitOr 617 #[derive(Debug, Copy, Clone, PartialEq, Eq)] 618 pub struct ShaderStages { 619 pub vertex: bool, 620 pub tessellation_control: bool, 621 pub tessellation_evaluation: bool, 622 pub geometry: bool, 623 pub fragment: bool, 624 pub compute: bool, 625 } 626 627 impl ShaderStages { 628 /// Creates a `ShaderStages` struct will all stages set to `true`. 629 // TODO: add example 630 #[inline] all() -> ShaderStages631 pub const fn all() -> ShaderStages { 632 ShaderStages { 633 vertex: true, 634 tessellation_control: true, 635 tessellation_evaluation: true, 636 geometry: true, 637 fragment: true, 638 compute: true, 639 } 640 } 641 642 /// Creates a `ShaderStages` struct will all stages set to `false`. 643 // TODO: add example 644 #[inline] none() -> ShaderStages645 pub const fn none() -> ShaderStages { 646 ShaderStages { 647 vertex: false, 648 tessellation_control: false, 649 tessellation_evaluation: false, 650 geometry: false, 651 fragment: false, 652 compute: false, 653 } 654 } 655 656 /// Creates a `ShaderStages` struct with all graphics stages set to `true`. 657 // TODO: add example 658 #[inline] all_graphics() -> ShaderStages659 pub const fn all_graphics() -> ShaderStages { 660 ShaderStages { 661 vertex: true, 662 tessellation_control: true, 663 tessellation_evaluation: true, 664 geometry: true, 665 fragment: true, 666 compute: false, 667 } 668 } 669 670 /// Creates a `ShaderStages` struct with the compute stage set to `true`. 671 // TODO: add example 672 #[inline] compute() -> ShaderStages673 pub const fn compute() -> ShaderStages { 674 ShaderStages { 675 vertex: false, 676 tessellation_control: false, 677 tessellation_evaluation: false, 678 geometry: false, 679 fragment: false, 680 compute: true, 681 } 682 } 683 684 /// Checks whether we have more stages enabled than `other`. 685 // TODO: add example 686 #[inline] ensure_superset_of( &self, other: &ShaderStages, ) -> Result<(), ShaderStagesSupersetError>687 pub const fn ensure_superset_of( 688 &self, 689 other: &ShaderStages, 690 ) -> Result<(), ShaderStagesSupersetError> { 691 if (self.vertex || !other.vertex) 692 && (self.tessellation_control || !other.tessellation_control) 693 && (self.tessellation_evaluation || !other.tessellation_evaluation) 694 && (self.geometry || !other.geometry) 695 && (self.fragment || !other.fragment) 696 && (self.compute || !other.compute) 697 { 698 Ok(()) 699 } else { 700 Err(ShaderStagesSupersetError::NotSuperset) 701 } 702 } 703 704 /// Checks whether any of the stages in `self` are also present in `other`. 705 // TODO: add example 706 #[inline] intersects(&self, other: &ShaderStages) -> bool707 pub const fn intersects(&self, other: &ShaderStages) -> bool { 708 (self.vertex && other.vertex) 709 || (self.tessellation_control && other.tessellation_control) 710 || (self.tessellation_evaluation && other.tessellation_evaluation) 711 || (self.geometry && other.geometry) 712 || (self.fragment && other.fragment) 713 || (self.compute && other.compute) 714 } 715 } 716 717 impl From<ShaderStages> for ash::vk::ShaderStageFlags { 718 #[inline] from(val: ShaderStages) -> Self719 fn from(val: ShaderStages) -> Self { 720 let mut result = ash::vk::ShaderStageFlags::empty(); 721 if val.vertex { 722 result |= ash::vk::ShaderStageFlags::VERTEX; 723 } 724 if val.tessellation_control { 725 result |= ash::vk::ShaderStageFlags::TESSELLATION_CONTROL; 726 } 727 if val.tessellation_evaluation { 728 result |= ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION; 729 } 730 if val.geometry { 731 result |= ash::vk::ShaderStageFlags::GEOMETRY; 732 } 733 if val.fragment { 734 result |= ash::vk::ShaderStageFlags::FRAGMENT; 735 } 736 if val.compute { 737 result |= ash::vk::ShaderStageFlags::COMPUTE; 738 } 739 result 740 } 741 } 742 743 impl From<ash::vk::ShaderStageFlags> for ShaderStages { 744 #[inline] from(val: ash::vk::ShaderStageFlags) -> Self745 fn from(val: ash::vk::ShaderStageFlags) -> Self { 746 Self { 747 vertex: val.intersects(ash::vk::ShaderStageFlags::VERTEX), 748 tessellation_control: val.intersects(ash::vk::ShaderStageFlags::TESSELLATION_CONTROL), 749 tessellation_evaluation: val 750 .intersects(ash::vk::ShaderStageFlags::TESSELLATION_EVALUATION), 751 geometry: val.intersects(ash::vk::ShaderStageFlags::GEOMETRY), 752 fragment: val.intersects(ash::vk::ShaderStageFlags::FRAGMENT), 753 compute: val.intersects(ash::vk::ShaderStageFlags::COMPUTE), 754 } 755 } 756 } 757 758 impl BitOr for ShaderStages { 759 type Output = ShaderStages; 760 761 #[inline] bitor(self, other: ShaderStages) -> ShaderStages762 fn bitor(self, other: ShaderStages) -> ShaderStages { 763 ShaderStages { 764 vertex: self.vertex || other.vertex, 765 tessellation_control: self.tessellation_control || other.tessellation_control, 766 tessellation_evaluation: self.tessellation_evaluation || other.tessellation_evaluation, 767 geometry: self.geometry || other.geometry, 768 fragment: self.fragment || other.fragment, 769 compute: self.compute || other.compute, 770 } 771 } 772 } 773 774 impl From<ShaderStages> for PipelineStages { 775 #[inline] from(stages: ShaderStages) -> PipelineStages776 fn from(stages: ShaderStages) -> PipelineStages { 777 PipelineStages { 778 vertex_shader: stages.vertex, 779 tessellation_control_shader: stages.tessellation_control, 780 tessellation_evaluation_shader: stages.tessellation_evaluation, 781 geometry_shader: stages.geometry, 782 fragment_shader: stages.fragment, 783 compute_shader: stages.compute, 784 ..PipelineStages::none() 785 } 786 } 787 } 788 789 /// Error when checking that a `ShaderStages` object is a superset of another. 790 #[derive(Debug, Clone)] 791 pub enum ShaderStagesSupersetError { 792 NotSuperset, 793 } 794 795 impl error::Error for ShaderStagesSupersetError {} 796 797 impl fmt::Display for ShaderStagesSupersetError { 798 #[inline] fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error>799 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> { 800 write!( 801 fmt, 802 "{}", 803 match *self { 804 ShaderStagesSupersetError::NotSuperset => "shader stages not a superset", 805 } 806 ) 807 } 808 } 809