• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 // Copyright 2018 The SwiftShader Authors. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //    http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "VkPipeline.hpp"
16 
17 #include "VkDestroy.hpp"
18 #include "VkDevice.hpp"
19 #include "VkPipelineCache.hpp"
20 #include "VkPipelineLayout.hpp"
21 #include "VkRenderPass.hpp"
22 #include "VkShaderModule.hpp"
23 #include "VkStringify.hpp"
24 #include "Pipeline/ComputeProgram.hpp"
25 #include "Pipeline/SpirvShader.hpp"
26 
27 #include "marl/trace.h"
28 
29 #include "spirv-tools/optimizer.hpp"
30 
31 #include <iostream>
32 
33 namespace {
34 
35 // optimizeSpirv() applies and freezes specializations into constants, and runs spirv-opt.
optimizeSpirv(const vk::PipelineCache::SpirvBinaryKey & key)36 sw::SpirvBinary optimizeSpirv(const vk::PipelineCache::SpirvBinaryKey &key)
37 {
38 	const sw::SpirvBinary &code = key.getBinary();
39 	const VkSpecializationInfo *specializationInfo = key.getSpecializationInfo();
40 	bool optimize = key.getOptimization();
41 
42 	spvtools::Optimizer opt{ vk::SPIRV_VERSION };
43 
44 	opt.SetMessageConsumer([](spv_message_level_t level, const char *source, const spv_position_t &position, const char *message) {
45 		switch(level)
46 		{
47 		case SPV_MSG_FATAL: sw::warn("SPIR-V FATAL: %d:%d %s\n", int(position.line), int(position.column), message);
48 		case SPV_MSG_INTERNAL_ERROR: sw::warn("SPIR-V INTERNAL_ERROR: %d:%d %s\n", int(position.line), int(position.column), message);
49 		case SPV_MSG_ERROR: sw::warn("SPIR-V ERROR: %d:%d %s\n", int(position.line), int(position.column), message);
50 		case SPV_MSG_WARNING: sw::warn("SPIR-V WARNING: %d:%d %s\n", int(position.line), int(position.column), message);
51 		case SPV_MSG_INFO: sw::trace("SPIR-V INFO: %d:%d %s\n", int(position.line), int(position.column), message);
52 		case SPV_MSG_DEBUG: sw::trace("SPIR-V DEBUG: %d:%d %s\n", int(position.line), int(position.column), message);
53 		default: sw::trace("SPIR-V MESSAGE: %d:%d %s\n", int(position.line), int(position.column), message);
54 		}
55 	});
56 
57 	// If the pipeline uses specialization, apply the specializations before freezing
58 	if(specializationInfo)
59 	{
60 		std::unordered_map<uint32_t, std::vector<uint32_t>> specializations;
61 		const uint8_t *specializationData = static_cast<const uint8_t *>(specializationInfo->pData);
62 
63 		for(uint32_t i = 0; i < specializationInfo->mapEntryCount; i++)
64 		{
65 			const VkSpecializationMapEntry &entry = specializationInfo->pMapEntries[i];
66 			const uint8_t *value_ptr = specializationData + entry.offset;
67 			std::vector<uint32_t> value(reinterpret_cast<const uint32_t *>(value_ptr),
68 			                            reinterpret_cast<const uint32_t *>(value_ptr + entry.size));
69 			specializations.emplace(entry.constantID, std::move(value));
70 		}
71 
72 		opt.RegisterPass(spvtools::CreateSetSpecConstantDefaultValuePass(specializations));
73 	}
74 
75 	if(optimize)
76 	{
77 		// Remove DontInline flags so the optimizer force-inlines all functions,
78 		// as we currently don't support OpFunctionCall (b/141246700).
79 		opt.RegisterPass(spvtools::CreateRemoveDontInlinePass());
80 
81 		// Full optimization list taken from spirv-opt.
82 		opt.RegisterPerformancePasses();
83 	}
84 
85 	spvtools::OptimizerOptions optimizerOptions = {};
86 #if defined(NDEBUG)
87 	optimizerOptions.set_run_validator(false);
88 #else
89 	optimizerOptions.set_run_validator(true);
90 	spvtools::ValidatorOptions validatorOptions = {};
91 	validatorOptions.SetScalarBlockLayout(true);            // VK_EXT_scalar_block_layout
92 	validatorOptions.SetUniformBufferStandardLayout(true);  // VK_KHR_uniform_buffer_standard_layout
93 	validatorOptions.SetAllowLocalSizeId(true);             // VK_KHR_maintenance4
94 	optimizerOptions.set_validator_options(validatorOptions);
95 #endif
96 
97 	sw::SpirvBinary optimized;
98 	opt.Run(code.data(), code.size(), &optimized, optimizerOptions);
99 	ASSERT(optimized.size() > 0);
100 
101 	if(false)
102 	{
103 		spvtools::SpirvTools core(vk::SPIRV_VERSION);
104 		std::string preOpt;
105 		core.Disassemble(code, &preOpt, SPV_BINARY_TO_TEXT_OPTION_NONE);
106 		std::string postOpt;
107 		core.Disassemble(optimized, &postOpt, SPV_BINARY_TO_TEXT_OPTION_NONE);
108 		std::cout << "PRE-OPT: " << preOpt << std::endl
109 		          << "POST-OPT: " << postOpt << std::endl;
110 	}
111 
112 	return optimized;
113 }
114 
createProgram(vk::Device * device,std::shared_ptr<sw::SpirvShader> shader,const vk::PipelineLayout * layout)115 std::shared_ptr<sw::ComputeProgram> createProgram(vk::Device *device, std::shared_ptr<sw::SpirvShader> shader, const vk::PipelineLayout *layout)
116 {
117 	MARL_SCOPED_EVENT("createProgram");
118 
119 	vk::DescriptorSet::Bindings descriptorSets;  // TODO(b/129523279): Delay code generation until dispatch time.
120 	// TODO(b/119409619): use allocator.
121 	auto program = std::make_shared<sw::ComputeProgram>(device, shader, layout, descriptorSets);
122 	program->generate();
123 	program->finalize("ComputeProgram");
124 
125 	return program;
126 }
127 
128 class PipelineCreationFeedback
129 {
130 public:
PipelineCreationFeedback(const VkGraphicsPipelineCreateInfo * pCreateInfo)131 	PipelineCreationFeedback(const VkGraphicsPipelineCreateInfo *pCreateInfo)
132 	    : pipelineCreationFeedback(GetPipelineCreationFeedback(pCreateInfo->pNext))
133 	{
134 		pipelineCreationBegins();
135 	}
136 
PipelineCreationFeedback(const VkComputePipelineCreateInfo * pCreateInfo)137 	PipelineCreationFeedback(const VkComputePipelineCreateInfo *pCreateInfo)
138 	    : pipelineCreationFeedback(GetPipelineCreationFeedback(pCreateInfo->pNext))
139 	{
140 		pipelineCreationBegins();
141 	}
142 
~PipelineCreationFeedback()143 	~PipelineCreationFeedback()
144 	{
145 		pipelineCreationEnds();
146 	}
147 
stageCreationBegins(uint32_t stage)148 	void stageCreationBegins(uint32_t stage)
149 	{
150 		if(pipelineCreationFeedback && (stage < pipelineCreationFeedback->pipelineStageCreationFeedbackCount))
151 		{
152 			// Record stage creation begin time
153 			pipelineCreationFeedback->pPipelineStageCreationFeedbacks[stage].duration = now();
154 		}
155 	}
156 
cacheHit(uint32_t stage)157 	void cacheHit(uint32_t stage)
158 	{
159 		if(pipelineCreationFeedback)
160 		{
161 			pipelineCreationFeedback->pPipelineCreationFeedback->flags |=
162 			    VK_PIPELINE_CREATION_FEEDBACK_APPLICATION_PIPELINE_CACHE_HIT_BIT;
163 			if(stage < pipelineCreationFeedback->pipelineStageCreationFeedbackCount)
164 			{
165 				pipelineCreationFeedback->pPipelineStageCreationFeedbacks[stage].flags |=
166 				    VK_PIPELINE_CREATION_FEEDBACK_APPLICATION_PIPELINE_CACHE_HIT_BIT;
167 			}
168 		}
169 	}
170 
stageCreationEnds(uint32_t stage)171 	void stageCreationEnds(uint32_t stage)
172 	{
173 		if(pipelineCreationFeedback && (stage < pipelineCreationFeedback->pipelineStageCreationFeedbackCount))
174 		{
175 			pipelineCreationFeedback->pPipelineStageCreationFeedbacks[stage].flags |=
176 			    VK_PIPELINE_CREATION_FEEDBACK_VALID_BIT;
177 			pipelineCreationFeedback->pPipelineStageCreationFeedbacks[stage].duration =
178 			    now() - pipelineCreationFeedback->pPipelineStageCreationFeedbacks[stage].duration;
179 		}
180 	}
181 
pipelineCreationError()182 	void pipelineCreationError()
183 	{
184 		clear();
185 		pipelineCreationFeedback = nullptr;
186 	}
187 
188 private:
GetPipelineCreationFeedback(const void * pNext)189 	static const VkPipelineCreationFeedbackCreateInfo *GetPipelineCreationFeedback(const void *pNext)
190 	{
191 		return vk::GetExtendedStruct<VkPipelineCreationFeedbackCreateInfo>(pNext, VK_STRUCTURE_TYPE_PIPELINE_CREATION_FEEDBACK_CREATE_INFO);
192 	}
193 
pipelineCreationBegins()194 	void pipelineCreationBegins()
195 	{
196 		if(pipelineCreationFeedback)
197 		{
198 			clear();
199 
200 			// Record pipeline creation begin time
201 			pipelineCreationFeedback->pPipelineCreationFeedback->duration = now();
202 		}
203 	}
204 
pipelineCreationEnds()205 	void pipelineCreationEnds()
206 	{
207 		if(pipelineCreationFeedback)
208 		{
209 			pipelineCreationFeedback->pPipelineCreationFeedback->flags |=
210 			    VK_PIPELINE_CREATION_FEEDBACK_VALID_BIT;
211 			pipelineCreationFeedback->pPipelineCreationFeedback->duration =
212 			    now() - pipelineCreationFeedback->pPipelineCreationFeedback->duration;
213 		}
214 	}
215 
clear()216 	void clear()
217 	{
218 		if(pipelineCreationFeedback)
219 		{
220 			// Clear all flags and durations
221 			pipelineCreationFeedback->pPipelineCreationFeedback->flags = 0;
222 			pipelineCreationFeedback->pPipelineCreationFeedback->duration = 0;
223 			for(uint32_t i = 0; i < pipelineCreationFeedback->pipelineStageCreationFeedbackCount; i++)
224 			{
225 				pipelineCreationFeedback->pPipelineStageCreationFeedbacks[i].flags = 0;
226 				pipelineCreationFeedback->pPipelineStageCreationFeedbacks[i].duration = 0;
227 			}
228 		}
229 	}
230 
now()231 	uint64_t now()
232 	{
233 		return std::chrono::time_point_cast<std::chrono::nanoseconds>(std::chrono::system_clock::now()).time_since_epoch().count();
234 	}
235 
236 	const VkPipelineCreationFeedbackCreateInfo *pipelineCreationFeedback = nullptr;
237 };
238 
getRobustBufferAccess(VkPipelineRobustnessBufferBehaviorEXT behavior,bool inheritRobustBufferAccess)239 bool getRobustBufferAccess(VkPipelineRobustnessBufferBehaviorEXT behavior, bool inheritRobustBufferAccess)
240 {
241 	// Based on behavior:
242 	// - <not provided>:
243 	//   * For pipelines, use device's robustBufferAccess
244 	//   * For shaders, use pipeline's robustBufferAccess
245 	//     Note that pipeline's robustBufferAccess is already set to device's if not overriden.
246 	// - Default: Use device's robustBufferAccess
247 	// - Disabled / Enabled: Override to disabled or enabled
248 	//
249 	// This function is passed "DEFAULT" when override is not provided, and
250 	// inheritRobustBufferAccess is appropriately set to the device or pipeline's
251 	// robustBufferAccess
252 	switch(behavior)
253 	{
254 	case VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DEVICE_DEFAULT_EXT:
255 		return inheritRobustBufferAccess;
256 	case VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DISABLED_EXT:
257 		return false;
258 	case VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_ROBUST_BUFFER_ACCESS_EXT:
259 		return true;
260 	default:
261 		UNSUPPORTED("Unsupported robustness behavior");
262 		return true;
263 	}
264 }
265 
getRobustBufferAccess(const VkPipelineRobustnessCreateInfoEXT * overrideRobustness,bool deviceRobustBufferAccess,bool inheritRobustBufferAccess)266 bool getRobustBufferAccess(const VkPipelineRobustnessCreateInfoEXT *overrideRobustness, bool deviceRobustBufferAccess, bool inheritRobustBufferAccess)
267 {
268 	VkPipelineRobustnessBufferBehaviorEXT storageBehavior = VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DEVICE_DEFAULT_EXT;
269 	VkPipelineRobustnessBufferBehaviorEXT uniformBehavior = VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DEVICE_DEFAULT_EXT;
270 	VkPipelineRobustnessBufferBehaviorEXT vertexBehavior = VK_PIPELINE_ROBUSTNESS_BUFFER_BEHAVIOR_DEVICE_DEFAULT_EXT;
271 
272 	if(overrideRobustness)
273 	{
274 		storageBehavior = overrideRobustness->storageBuffers;
275 		uniformBehavior = overrideRobustness->uniformBuffers;
276 		vertexBehavior = overrideRobustness->vertexInputs;
277 		inheritRobustBufferAccess = deviceRobustBufferAccess;
278 	}
279 
280 	bool storageRobustBufferAccess = getRobustBufferAccess(storageBehavior, inheritRobustBufferAccess);
281 	bool uniformRobustBufferAccess = getRobustBufferAccess(uniformBehavior, inheritRobustBufferAccess);
282 	bool vertexRobustBufferAccess = getRobustBufferAccess(vertexBehavior, inheritRobustBufferAccess);
283 
284 	// Note: in the initial implementation, enabling robust access for any buffer enables it for
285 	// all.  TODO(b/185122256) split robustBufferAccess in the pipeline and shaders into three
286 	// categories and provide robustness for storage, uniform and vertex buffers accordingly.
287 	return storageRobustBufferAccess || uniformRobustBufferAccess || vertexRobustBufferAccess;
288 }
289 
getPipelineRobustBufferAccess(const void * pNext,vk::Device * device)290 bool getPipelineRobustBufferAccess(const void *pNext, vk::Device *device)
291 {
292 	const VkPipelineRobustnessCreateInfoEXT *overrideRobustness = vk::GetExtendedStruct<VkPipelineRobustnessCreateInfoEXT>(pNext, VK_STRUCTURE_TYPE_PIPELINE_ROBUSTNESS_CREATE_INFO_EXT);
293 	const bool deviceRobustBufferAccess = device->getEnabledFeatures().robustBufferAccess;
294 
295 	// For pipelines, there's no robustBufferAccess to inherit from.  Default and no-override
296 	// both lead to using the device's robustBufferAccess.
297 	return getRobustBufferAccess(overrideRobustness, deviceRobustBufferAccess, deviceRobustBufferAccess);
298 }
299 
getPipelineStageRobustBufferAccess(const void * pNext,vk::Device * device,bool pipelineRobustBufferAccess)300 bool getPipelineStageRobustBufferAccess(const void *pNext, vk::Device *device, bool pipelineRobustBufferAccess)
301 {
302 	const VkPipelineRobustnessCreateInfoEXT *overrideRobustness = vk::GetExtendedStruct<VkPipelineRobustnessCreateInfoEXT>(pNext, VK_STRUCTURE_TYPE_PIPELINE_ROBUSTNESS_CREATE_INFO_EXT);
303 	const bool deviceRobustBufferAccess = device->getEnabledFeatures().robustBufferAccess;
304 
305 	return getRobustBufferAccess(overrideRobustness, deviceRobustBufferAccess, pipelineRobustBufferAccess);
306 }
307 
308 }  // anonymous namespace
309 
310 namespace vk {
Pipeline(PipelineLayout * layout,Device * device,bool robustBufferAccess)311 Pipeline::Pipeline(PipelineLayout *layout, Device *device, bool robustBufferAccess)
312     : layout(layout)
313     , device(device)
314     , robustBufferAccess(robustBufferAccess)
315 {
316 	if(layout)
317 	{
318 		layout->incRefCount();
319 	}
320 }
321 
destroy(const VkAllocationCallbacks * pAllocator)322 void Pipeline::destroy(const VkAllocationCallbacks *pAllocator)
323 {
324 	destroyPipeline(pAllocator);
325 
326 	if(layout)
327 	{
328 		vk::release(static_cast<VkPipelineLayout>(*layout), pAllocator);
329 	}
330 }
331 
GraphicsPipeline(const VkGraphicsPipelineCreateInfo * pCreateInfo,void * mem,Device * device)332 GraphicsPipeline::GraphicsPipeline(const VkGraphicsPipelineCreateInfo *pCreateInfo, void *mem, Device *device)
333     : Pipeline(vk::Cast(pCreateInfo->layout), device, getPipelineRobustBufferAccess(pCreateInfo->pNext, device))
334     , state(device, pCreateInfo, layout)
335 {
336 	// Either the vertex input interface comes from a pipeline library, or the
337 	// VkGraphicsPipelineCreateInfo itself.  Same with shaders.
338 	const auto *libraryCreateInfo = GetExtendedStruct<VkPipelineLibraryCreateInfoKHR>(pCreateInfo->pNext, VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR);
339 	bool vertexInputInterfaceInLibraries = false;
340 	if(libraryCreateInfo)
341 	{
342 		for(uint32_t i = 0; i < libraryCreateInfo->libraryCount; ++i)
343 		{
344 			const auto *library = static_cast<const vk::GraphicsPipeline *>(vk::Cast(libraryCreateInfo->pLibraries[i]));
345 			if(library->state.hasVertexInputInterfaceState())
346 			{
347 				inputs = library->inputs;
348 				vertexInputInterfaceInLibraries = true;
349 			}
350 			if(library->state.hasPreRasterizationState())
351 			{
352 				vertexShader = library->vertexShader;
353 			}
354 			if(library->state.hasFragmentState())
355 			{
356 				fragmentShader = library->fragmentShader;
357 			}
358 		}
359 	}
360 	if(state.hasVertexInputInterfaceState() && !vertexInputInterfaceInLibraries)
361 	{
362 		inputs.initialize(pCreateInfo->pVertexInputState);
363 	}
364 }
365 
destroyPipeline(const VkAllocationCallbacks * pAllocator)366 void GraphicsPipeline::destroyPipeline(const VkAllocationCallbacks *pAllocator)
367 {
368 	vertexShader.reset();
369 	fragmentShader.reset();
370 }
371 
ComputeRequiredAllocationSize(const VkGraphicsPipelineCreateInfo * pCreateInfo)372 size_t GraphicsPipeline::ComputeRequiredAllocationSize(const VkGraphicsPipelineCreateInfo *pCreateInfo)
373 {
374 	return 0;
375 }
376 
GetGraphicsPipelineSubset(const VkGraphicsPipelineCreateInfo * pCreateInfo)377 VkGraphicsPipelineLibraryFlagsEXT GraphicsPipeline::GetGraphicsPipelineSubset(const VkGraphicsPipelineCreateInfo *pCreateInfo)
378 {
379 	const auto *libraryCreateInfo = vk::GetExtendedStruct<VkPipelineLibraryCreateInfoKHR>(pCreateInfo->pNext, VK_STRUCTURE_TYPE_PIPELINE_LIBRARY_CREATE_INFO_KHR);
380 	const auto *graphicsLibraryCreateInfo = vk::GetExtendedStruct<VkGraphicsPipelineLibraryCreateInfoEXT>(pCreateInfo->pNext, VK_STRUCTURE_TYPE_GRAPHICS_PIPELINE_LIBRARY_CREATE_INFO_EXT);
381 
382 	if(graphicsLibraryCreateInfo)
383 	{
384 		return graphicsLibraryCreateInfo->flags;
385 	}
386 
387 	// > If this structure is omitted, and either VkGraphicsPipelineCreateInfo::flags
388 	// > includes VK_PIPELINE_CREATE_LIBRARY_BIT_KHR or the
389 	// > VkGraphicsPipelineCreateInfo::pNext chain includes a VkPipelineLibraryCreateInfoKHR
390 	// > structure with a libraryCount greater than 0, it is as if flags is 0. Otherwise if
391 	// > this structure is omitted, it is as if flags includes all possible subsets of the
392 	// > graphics pipeline (i.e. a complete graphics pipeline).
393 	//
394 	// The above basically says that when a pipeline is created:
395 	// - If not a library and not created from libraries, it's a complete pipeline (i.e.
396 	//   Vulkan 1.0 pipelines)
397 	// - If only created from other libraries, no state is taken from
398 	//   VkGraphicsPipelineCreateInfo.
399 	//
400 	// Otherwise the behavior when creating a library from other libraries is that some
401 	// state is taken from VkGraphicsPipelineCreateInfo and some from the libraries.
402 	const bool isLibrary = (pCreateInfo->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) != 0;
403 	if(isLibrary || (libraryCreateInfo && libraryCreateInfo->libraryCount > 0))
404 	{
405 		return 0;
406 	}
407 
408 	return VK_GRAPHICS_PIPELINE_LIBRARY_VERTEX_INPUT_INTERFACE_BIT_EXT |
409 	       VK_GRAPHICS_PIPELINE_LIBRARY_PRE_RASTERIZATION_SHADERS_BIT_EXT |
410 	       VK_GRAPHICS_PIPELINE_LIBRARY_FRAGMENT_SHADER_BIT_EXT |
411 	       VK_GRAPHICS_PIPELINE_LIBRARY_FRAGMENT_OUTPUT_INTERFACE_BIT_EXT;
412 }
413 
getIndexBuffers(const vk::DynamicState & dynamicState,uint32_t count,uint32_t first,bool indexed,std::vector<std::pair<uint32_t,void * >> * indexBuffers) const414 void GraphicsPipeline::getIndexBuffers(const vk::DynamicState &dynamicState, uint32_t count, uint32_t first, bool indexed, std::vector<std::pair<uint32_t, void *>> *indexBuffers) const
415 {
416 	const vk::VertexInputInterfaceState &vertexInputInterfaceState = state.getVertexInputInterfaceState();
417 
418 	const VkPrimitiveTopology topology = vertexInputInterfaceState.hasDynamicTopology() ? dynamicState.primitiveTopology : vertexInputInterfaceState.getTopology();
419 	const bool hasPrimitiveRestartEnable = vertexInputInterfaceState.hasDynamicPrimitiveRestartEnable() ? dynamicState.primitiveRestartEnable : vertexInputInterfaceState.hasPrimitiveRestartEnable();
420 	indexBuffer.getIndexBuffers(topology, count, first, indexed, hasPrimitiveRestartEnable, indexBuffers);
421 }
422 
preRasterizationContainsImageWrite() const423 bool GraphicsPipeline::preRasterizationContainsImageWrite() const
424 {
425 	return vertexShader.get() && vertexShader->containsImageWrite();
426 }
427 
fragmentContainsImageWrite() const428 bool GraphicsPipeline::fragmentContainsImageWrite() const
429 {
430 	return fragmentShader.get() && fragmentShader->containsImageWrite();
431 }
432 
setShader(const VkShaderStageFlagBits & stage,const std::shared_ptr<sw::SpirvShader> spirvShader)433 void GraphicsPipeline::setShader(const VkShaderStageFlagBits &stage, const std::shared_ptr<sw::SpirvShader> spirvShader)
434 {
435 	switch(stage)
436 	{
437 	case VK_SHADER_STAGE_VERTEX_BIT:
438 		ASSERT(vertexShader.get() == nullptr);
439 		vertexShader = spirvShader;
440 		break;
441 
442 	case VK_SHADER_STAGE_FRAGMENT_BIT:
443 		ASSERT(fragmentShader.get() == nullptr);
444 		fragmentShader = spirvShader;
445 		break;
446 
447 	default:
448 		UNSUPPORTED("Unsupported stage");
449 		break;
450 	}
451 }
452 
getShader(const VkShaderStageFlagBits & stage) const453 const std::shared_ptr<sw::SpirvShader> GraphicsPipeline::getShader(const VkShaderStageFlagBits &stage) const
454 {
455 	switch(stage)
456 	{
457 	case VK_SHADER_STAGE_VERTEX_BIT:
458 		return vertexShader;
459 	case VK_SHADER_STAGE_FRAGMENT_BIT:
460 		return fragmentShader;
461 	default:
462 		UNSUPPORTED("Unsupported stage");
463 		return fragmentShader;
464 	}
465 }
466 
compileShaders(const VkAllocationCallbacks * pAllocator,const VkGraphicsPipelineCreateInfo * pCreateInfo,PipelineCache * pPipelineCache)467 VkResult GraphicsPipeline::compileShaders(const VkAllocationCallbacks *pAllocator, const VkGraphicsPipelineCreateInfo *pCreateInfo, PipelineCache *pPipelineCache)
468 {
469 	PipelineCreationFeedback pipelineCreationFeedback(pCreateInfo);
470 	VkGraphicsPipelineLibraryFlagsEXT pipelineSubset = GetGraphicsPipelineSubset(pCreateInfo);
471 	const bool expectVertexShader = (pipelineSubset & VK_GRAPHICS_PIPELINE_LIBRARY_PRE_RASTERIZATION_SHADERS_BIT_EXT) != 0;
472 	const bool expectFragmentShader = (pipelineSubset & VK_GRAPHICS_PIPELINE_LIBRARY_FRAGMENT_SHADER_BIT_EXT) != 0;
473 
474 	for(uint32_t stageIndex = 0; stageIndex < pCreateInfo->stageCount; stageIndex++)
475 	{
476 		const VkPipelineShaderStageCreateInfo &stageInfo = pCreateInfo->pStages[stageIndex];
477 
478 		// Ignore stages that don't exist in the pipeline library.
479 		if((stageInfo.stage == VK_SHADER_STAGE_VERTEX_BIT && !expectVertexShader) ||
480 		   (stageInfo.stage == VK_SHADER_STAGE_FRAGMENT_BIT && !expectFragmentShader))
481 		{
482 			continue;
483 		}
484 
485 		pipelineCreationFeedback.stageCreationBegins(stageIndex);
486 
487 		if((stageInfo.flags &
488 		    ~(VK_PIPELINE_SHADER_STAGE_CREATE_ALLOW_VARYING_SUBGROUP_SIZE_BIT |
489 		      VK_PIPELINE_SHADER_STAGE_CREATE_REQUIRE_FULL_SUBGROUPS_BIT)) != 0)
490 		{
491 			UNSUPPORTED("pStage->flags 0x%08X", int(stageInfo.flags));
492 		}
493 
494 		const bool optimize = true;  // TODO(b/251802301): Don't optimize when debugging shaders.
495 
496 		const ShaderModule *module = vk::Cast(stageInfo.module);
497 
498 		// VK_EXT_graphics_pipeline_library allows VkShaderModuleCreateInfo to be chained to
499 		// VkPipelineShaderStageCreateInfo, which is used if stageInfo.module is
500 		// VK_NULL_HANDLE.
501 		VkShaderModule tempModule = {};
502 		if(stageInfo.module == VK_NULL_HANDLE)
503 		{
504 			const auto *moduleCreateInfo = vk::GetExtendedStruct<VkShaderModuleCreateInfo>(stageInfo.pNext,
505 			                                                                               VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO);
506 			ASSERT(moduleCreateInfo);
507 			VkResult createResult = vk::ShaderModule::Create(nullptr, moduleCreateInfo, &tempModule);
508 			if(createResult != VK_SUCCESS)
509 			{
510 				return createResult;
511 			}
512 
513 			module = vk::Cast(tempModule);
514 		}
515 
516 		const PipelineCache::SpirvBinaryKey key(module->getBinary(), stageInfo.pSpecializationInfo, robustBufferAccess, optimize);
517 
518 		if((pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT) &&
519 		   (!pPipelineCache || !pPipelineCache->contains(key)))
520 		{
521 			pipelineCreationFeedback.pipelineCreationError();
522 			return VK_PIPELINE_COMPILE_REQUIRED_EXT;
523 		}
524 
525 		sw::SpirvBinary spirv;
526 
527 		if(pPipelineCache)
528 		{
529 			auto onCacheMiss = [&] { return optimizeSpirv(key); };
530 			auto onCacheHit = [&] { pipelineCreationFeedback.cacheHit(stageIndex); };
531 			spirv = pPipelineCache->getOrOptimizeSpirv(key, onCacheMiss, onCacheHit);
532 		}
533 		else
534 		{
535 			spirv = optimizeSpirv(key);
536 
537 			// If the pipeline does not have specialization constants, there's a 1-to-1 mapping between the unoptimized and optimized SPIR-V,
538 			// so we should use a 1-to-1 mapping of the identifiers to avoid JIT routine recompiles.
539 			if(!key.getSpecializationInfo())
540 			{
541 				spirv.mapOptimizedIdentifier(key.getBinary());
542 			}
543 		}
544 
545 		const bool stageRobustBufferAccess = getPipelineStageRobustBufferAccess(stageInfo.pNext, device, robustBufferAccess);
546 
547 		// TODO(b/201798871): use allocator.
548 		auto shader = std::make_shared<sw::SpirvShader>(stageInfo.stage, stageInfo.pName, spirv,
549 		                                                vk::Cast(pCreateInfo->renderPass), pCreateInfo->subpass, stageRobustBufferAccess);
550 
551 		setShader(stageInfo.stage, shader);
552 
553 		pipelineCreationFeedback.stageCreationEnds(stageIndex);
554 
555 		if(tempModule != VK_NULL_HANDLE)
556 		{
557 			vk::destroy(tempModule, nullptr);
558 		}
559 	}
560 
561 	return VK_SUCCESS;
562 }
563 
ComputePipeline(const VkComputePipelineCreateInfo * pCreateInfo,void * mem,Device * device)564 ComputePipeline::ComputePipeline(const VkComputePipelineCreateInfo *pCreateInfo, void *mem, Device *device)
565     : Pipeline(vk::Cast(pCreateInfo->layout), device, getPipelineRobustBufferAccess(pCreateInfo->pNext, device))
566 {
567 }
568 
destroyPipeline(const VkAllocationCallbacks * pAllocator)569 void ComputePipeline::destroyPipeline(const VkAllocationCallbacks *pAllocator)
570 {
571 	shader.reset();
572 	program.reset();
573 }
574 
ComputeRequiredAllocationSize(const VkComputePipelineCreateInfo * pCreateInfo)575 size_t ComputePipeline::ComputeRequiredAllocationSize(const VkComputePipelineCreateInfo *pCreateInfo)
576 {
577 	return 0;
578 }
579 
compileShaders(const VkAllocationCallbacks * pAllocator,const VkComputePipelineCreateInfo * pCreateInfo,PipelineCache * pPipelineCache)580 VkResult ComputePipeline::compileShaders(const VkAllocationCallbacks *pAllocator, const VkComputePipelineCreateInfo *pCreateInfo, PipelineCache *pPipelineCache)
581 {
582 	PipelineCreationFeedback pipelineCreationFeedback(pCreateInfo);
583 	pipelineCreationFeedback.stageCreationBegins(0);
584 
585 	auto &stage = pCreateInfo->stage;
586 	const ShaderModule *module = vk::Cast(stage.module);
587 
588 	// VK_EXT_graphics_pipeline_library allows VkShaderModuleCreateInfo to be chained to
589 	// VkPipelineShaderStageCreateInfo, which is used if stageInfo.module is
590 	// VK_NULL_HANDLE.
591 	VkShaderModule tempModule = {};
592 	if(stage.module == VK_NULL_HANDLE)
593 	{
594 		const auto *moduleCreateInfo = vk::GetExtendedStruct<VkShaderModuleCreateInfo>(stage.pNext,
595 		                                                                               VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO);
596 		ASSERT(moduleCreateInfo);
597 		VkResult createResult = vk::ShaderModule::Create(nullptr, moduleCreateInfo, &tempModule);
598 		if(createResult != VK_SUCCESS)
599 		{
600 			return createResult;
601 		}
602 
603 		module = vk::Cast(tempModule);
604 	}
605 
606 	ASSERT(shader.get() == nullptr);
607 	ASSERT(program.get() == nullptr);
608 
609 	const bool optimize = true;  // TODO(b/251802301): Don't optimize when debugging shaders.
610 
611 	const PipelineCache::SpirvBinaryKey shaderKey(module->getBinary(), stage.pSpecializationInfo, robustBufferAccess, optimize);
612 
613 	if((pCreateInfo->flags & VK_PIPELINE_CREATE_FAIL_ON_PIPELINE_COMPILE_REQUIRED_BIT_EXT) &&
614 	   (!pPipelineCache || !pPipelineCache->contains(shaderKey)))
615 	{
616 		pipelineCreationFeedback.pipelineCreationError();
617 		return VK_PIPELINE_COMPILE_REQUIRED_EXT;
618 	}
619 
620 	sw::SpirvBinary spirv;
621 
622 	if(pPipelineCache)
623 	{
624 		auto onCacheMiss = [&] { return optimizeSpirv(shaderKey); };
625 		auto onCacheHit = [&] { pipelineCreationFeedback.cacheHit(0); };
626 		spirv = pPipelineCache->getOrOptimizeSpirv(shaderKey, onCacheMiss, onCacheHit);
627 	}
628 	else
629 	{
630 		spirv = optimizeSpirv(shaderKey);
631 
632 		// If the pipeline does not have specialization constants, there's a 1-to-1 mapping between the unoptimized and optimized SPIR-V,
633 		// so we should use a 1-to-1 mapping of the identifiers to avoid JIT routine recompiles.
634 		if(!shaderKey.getSpecializationInfo())
635 		{
636 			spirv.mapOptimizedIdentifier(shaderKey.getBinary());
637 		}
638 	}
639 
640 	const bool stageRobustBufferAccess = getPipelineStageRobustBufferAccess(stage.pNext, device, robustBufferAccess);
641 
642 	// TODO(b/201798871): use allocator.
643 	shader = std::make_shared<sw::SpirvShader>(stage.stage, stage.pName, spirv,
644 	                                           nullptr, 0, stageRobustBufferAccess);
645 
646 	const PipelineCache::ComputeProgramKey programKey(shader->getIdentifier(), layout->identifier);
647 
648 	if(pPipelineCache)
649 	{
650 		program = pPipelineCache->getOrCreateComputeProgram(programKey, [&] {
651 			return createProgram(device, shader, layout);
652 		});
653 	}
654 	else
655 	{
656 		program = createProgram(device, shader, layout);
657 	}
658 
659 	pipelineCreationFeedback.stageCreationEnds(0);
660 
661 	return VK_SUCCESS;
662 }
663 
run(uint32_t baseGroupX,uint32_t baseGroupY,uint32_t baseGroupZ,uint32_t groupCountX,uint32_t groupCountY,uint32_t groupCountZ,const vk::DescriptorSet::Array & descriptorSetObjects,const vk::DescriptorSet::Bindings & descriptorSets,const vk::DescriptorSet::DynamicOffsets & descriptorDynamicOffsets,const vk::Pipeline::PushConstantStorage & pushConstants)664 void ComputePipeline::run(uint32_t baseGroupX, uint32_t baseGroupY, uint32_t baseGroupZ,
665                           uint32_t groupCountX, uint32_t groupCountY, uint32_t groupCountZ,
666                           const vk::DescriptorSet::Array &descriptorSetObjects,
667                           const vk::DescriptorSet::Bindings &descriptorSets,
668                           const vk::DescriptorSet::DynamicOffsets &descriptorDynamicOffsets,
669                           const vk::Pipeline::PushConstantStorage &pushConstants)
670 {
671 	ASSERT_OR_RETURN(program != nullptr);
672 	program->run(
673 	    descriptorSetObjects, descriptorSets, descriptorDynamicOffsets, pushConstants,
674 	    baseGroupX, baseGroupY, baseGroupZ,
675 	    groupCountX, groupCountY, groupCountZ);
676 }
677 
678 }  // namespace vk
679