• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #ifndef _VKRAYTRACINGUTIL_HPP
2 #define _VKRAYTRACINGUTIL_HPP
3 /*-------------------------------------------------------------------------
4  * Vulkan CTS Framework
5  * --------------------
6  *
7  * Copyright (c) 2020 The Khronos Group Inc.
8  *
9  * Licensed under the Apache License, Version 2.0 (the "License");
10  * you may not use this file except in compliance with the License.
11  * You may obtain a copy of the License at
12  *
13  *      http://www.apache.org/licenses/LICENSE-2.0
14  *
15  * Unless required by applicable law or agreed to in writing, software
16  * distributed under the License is distributed on an "AS IS" BASIS,
17  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18  * See the License for the specific language governing permissions and
19  * limitations under the License.
20  *
21  *//*!
22  * \file
23  * \brief Vulkan ray tracing utility.
24  *//*--------------------------------------------------------------------*/
25 
26 #include "vkDefs.hpp"
27 #include "vkRef.hpp"
28 #include "vkMemUtil.hpp"
29 #include "vkBufferWithMemory.hpp"
30 
31 #include "deFloat16.h"
32 
33 #include "tcuVector.hpp"
34 #include "tcuVectorType.hpp"
35 
36 #include <vector>
37 #include <limits>
38 
39 namespace vk
40 {
41 constexpr VkShaderStageFlags	SHADER_STAGE_ALL_RAY_TRACING	= VK_SHADER_STAGE_RAYGEN_BIT_KHR
42 																| VK_SHADER_STAGE_ANY_HIT_BIT_KHR
43 																| VK_SHADER_STAGE_CLOSEST_HIT_BIT_KHR
44 																| VK_SHADER_STAGE_MISS_BIT_KHR
45 																| VK_SHADER_STAGE_INTERSECTION_BIT_KHR
46 																| VK_SHADER_STAGE_CALLABLE_BIT_KHR;
47 
48 const VkTransformMatrixKHR identityMatrix3x4 = { { { 1.0f, 0.0f, 0.0f, 0.0f }, { 0.0f, 1.0f, 0.0f, 0.0f }, { 0.0f, 0.0f, 1.0f, 0.0f } } };
49 
50 template<typename T>
makeVkSharedPtr(Move<T> move)51 inline de::SharedPtr<Move<T>> makeVkSharedPtr(Move<T> move)
52 {
53 	return de::SharedPtr<Move<T>>(new Move<T>(move));
54 }
55 
56 template<typename T>
makeVkSharedPtr(de::MovePtr<T> movePtr)57 inline de::SharedPtr<de::MovePtr<T> > makeVkSharedPtr(de::MovePtr<T> movePtr)
58 {
59 	return de::SharedPtr<de::MovePtr<T> >(new de::MovePtr<T>(movePtr));
60 }
61 
62 template<typename T>
dataOrNullPtr(const std::vector<T> & v)63 inline const T* dataOrNullPtr(const std::vector<T>& v)
64 {
65 	return (v.empty() ? DE_NULL : v.data());
66 }
67 
68 template<typename T>
dataOrNullPtr(std::vector<T> & v)69 inline T* dataOrNullPtr(std::vector<T>& v)
70 {
71 	return (v.empty() ? DE_NULL : v.data());
72 }
73 
updateRayTracingGLSL(const std::string & str)74 inline std::string updateRayTracingGLSL (const std::string& str)
75 {
76 	return str;
77 }
78 
79 std::string getCommonRayGenerationShader (void);
80 
81 // Get lowercase version of the format name with no VK_FORMAT_ prefix.
82 std::string getFormatSimpleName (vk::VkFormat format);
83 
84 // Checks the given vertex buffer format is valid for acceleration structures.
85 // Note: VK_KHR_get_physical_device_properties2 and VK_KHR_acceleration_structure are supposed to be supported.
86 void checkAccelerationStructureVertexBufferFormat (const vk::InstanceInterface &vki, vk::VkPhysicalDevice physicalDevice, vk::VkFormat format);
87 
88 class RaytracedGeometryBase
89 {
90 public:
91 								RaytracedGeometryBase			()										= delete;
92 								RaytracedGeometryBase			(const RaytracedGeometryBase& geometry)	= delete;
93 								RaytracedGeometryBase			(VkGeometryTypeKHR geometryType, VkFormat vertexFormat, VkIndexType indexType);
94 								virtual ~RaytracedGeometryBase	();
95 
getGeometryType(void) const96 	inline VkGeometryTypeKHR	getGeometryType					(void) const								{ return m_geometryType; }
isTrianglesType(void) const97 	inline bool					isTrianglesType					(void) const								{ return m_geometryType == VK_GEOMETRY_TYPE_TRIANGLES_KHR; }
getVertexFormat(void) const98 	inline VkFormat				getVertexFormat					(void) const								{ return m_vertexFormat; }
getIndexType(void) const99 	inline VkIndexType			getIndexType					(void) const								{ return m_indexType; }
usesIndices(void) const100 	inline bool					usesIndices						(void) const								{ return m_indexType != VK_INDEX_TYPE_NONE_KHR; }
getGeometryFlags(void) const101 	inline VkGeometryFlagsKHR	getGeometryFlags				(void) const								{ return m_geometryFlags; }
setGeometryFlags(const VkGeometryFlagsKHR geometryFlags)102 	inline void					setGeometryFlags				(const VkGeometryFlagsKHR geometryFlags)	{ m_geometryFlags = geometryFlags; }
103 	virtual deUint32			getVertexCount					(void) const								= 0;
104 	virtual const deUint8*		getVertexPointer				(void) const								= 0;
105 	virtual VkDeviceSize		getVertexStride					(void) const								= 0;
106 	virtual VkDeviceSize		getAABBStride					(void) const								= 0;
107 	virtual size_t				getVertexByteSize				(void) const								= 0;
108 	virtual deUint32			getIndexCount					(void) const								= 0;
109 	virtual const deUint8*		getIndexPointer					(void) const								= 0;
110 	virtual VkDeviceSize		getIndexStride					(void) const								= 0;
111 	virtual size_t				getIndexByteSize				(void) const								= 0;
112 	virtual deUint32			getPrimitiveCount				(void) const								= 0;
113 	virtual void				addVertex						(const tcu::Vec3& vertex)					= 0;
114 	virtual void				addIndex						(const deUint32& index)						= 0;
115 private:
116 	VkGeometryTypeKHR			m_geometryType;
117 	VkFormat					m_vertexFormat;
118 	VkIndexType					m_indexType;
119 	VkGeometryFlagsKHR			m_geometryFlags;
120 };
121 
122 template <typename T>
convertSatRte(float f)123 inline T convertSatRte (float f)
124 {
125 	// \note Doesn't work for 64-bit types
126 	DE_STATIC_ASSERT(sizeof(T) < sizeof(deUint64));
127 	DE_STATIC_ASSERT((-3 % 2 != 0) && (-4 % 2 == 0));
128 
129 	deInt64	minVal	= std::numeric_limits<T>::min();
130 	deInt64 maxVal	= std::numeric_limits<T>::max();
131 	float	q		= deFloatFrac(f);
132 	deInt64 intVal	= (deInt64)(f-q);
133 
134 	// Rounding.
135 	if (q == 0.5f)
136 	{
137 		if (intVal % 2 != 0)
138 			intVal++;
139 	}
140 	else if (q > 0.5f)
141 		intVal++;
142 	// else Don't add anything
143 
144 	// Saturate.
145 	intVal = de::max(minVal, de::min(maxVal, intVal));
146 
147 	return (T)intVal;
148 }
149 
150 // Converts float to signed integer with variable width.
151 // Source float is assumed to be in the [-1, 1] range.
152 template <typename T>
deFloat32ToSNorm(float src)153 inline T deFloat32ToSNorm (float src)
154 {
155 	DE_STATIC_ASSERT(std::numeric_limits<T>::is_integer && std::numeric_limits<T>::is_signed);
156 	const T range	= std::numeric_limits<T>::max();
157 	const T intVal	= convertSatRte<T>(src * static_cast<float>(range));
158 	return de::clamp<T>(intVal, -range, range);
159 }
160 
161 typedef tcu::Vector<deFloat16, 2>			Vec2_16;
162 typedef tcu::Vector<deFloat16, 3>			Vec3_16;
163 typedef tcu::Vector<deFloat16, 4>			Vec4_16;
164 typedef tcu::Vector<deInt16, 2>				Vec2_16SNorm;
165 typedef tcu::Vector<deInt16, 3>				Vec3_16SNorm;
166 typedef tcu::Vector<deInt16, 4>				Vec4_16SNorm;
167 typedef tcu::Vector<deInt8, 2>				Vec2_8SNorm;
168 typedef tcu::Vector<deInt8, 3>				Vec3_8SNorm;
169 typedef tcu::Vector<deInt8, 4>				Vec4_8SNorm;
170 
171 template<typename V>	VkFormat			vertexFormatFromType				();
vertexFormatFromType()172 template<>				inline VkFormat		vertexFormatFromType<tcu::Vec2>		()							{ return VK_FORMAT_R32G32_SFLOAT; }
vertexFormatFromType()173 template<>				inline VkFormat		vertexFormatFromType<tcu::Vec3>		()							{ return VK_FORMAT_R32G32B32_SFLOAT; }
vertexFormatFromType()174 template<>				inline VkFormat		vertexFormatFromType<tcu::Vec4>		()							{ return VK_FORMAT_R32G32B32A32_SFLOAT; }
vertexFormatFromType()175 template<>				inline VkFormat		vertexFormatFromType<Vec2_16>		()							{ return VK_FORMAT_R16G16_SFLOAT; }
vertexFormatFromType()176 template<>				inline VkFormat		vertexFormatFromType<Vec3_16>		()							{ return VK_FORMAT_R16G16B16_SFLOAT; }
vertexFormatFromType()177 template<>				inline VkFormat		vertexFormatFromType<Vec4_16>		()							{ return VK_FORMAT_R16G16B16A16_SFLOAT; }
vertexFormatFromType()178 template<>				inline VkFormat		vertexFormatFromType<Vec2_16SNorm>	()							{ return VK_FORMAT_R16G16_SNORM; }
vertexFormatFromType()179 template<>				inline VkFormat		vertexFormatFromType<Vec3_16SNorm>	()							{ return VK_FORMAT_R16G16B16_SNORM; }
vertexFormatFromType()180 template<>				inline VkFormat		vertexFormatFromType<Vec4_16SNorm>	()							{ return VK_FORMAT_R16G16B16A16_SNORM; }
vertexFormatFromType()181 template<>				inline VkFormat		vertexFormatFromType<tcu::DVec2>	()							{ return VK_FORMAT_R64G64_SFLOAT; }
vertexFormatFromType()182 template<>				inline VkFormat		vertexFormatFromType<tcu::DVec3>	()							{ return VK_FORMAT_R64G64B64_SFLOAT; }
vertexFormatFromType()183 template<>				inline VkFormat		vertexFormatFromType<tcu::DVec4>	()							{ return VK_FORMAT_R64G64B64A64_SFLOAT; }
vertexFormatFromType()184 template<>				inline VkFormat		vertexFormatFromType<Vec2_8SNorm>	()							{ return VK_FORMAT_R8G8_SNORM; }
vertexFormatFromType()185 template<>				inline VkFormat		vertexFormatFromType<Vec3_8SNorm>	()							{ return VK_FORMAT_R8G8B8_SNORM; }
vertexFormatFromType()186 template<>				inline VkFormat		vertexFormatFromType<Vec4_8SNorm>	()							{ return VK_FORMAT_R8G8B8A8_SNORM; }
187 
188 struct EmptyIndex {};
189 template<typename I>	VkIndexType			indexTypeFromType					();
indexTypeFromType()190 template<>				inline VkIndexType	indexTypeFromType<deUint16>			()							{ return VK_INDEX_TYPE_UINT16; }
indexTypeFromType()191 template<>				inline VkIndexType	indexTypeFromType<deUint32>			()							{ return VK_INDEX_TYPE_UINT32; }
indexTypeFromType()192 template<>				inline VkIndexType	indexTypeFromType<EmptyIndex>		()							{ return VK_INDEX_TYPE_NONE_KHR; }
193 
194 template<typename V>	V					convertFloatTo						(const tcu::Vec3& vertex);
convertFloatTo(const tcu::Vec3 & vertex)195 template<>				inline tcu::Vec2	convertFloatTo<tcu::Vec2>			(const tcu::Vec3& vertex)	{ return tcu::Vec2(vertex.x(), vertex.y()); }
convertFloatTo(const tcu::Vec3 & vertex)196 template<>				inline tcu::Vec3	convertFloatTo<tcu::Vec3>			(const tcu::Vec3& vertex)	{ return vertex; }
convertFloatTo(const tcu::Vec3 & vertex)197 template<>				inline tcu::Vec4	convertFloatTo<tcu::Vec4>			(const tcu::Vec3& vertex)	{ return tcu::Vec4(vertex.x(), vertex.y(), vertex.z(), 0.0f); }
convertFloatTo(const tcu::Vec3 & vertex)198 template<>				inline Vec2_16		convertFloatTo<Vec2_16>				(const tcu::Vec3& vertex)	{ return Vec2_16(deFloat32To16(vertex.x()), deFloat32To16(vertex.y())); }
convertFloatTo(const tcu::Vec3 & vertex)199 template<>				inline Vec3_16		convertFloatTo<Vec3_16>				(const tcu::Vec3& vertex)	{ return Vec3_16(deFloat32To16(vertex.x()), deFloat32To16(vertex.y()), deFloat32To16(vertex.z())); }
convertFloatTo(const tcu::Vec3 & vertex)200 template<>				inline Vec4_16		convertFloatTo<Vec4_16>				(const tcu::Vec3& vertex)	{ return Vec4_16(deFloat32To16(vertex.x()), deFloat32To16(vertex.y()), deFloat32To16(vertex.z()), deFloat32To16(0.0f)); }
convertFloatTo(const tcu::Vec3 & vertex)201 template<>				inline Vec2_16SNorm	convertFloatTo<Vec2_16SNorm>		(const tcu::Vec3& vertex)	{ return Vec2_16SNorm(deFloat32ToSNorm<deInt16>(vertex.x()), deFloat32ToSNorm<deInt16>(vertex.y())); }
convertFloatTo(const tcu::Vec3 & vertex)202 template<>				inline Vec3_16SNorm	convertFloatTo<Vec3_16SNorm>		(const tcu::Vec3& vertex)	{ return Vec3_16SNorm(deFloat32ToSNorm<deInt16>(vertex.x()), deFloat32ToSNorm<deInt16>(vertex.y()), deFloat32ToSNorm<deInt16>(vertex.z())); }
convertFloatTo(const tcu::Vec3 & vertex)203 template<>				inline Vec4_16SNorm	convertFloatTo<Vec4_16SNorm>		(const tcu::Vec3& vertex)	{ return Vec4_16SNorm(deFloat32ToSNorm<deInt16>(vertex.x()), deFloat32ToSNorm<deInt16>(vertex.y()), deFloat32ToSNorm<deInt16>(vertex.z()), deFloat32ToSNorm<deInt16>(0.0f)); }
convertFloatTo(const tcu::Vec3 & vertex)204 template<>				inline tcu::DVec2	convertFloatTo<tcu::DVec2>			(const tcu::Vec3& vertex)	{ return tcu::DVec2(static_cast<double>(vertex.x()), static_cast<double>(vertex.y())); }
convertFloatTo(const tcu::Vec3 & vertex)205 template<>				inline tcu::DVec3	convertFloatTo<tcu::DVec3>			(const tcu::Vec3& vertex)	{ return tcu::DVec3(static_cast<double>(vertex.x()), static_cast<double>(vertex.y()), static_cast<double>(vertex.z())); }
convertFloatTo(const tcu::Vec3 & vertex)206 template<>				inline tcu::DVec4	convertFloatTo<tcu::DVec4>			(const tcu::Vec3& vertex)	{ return tcu::DVec4(static_cast<double>(vertex.x()), static_cast<double>(vertex.y()), static_cast<double>(vertex.z()), 0.0); }
convertFloatTo(const tcu::Vec3 & vertex)207 template<>				inline Vec2_8SNorm	convertFloatTo<Vec2_8SNorm>			(const tcu::Vec3& vertex)	{ return Vec2_8SNorm(deFloat32ToSNorm<deInt8>(vertex.x()), deFloat32ToSNorm<deInt8>(vertex.y())); }
convertFloatTo(const tcu::Vec3 & vertex)208 template<>				inline Vec3_8SNorm	convertFloatTo<Vec3_8SNorm>			(const tcu::Vec3& vertex)	{ return Vec3_8SNorm(deFloat32ToSNorm<deInt8>(vertex.x()), deFloat32ToSNorm<deInt8>(vertex.y()), deFloat32ToSNorm<deInt8>(vertex.z())); }
convertFloatTo(const tcu::Vec3 & vertex)209 template<>				inline Vec4_8SNorm	convertFloatTo<Vec4_8SNorm>			(const tcu::Vec3& vertex)	{ return Vec4_8SNorm(deFloat32ToSNorm<deInt8>(vertex.x()), deFloat32ToSNorm<deInt8>(vertex.y()), deFloat32ToSNorm<deInt8>(vertex.z()), deFloat32ToSNorm<deInt8>(0.0f)); }
210 
211 template<typename V>	V					convertIndexTo						(deUint32 index);
convertIndexTo(deUint32 index)212 template<>				inline EmptyIndex	convertIndexTo<EmptyIndex>			(deUint32 index)			{ DE_UNREF(index); TCU_THROW(TestError, "Cannot add empty index"); }
convertIndexTo(deUint32 index)213 template<>				inline deUint16		convertIndexTo<deUint16>			(deUint32 index)			{ return static_cast<deUint16>(index); }
convertIndexTo(deUint32 index)214 template<>				inline deUint32		convertIndexTo<deUint32>			(deUint32 index)			{ return index; }
215 
216 template<typename V, typename I>
217 class RaytracedGeometry : public RaytracedGeometryBase
218 {
219 public:
220 						RaytracedGeometry			()									= delete;
221 						RaytracedGeometry			(const RaytracedGeometry& geometry)	= delete;
222 						RaytracedGeometry			(VkGeometryTypeKHR geometryType, deUint32 paddingBlocks = 0u);
223 						RaytracedGeometry			(VkGeometryTypeKHR geometryType, const std::vector<V>& vertices, const std::vector<I>& indices = std::vector<I>(), deUint32 paddingBlocks = 0u);
224 
225 	deUint32			getVertexCount				(void) const override;
226 	const deUint8*		getVertexPointer			(void) const override;
227 	VkDeviceSize		getVertexStride				(void) const override;
228 	VkDeviceSize		getAABBStride				(void) const override;
229 	size_t				getVertexByteSize			(void) const override;
230 	deUint32			getIndexCount				(void) const override;
231 	const deUint8*		getIndexPointer				(void) const override;
232 	VkDeviceSize		getIndexStride				(void) const override;
233 	size_t				getIndexByteSize			(void) const override;
234 	deUint32			getPrimitiveCount			(void) const override;
235 
236 	void				addVertex					(const tcu::Vec3& vertex) override;
237 	void				addIndex					(const deUint32& index) override;
238 
239 private:
240 	void				init						();					// To be run in constructors.
241 	void				checkGeometryType			() const;			// Checks geometry type is valid.
242 	void				calcBlockSize				();					// Calculates and saves vertex buffer block size.
243 	size_t				getBlockSize				() const;			// Return stored vertex buffer block size.
244 	void				addNativeVertex				(const V& vertex);	// Adds new vertex in native format.
245 
246 	// The implementation below stores vertices as byte blocks to take the requested padding into account. m_vertices is the array
247 	// of bytes containing vertex data.
248 	//
249 	// For triangles, the padding block has a size that is a multiple of the vertex size and each vertex is stored in a byte block
250 	// equivalent to:
251 	//
252 	//	struct Vertex
253 	//	{
254 	//		V		vertex;
255 	//		deUint8	padding[m_paddingBlocks * sizeof(V)];
256 	//	};
257 	//
258 	// For AABBs, the padding block has a size that is a multiple of kAABBPadBaseSize (see below) and vertices are stored in pairs
259 	// before the padding block. This is equivalent to:
260 	//
261 	//		struct VertexPair
262 	//		{
263 	//			V		vertices[2];
264 	//			deUint8	padding[m_paddingBlocks * kAABBPadBaseSize];
265 	//		};
266 	//
267 	// The size of each pseudo-structure above is saved to one of the correspoding union members below.
268 	union BlockSize
269 	{
270 		size_t trianglesBlockSize;
271 		size_t aabbsBlockSize;
272 	};
273 
274 	const deUint32			m_paddingBlocks;
275 	size_t					m_vertexCount;
276 	std::vector<deUint8>	m_vertices;			// Vertices are stored as byte blocks.
277 	std::vector<I>			m_indices;			// Indices are stored natively.
278 	BlockSize				m_blockSize;		// For m_vertices.
279 
280 	// Data sizes.
281 	static constexpr size_t	kVertexSize			= sizeof(V);
282 	static constexpr size_t	kIndexSize			= sizeof(I);
283 	static constexpr size_t	kAABBPadBaseSize	= 8; // As required by the spec.
284 };
285 
286 template<typename V, typename I>
RaytracedGeometry(VkGeometryTypeKHR geometryType,deUint32 paddingBlocks)287 RaytracedGeometry<V, I>::RaytracedGeometry (VkGeometryTypeKHR geometryType, deUint32 paddingBlocks)
288 	: RaytracedGeometryBase(geometryType, vertexFormatFromType<V>(), indexTypeFromType<I>())
289 	, m_paddingBlocks(paddingBlocks)
290 	, m_vertexCount(0)
291 {
292 	init();
293 }
294 
295 template<typename V, typename I>
RaytracedGeometry(VkGeometryTypeKHR geometryType,const std::vector<V> & vertices,const std::vector<I> & indices,deUint32 paddingBlocks)296 RaytracedGeometry<V,I>::RaytracedGeometry (VkGeometryTypeKHR geometryType, const std::vector<V>& vertices, const std::vector<I>& indices, deUint32 paddingBlocks)
297 	: RaytracedGeometryBase(geometryType, vertexFormatFromType<V>(), indexTypeFromType<I>())
298 	, m_paddingBlocks(paddingBlocks)
299 	, m_vertexCount(0)
300 	, m_vertices()
301 	, m_indices(indices)
302 {
303 	init();
304 	for (const auto& vertex : vertices)
305 		addNativeVertex(vertex);
306 }
307 
308 template<typename V, typename I>
getVertexCount(void) const309 deUint32 RaytracedGeometry<V,I>::getVertexCount (void) const
310 {
311 	return (isTrianglesType() ? static_cast<deUint32>(m_vertexCount) : 0u);
312 }
313 
314 template<typename V, typename I>
getVertexPointer(void) const315 const deUint8* RaytracedGeometry<V, I>::getVertexPointer (void) const
316 {
317 	DE_ASSERT(!m_vertices.empty());
318 	return reinterpret_cast<const deUint8*>(m_vertices.data());
319 }
320 
321 template<typename V, typename I>
getVertexStride(void) const322 VkDeviceSize RaytracedGeometry<V,I>::getVertexStride (void) const
323 {
324 	return ((!isTrianglesType()) ? 0ull : static_cast<VkDeviceSize>(getBlockSize()));
325 }
326 
327 template<typename V, typename I>
getAABBStride(void) const328 VkDeviceSize RaytracedGeometry<V, I>::getAABBStride (void) const
329 {
330 	return (isTrianglesType() ? 0ull : static_cast<VkDeviceSize>(getBlockSize()));
331 }
332 
333 template<typename V, typename I>
getVertexByteSize(void) const334 size_t RaytracedGeometry<V, I>::getVertexByteSize (void) const
335 {
336 	return m_vertices.size();
337 }
338 
339 template<typename V, typename I>
getIndexCount(void) const340 deUint32 RaytracedGeometry<V, I>::getIndexCount (void) const
341 {
342 	return static_cast<deUint32>(isTrianglesType() ? m_indices.size() : 0);
343 }
344 
345 template<typename V, typename I>
getIndexPointer(void) const346 const deUint8* RaytracedGeometry<V, I>::getIndexPointer (void) const
347 {
348 	const auto indexCount = getIndexCount();
349 	DE_UNREF(indexCount); // For release builds.
350 	DE_ASSERT(indexCount > 0u);
351 
352 	return reinterpret_cast<const deUint8*>(m_indices.data());
353 }
354 
355 template<typename V, typename I>
getIndexStride(void) const356 VkDeviceSize RaytracedGeometry<V, I>::getIndexStride (void) const
357 {
358 	return static_cast<VkDeviceSize>(kIndexSize);
359 }
360 
361 template<typename V, typename I>
getIndexByteSize(void) const362 size_t RaytracedGeometry<V, I>::getIndexByteSize (void) const
363 {
364 	const auto indexCount = getIndexCount();
365 	DE_ASSERT(indexCount > 0u);
366 
367 	return (indexCount * kIndexSize);
368 }
369 
370 template<typename V, typename I>
getPrimitiveCount(void) const371 deUint32 RaytracedGeometry<V,I>::getPrimitiveCount (void) const
372 {
373 	return static_cast<deUint32>(isTrianglesType() ? (usesIndices() ? m_indices.size() / 3 : m_vertexCount / 3) : (m_vertexCount / 2));
374 }
375 
376 template<typename V, typename I>
addVertex(const tcu::Vec3 & vertex)377 void RaytracedGeometry<V, I>::addVertex (const tcu::Vec3& vertex)
378 {
379 	addNativeVertex(convertFloatTo<V>(vertex));
380 }
381 
382 template<typename V, typename I>
addNativeVertex(const V & vertex)383 void RaytracedGeometry<V, I>::addNativeVertex (const V& vertex)
384 {
385 	const auto oldSize			= m_vertices.size();
386 	const auto blockSize		= getBlockSize();
387 
388 	if (isTrianglesType())
389 	{
390 		// Reserve new block, copy vertex at the beginning of the new block.
391 		m_vertices.resize(oldSize + blockSize, deUint8{0});
392 		deMemcpy(&m_vertices[oldSize], &vertex, kVertexSize);
393 	}
394 	else // AABB
395 	{
396 		if (m_vertexCount % 2 == 0)
397 		{
398 			// New block needed.
399 			m_vertices.resize(oldSize + blockSize, deUint8{0});
400 			deMemcpy(&m_vertices[oldSize], &vertex, kVertexSize);
401 		}
402 		else
403 		{
404 			// Insert in the second position of last existing block.
405 			//
406 			//												Vertex Size
407 			//												+-------+
408 			//	+-------------+------------+----------------------------------------+
409 			//	|             |            |      ...       | vertex vertex padding |
410 			//	+-------------+------------+----------------+-----------------------+
411 			//												+-----------------------+
412 			//														Block Size
413 			//	+-------------------------------------------------------------------+
414 			//							Old Size
415 			//
416 			deMemcpy(&m_vertices[oldSize - blockSize + kVertexSize], &vertex, kVertexSize);
417 		}
418 	}
419 
420 	++m_vertexCount;
421 }
422 
423 template<typename V, typename I>
addIndex(const deUint32 & index)424 void RaytracedGeometry<V, I>::addIndex (const deUint32& index)
425 {
426 	m_indices.push_back(convertIndexTo<I>(index));
427 }
428 
429 template<typename V, typename I>
init()430 void RaytracedGeometry<V, I>::init ()
431 {
432 	checkGeometryType();
433 	calcBlockSize();
434 }
435 
436 template<typename V, typename I>
checkGeometryType() const437 void RaytracedGeometry<V, I>::checkGeometryType () const
438 {
439 	const auto geometryType = getGeometryType();
440 	DE_UNREF(geometryType); // For release builds.
441 	DE_ASSERT(geometryType == VK_GEOMETRY_TYPE_TRIANGLES_KHR || geometryType == VK_GEOMETRY_TYPE_AABBS_KHR);
442 }
443 
444 template<typename V, typename I>
calcBlockSize()445 void RaytracedGeometry<V, I>::calcBlockSize ()
446 {
447 	if (isTrianglesType())
448 		m_blockSize.trianglesBlockSize = kVertexSize * static_cast<size_t>(1u + m_paddingBlocks);
449 	else
450 		m_blockSize.aabbsBlockSize = 2 * kVertexSize + m_paddingBlocks * kAABBPadBaseSize;
451 }
452 
453 template<typename V, typename I>
getBlockSize() const454 size_t RaytracedGeometry<V, I>::getBlockSize () const
455 {
456 	return (isTrianglesType() ? m_blockSize.trianglesBlockSize : m_blockSize.aabbsBlockSize);
457 }
458 
459 de::SharedPtr<RaytracedGeometryBase> makeRaytracedGeometry (VkGeometryTypeKHR geometryType, VkFormat vertexFormat, VkIndexType indexType, bool padVertices = false);
460 
461 VkDeviceAddress getBufferDeviceAddress ( const DeviceInterface&	vkd,
462 										 const VkDevice			device,
463 										 const VkBuffer			buffer,
464 										 VkDeviceSize			offset );
465 
466 // type used for creating a deep serialization/deserialization of top-level acceleration structures
467 class SerialInfo
468 {
469 	std::vector<deUint64>		m_addresses;
470 	std::vector<VkDeviceSize>	m_sizes;
471 public:
472 
473 	SerialInfo() = default;
474 
475 	// addresses: { (owner-top-level AS address) [, (first bottom_level AS address), (second bottom_level AS address), ...] }
476 	// sizes:     { (owner-top-level AS serial size) [, (first bottom_level AS serial size), (second bottom_level AS serial size), ...] }
SerialInfo(const std::vector<deUint64> & addresses,const std::vector<VkDeviceSize> & sizes)477 	SerialInfo(const std::vector<deUint64>& addresses, const std::vector<VkDeviceSize>& sizes)
478 		: m_addresses(addresses), m_sizes(sizes)
479 	{
480 		DE_ASSERT(!addresses.empty() && addresses.size() == sizes.size());
481 	}
482 
addresses() const483 	const std::vector<deUint64>&		addresses			() const	{ return m_addresses; }
sizes() const484 	const std::vector<VkDeviceSize>&	sizes				() const	{ return m_sizes; }
485 };
486 
487 class SerialStorage
488 {
489 public:
490 	enum
491 	{
492 		DE_SERIALIZED_FIELD(DRIVER_UUID,		VK_UUID_SIZE),		// VK_UUID_SIZE bytes of data matching VkPhysicalDeviceIDProperties::driverUUID
493 		DE_SERIALIZED_FIELD(COMPAT_UUID,		VK_UUID_SIZE),		// VK_UUID_SIZE bytes of data identifying the compatibility for comparison using vkGetDeviceAccelerationStructureCompatibilityKHR
494 		DE_SERIALIZED_FIELD(SERIALIZED_SIZE,	sizeof(deUint64)),	// A 64-bit integer of the total size matching the value queried using VK_QUERY_TYPE_ACCELERATION_STRUCTURE_SERIALIZATION_SIZE_KHR
495 		DE_SERIALIZED_FIELD(DESERIALIZED_SIZE,	sizeof(deUint64)),	// A 64-bit integer of the deserialized size to be passed in to VkAccelerationStructureCreateInfoKHR::size
496 		DE_SERIALIZED_FIELD(HANDLES_COUNT,		sizeof(deUint64)),	// A 64-bit integer of the count of the number of acceleration structure handles following. This will be zero for a bottom-level acceleration structure.
497 		SERIAL_STORAGE_SIZE_MIN
498 	};
499 
500 	// An old fashion C-style structure that simplifies an access to the AS header
501 	struct alignas(16) AccelerationStructureHeader
502 	{
503 		union {
504 			struct {
505 				deUint8	driverUUID[VK_UUID_SIZE];
506 				deUint8	compactUUID[VK_UUID_SIZE];
507 			};
508 			deUint8		uuids[VK_UUID_SIZE * 2];
509 		};
510 		deUint64		serializedSize;
511 		deUint64		deserializedSize;
512 		deUint64		handleCount;
513 		VkDeviceAddress	handleArray[1];
514 	};
515 
516 											SerialStorage		() = delete;
517 											SerialStorage		(const DeviceInterface&						vk,
518 																 const VkDevice								device,
519 																 Allocator&									allocator,
520 																 const VkAccelerationStructureBuildTypeKHR	buildType,
521 																 const VkDeviceSize							storageSize);
522 	// An additional constructor for creating a deep copy of top-level AS's.
523 											SerialStorage		(const DeviceInterface&						vk,
524 																 const VkDevice								device,
525 																 Allocator&									allocator,
526 																 const VkAccelerationStructureBuildTypeKHR	buildType,
527 																 const SerialInfo&							SerialInfo);
528 
529 	// below methods will return host addres if AS was build on cpu and device addres when it was build on gpu
530 	VkDeviceOrHostAddressKHR				getAddress			(const DeviceInterface&						vk,
531 																 const VkDevice								device,
532 																 const VkAccelerationStructureBuildTypeKHR	buildType);
533 	VkDeviceOrHostAddressConstKHR			getAddressConst		(const DeviceInterface&						vk,
534 																 const VkDevice								device,
535 																 const VkAccelerationStructureBuildTypeKHR	buildType);
536 
537 	// this methods retun host address regardless of where AS was built
538 	VkDeviceOrHostAddressKHR				getHostAddress		(VkDeviceSize			offset = 0);
539 	VkDeviceOrHostAddressConstKHR			getHostAddressConst	(VkDeviceSize			offset = 0);
540 
541 	// works the similar way as getHostAddressConst() but returns more readable/intuitive object
542 	AccelerationStructureHeader*			getASHeader			();
543 	bool									hasDeepFormat		() const;
544 	de::SharedPtr<SerialStorage>			getBottomStorage	(deUint32			index) const;
545 
546 	VkDeviceSize							getStorageSize		() const;
547 	const SerialInfo&						getSerialInfo		() const;
548 	deUint64								getDeserializedSize	();
549 
550 protected:
551 	const VkAccelerationStructureBuildTypeKHR	m_buildType;
552 	const VkDeviceSize							m_storageSize;
553 	const SerialInfo							m_serialInfo;
554 	de::MovePtr<BufferWithMemory>				m_buffer;
555 	std::vector<de::SharedPtr<SerialStorage>>	m_bottoms;
556 };
557 
558 class BottomLevelAccelerationStructure
559 {
560 public:
561 	static deUint32										getRequiredAllocationCount				(void);
562 
563 														BottomLevelAccelerationStructure		();
564 														BottomLevelAccelerationStructure		(const BottomLevelAccelerationStructure&		other) = delete;
565 	virtual												~BottomLevelAccelerationStructure		();
566 
567 	virtual void										setGeometryData							(const std::vector<tcu::Vec3>&					geometryData,
568 																								 const bool										triangles,
569 																								 const VkGeometryFlagsKHR						geometryFlags			= 0u );
570 	virtual void										setDefaultGeometryData					(const VkShaderStageFlagBits					testStage,
571 																								 const VkGeometryFlagsKHR						geometryFlags			= 0u );
572 	virtual void										setGeometryCount						(const size_t									geometryCount);
573 	virtual void										addGeometry								(de::SharedPtr<RaytracedGeometryBase>&			raytracedGeometry);
574 	virtual void										addGeometry								(const std::vector<tcu::Vec3>&					geometryData,
575 																								 const bool										triangles,
576 																								 const VkGeometryFlagsKHR						geometryFlags			= 0u );
577 
578 	virtual void										setBuildType							(const VkAccelerationStructureBuildTypeKHR		buildType) = DE_NULL;
579 	virtual void										setCreateFlags							(const VkAccelerationStructureCreateFlagsKHR	createFlags) = DE_NULL;
580 	virtual void										setCreateGeneric						(bool											createGeneric) = 0;
581 	virtual void										setBuildFlags							(const VkBuildAccelerationStructureFlagsKHR		buildFlags) = DE_NULL;
582 	virtual void										setBuildWithoutGeometries				(bool											buildWithoutGeometries) = 0;
583 	virtual void										setBuildWithoutPrimitives				(bool											buildWithoutPrimitives) = 0;
584 	virtual void										setDeferredOperation					(const bool										deferredOperation,
585 																								 const deUint32									workerThreadCount		= 0u ) = DE_NULL;
586 	virtual void										setUseArrayOfPointers					(const bool										useArrayOfPointers) = DE_NULL;
587 	virtual void										setIndirectBuildParameters				(const VkBuffer									indirectBuffer,
588 																								 const VkDeviceSize								indirectBufferOffset,
589 																								 const deUint32									indirectBufferStride) = DE_NULL;
590 	virtual VkBuildAccelerationStructureFlagsKHR		getBuildFlags							() const = DE_NULL;
591 	VkDeviceSize										getStructureSize						() const;
592 
593 	// methods specific for each acceleration structure
594 	virtual void										create									(const DeviceInterface&							vk,
595 																								 const VkDevice									device,
596 																								 Allocator&										allocator,
597 																								 VkDeviceSize									structureSize,
598 																								 VkDeviceAddress								deviceAddress			= 0u) = DE_NULL;
599 	virtual void										build									(const DeviceInterface&							vk,
600 																								 const VkDevice									device,
601 																								 const VkCommandBuffer							cmdBuffer) = DE_NULL;
602 	virtual void										copyFrom								(const DeviceInterface&							vk,
603 																								 const VkDevice									device,
604 																								 const VkCommandBuffer							cmdBuffer,
605 																								 BottomLevelAccelerationStructure*				accelerationStructure,
606 																								 bool											compactCopy) = DE_NULL;
607 
608 	virtual void										serialize								(const DeviceInterface&							vk,
609 																								 const VkDevice									device,
610 																								 const VkCommandBuffer							cmdBuffer,
611 																								 SerialStorage*									storage) = DE_NULL;
612 	virtual void										deserialize								(const DeviceInterface&							vk,
613 																								 const VkDevice									device,
614 																								 const VkCommandBuffer							cmdBuffer,
615 																								 SerialStorage*									storage) = DE_NULL;
616 
617 	// helper methods for typical acceleration structure creation tasks
618 	void												createAndBuild							(const DeviceInterface&							vk,
619 																								 const VkDevice									device,
620 																								 const VkCommandBuffer							cmdBuffer,
621 																								 Allocator&										allocator,
622 																								 VkDeviceAddress								deviceAddress			= 0u );
623 	void												createAndCopyFrom						(const DeviceInterface&							vk,
624 																								 const VkDevice									device,
625 																								 const VkCommandBuffer							cmdBuffer,
626 																								 Allocator&										allocator,
627 																								 BottomLevelAccelerationStructure*				accelerationStructure,
628 																								 VkDeviceSize									compactCopySize			= 0u,
629 																								 VkDeviceAddress								deviceAddress			= 0u);
630 	void												createAndDeserializeFrom				(const DeviceInterface&							vk,
631 																								 const VkDevice									device,
632 																								 const VkCommandBuffer							cmdBuffer,
633 																								 Allocator&										allocator,
634 																								 SerialStorage*									storage,
635 																								 VkDeviceAddress								deviceAddress			= 0u);
636 
637 	virtual const VkAccelerationStructureKHR*			getPtr									(void) const = DE_NULL;
638 protected:
639 	std::vector<de::SharedPtr<RaytracedGeometryBase>>	m_geometriesData;
640 	VkDeviceSize										m_structureSize;
641 	VkDeviceSize										m_updateScratchSize;
642 	VkDeviceSize										m_buildScratchSize;
643 };
644 
645 de::MovePtr<BottomLevelAccelerationStructure> makeBottomLevelAccelerationStructure ();
646 
647 struct InstanceData
648 {
InstanceDatavk::InstanceData649 								InstanceData (VkTransformMatrixKHR							matrix_,
650 											  deUint32										instanceCustomIndex_,
651 											  deUint32										mask_,
652 											  deUint32										instanceShaderBindingTableRecordOffset_,
653 											  VkGeometryInstanceFlagsKHR					flags_)
654 									: matrix(matrix_), instanceCustomIndex(instanceCustomIndex_), mask(mask_), instanceShaderBindingTableRecordOffset(instanceShaderBindingTableRecordOffset_), flags(flags_)
655 								{
656 								}
657 	VkTransformMatrixKHR		matrix;
658 	deUint32					instanceCustomIndex;
659 	deUint32					mask;
660 	deUint32					instanceShaderBindingTableRecordOffset;
661 	VkGeometryInstanceFlagsKHR	flags;
662 };
663 
664 class TopLevelAccelerationStructure
665 {
666 public:
667 	static deUint32													getRequiredAllocationCount			(void);
668 
669 																	TopLevelAccelerationStructure		();
670 																	TopLevelAccelerationStructure		(const TopLevelAccelerationStructure&				other) = delete;
671 	virtual															~TopLevelAccelerationStructure		();
672 
673 	virtual void													setInstanceCount					(const size_t										instanceCount);
674 	virtual void													addInstance							(de::SharedPtr<BottomLevelAccelerationStructure>	bottomLevelStructure,
675 																										 const VkTransformMatrixKHR&						matrix									= identityMatrix3x4,
676 																										 deUint32											instanceCustomIndex						= 0,
677 																										 deUint32											mask									= 0xFF,
678 																										 deUint32											instanceShaderBindingTableRecordOffset	= 0,
679 																										 VkGeometryInstanceFlagsKHR							flags									= VkGeometryInstanceFlagBitsKHR(0u)	);
680 
681 	virtual void													setBuildType						(const VkAccelerationStructureBuildTypeKHR			buildType) = DE_NULL;
682 	virtual void													setCreateFlags						(const VkAccelerationStructureCreateFlagsKHR		createFlags) = DE_NULL;
683 	virtual void													setCreateGeneric					(bool												createGeneric) = 0;
684 	virtual void													setBuildFlags						(const VkBuildAccelerationStructureFlagsKHR			buildFlags) = DE_NULL;
685 	virtual void													setBuildWithoutPrimitives			(bool												buildWithoutPrimitives) = 0;
686 	virtual void													setInactiveInstances				(bool												inactiveInstances) = 0;
687 	virtual void													setDeferredOperation				(const bool											deferredOperation,
688 																										 const deUint32										workerThreadCount = 0u) = DE_NULL;
689 	virtual void													setUseArrayOfPointers				(const bool											useArrayOfPointers) = DE_NULL;
690 	virtual void													setIndirectBuildParameters			(const VkBuffer										indirectBuffer,
691 																										 const VkDeviceSize									indirectBufferOffset,
692 																										 const deUint32										indirectBufferStride) = DE_NULL;
693 	virtual void													setUsePPGeometries					(const bool											usePPGeometries) = 0;
694 	virtual VkBuildAccelerationStructureFlagsKHR					getBuildFlags						() const = DE_NULL;
695 	VkDeviceSize													getStructureSize					() const;
696 
697 	// methods specific for each acceleration structure
698 	virtual void													create								(const DeviceInterface&						vk,
699 																										 const VkDevice								device,
700 																										 Allocator&									allocator,
701 																										 VkDeviceSize								structureSize			= 0u,
702 																										 VkDeviceAddress							deviceAddress			= 0u ) = DE_NULL;
703 	virtual void													build								(const DeviceInterface&						vk,
704 																										 const VkDevice								device,
705 																										 const VkCommandBuffer						cmdBuffer) = DE_NULL;
706 	virtual void													copyFrom							(const DeviceInterface&						vk,
707 																										 const VkDevice								device,
708 																										 const VkCommandBuffer						cmdBuffer,
709 																										 TopLevelAccelerationStructure*				accelerationStructure,
710 																										 bool										compactCopy) = DE_NULL;
711 
712 	virtual void													serialize							(const DeviceInterface&						vk,
713 																										 const VkDevice								device,
714 																										 const VkCommandBuffer						cmdBuffer,
715 																										 SerialStorage*								storage) = DE_NULL;
716 	virtual void													deserialize							(const DeviceInterface&						vk,
717 																										 const VkDevice								device,
718 																										 const VkCommandBuffer						cmdBuffer,
719 																										 SerialStorage*								storage) = DE_NULL;
720 
721 	virtual std::vector<VkDeviceSize>								getSerializingSizes					(const DeviceInterface&						vk,
722 																										 const VkDevice								device,
723 																										 const VkQueue								queue,
724 																										 const deUint32								queueFamilyIndex) = DE_NULL;
725 
726 	virtual std::vector<deUint64>									getSerializingAddresses				(const DeviceInterface&						vk,
727 																										 const VkDevice								device) const = DE_NULL;
728 
729 	// helper methods for typical acceleration structure creation tasks
730 	void															createAndBuild						(const DeviceInterface&						vk,
731 																										 const VkDevice								device,
732 																										 const VkCommandBuffer						cmdBuffer,
733 																										 Allocator&									allocator,
734 																										 VkDeviceAddress							deviceAddress			= 0u );
735 	void															createAndCopyFrom					(const DeviceInterface&						vk,
736 																										 const VkDevice								device,
737 																										 const VkCommandBuffer						cmdBuffer,
738 																										 Allocator&									allocator,
739 																										 TopLevelAccelerationStructure*				accelerationStructure,
740 																										 VkDeviceSize								compactCopySize			= 0u,
741 																										 VkDeviceAddress							deviceAddress			= 0u);
742 	void															createAndDeserializeFrom			(const DeviceInterface&						vk,
743 																										 const VkDevice								device,
744 																										 const VkCommandBuffer						cmdBuffer,
745 																										 Allocator&									allocator,
746 																										 SerialStorage*								storage,
747 																										 VkDeviceAddress							deviceAddress			= 0u);
748 
749 	virtual const VkAccelerationStructureKHR*						getPtr								(void) const = DE_NULL;
750 
751 	virtual void													updateInstanceMatrix				(const DeviceInterface&						vk,
752 																										 const VkDevice								device,
753 																										 size_t										instanceIndex,
754 																										 const VkTransformMatrixKHR&				matrix) = 0;
755 
756 protected:
757 	std::vector<de::SharedPtr<BottomLevelAccelerationStructure> >	m_bottomLevelInstances;
758 	std::vector<InstanceData>										m_instanceData;
759 	VkDeviceSize													m_structureSize;
760 	VkDeviceSize													m_updateScratchSize;
761 	VkDeviceSize													m_buildScratchSize;
762 
763 	virtual void													createAndDeserializeBottoms			(const DeviceInterface&						vk,
764 																										 const VkDevice								device,
765 																										 const VkCommandBuffer						cmdBuffer,
766 																										 Allocator&									allocator,
767 																										 SerialStorage*								storage) = DE_NULL;
768 };
769 
770 de::MovePtr<TopLevelAccelerationStructure> makeTopLevelAccelerationStructure ();
771 
772 template<class ASType> de::MovePtr<ASType> makeAccelerationStructure ();
makeAccelerationStructure()773 template<> inline de::MovePtr<BottomLevelAccelerationStructure>	makeAccelerationStructure () { return makeBottomLevelAccelerationStructure(); }
makeAccelerationStructure()774 template<> inline de::MovePtr<TopLevelAccelerationStructure>	makeAccelerationStructure () { return makeTopLevelAccelerationStructure(); }
775 
776 bool queryAccelerationStructureSize (const DeviceInterface&							vk,
777 									 const VkDevice									device,
778 									 const VkCommandBuffer							cmdBuffer,
779 									 const std::vector<VkAccelerationStructureKHR>&	accelerationStructureHandles,
780 									 VkAccelerationStructureBuildTypeKHR			buildType,
781 									 const VkQueryPool								queryPool,
782 									 VkQueryType									queryType,
783 									 deUint32										firstQuery,
784 									 std::vector<VkDeviceSize>&						results);
785 
786 class RayTracingPipeline
787 {
788 public:
789 																RayTracingPipeline			();
790 																~RayTracingPipeline			();
791 
792 	void														addShader					(VkShaderStageFlagBits									shaderStage,
793 																							 Move<VkShaderModule>									shaderModule,
794 																							 deUint32												group,
795 																							 const VkSpecializationInfo*							specializationInfo = nullptr,
796 																							 const VkPipelineShaderStageCreateFlags					pipelineShaderStageCreateFlags = static_cast<VkPipelineShaderStageCreateFlags>(0),
797 																							 const void*											pipelineShaderStageCreateInfopNext = nullptr);
798 	void														addShader					(VkShaderStageFlagBits									shaderStage,
799 																							 de::SharedPtr<Move<VkShaderModule>>					shaderModule,
800 																							 deUint32												group,
801 																							 const VkSpecializationInfo*							specializationInfoPtr = nullptr,
802 																							 const VkPipelineShaderStageCreateFlags					pipelineShaderStageCreateFlags = static_cast<VkPipelineShaderStageCreateFlags>(0),
803 																							 const void*											pipelineShaderStageCreateInfopNext = nullptr);
804 	void														addShader					(VkShaderStageFlagBits									shaderStage,
805 																							 VkShaderModule									        shaderModule,
806 																							 deUint32												group,
807 																							 const VkSpecializationInfo*							specializationInfo = nullptr,
808 																							 const VkPipelineShaderStageCreateFlags					pipelineShaderStageCreateFlags = static_cast<VkPipelineShaderStageCreateFlags>(0),
809 																							 const void*											pipelineShaderStageCreateInfopNext = nullptr);
810 	void														addLibrary					(de::SharedPtr<de::MovePtr<RayTracingPipeline>>			pipelineLibrary);
811 	Move<VkPipeline>											createPipeline				(const DeviceInterface&									vk,
812 																							 const VkDevice											device,
813 																							 const VkPipelineLayout									pipelineLayout,
814 																							 const std::vector<de::SharedPtr<Move<VkPipeline>>>&	pipelineLibraries			= std::vector<de::SharedPtr<Move<VkPipeline>>>());
815 	std::vector<de::SharedPtr<Move<VkPipeline>>>				createPipelineWithLibraries	(const DeviceInterface&									vk,
816 																							 const VkDevice											device,
817 																							 const VkPipelineLayout									pipelineLayout);
818 	de::MovePtr<BufferWithMemory>								createShaderBindingTable	(const DeviceInterface&									vk,
819 																							 const VkDevice											device,
820 																							 const VkPipeline										pipeline,
821 																							 Allocator&												allocator,
822 																							 const deUint32&										shaderGroupHandleSize,
823 																							 const deUint32											shaderGroupBaseAlignment,
824 																							 const deUint32&										firstGroup,
825 																							 const deUint32&										groupCount,
826 																							 const VkBufferCreateFlags&								additionalBufferCreateFlags	= VkBufferCreateFlags(0u),
827 																							 const VkBufferUsageFlags&								additionalBufferUsageFlags	= VkBufferUsageFlags(0u),
828 																							 const MemoryRequirement&								additionalMemoryRequirement	= MemoryRequirement::Any,
829 																							 const VkDeviceAddress&									opaqueCaptureAddress		= 0u,
830 																							 const deUint32											shaderBindingTableOffset	= 0u,
831 																							 const deUint32											shaderRecordSize			= 0u,
832 																							 const void**											shaderGroupDataPtrPerGroup	= nullptr);
833 	void														setCreateFlags				(const VkPipelineCreateFlags&							pipelineCreateFlags);
834 	void														setMaxRecursionDepth		(const deUint32&										maxRecursionDepth);
835 	void														setMaxPayloadSize			(const deUint32&										maxPayloadSize);
836 	void														setMaxAttributeSize			(const deUint32&										maxAttributeSize);
837 	void														setDeferredOperation		(const bool												deferredOperation,
838 																							 const deUint32											workerThreadCount = 0);
839 	void														addDynamicState				(const VkDynamicState&									dynamicState);
840 
841 
842 protected:
843 	Move<VkPipeline>											createPipelineKHR			(const DeviceInterface&									vk,
844 																							 const VkDevice											device,
845 																							 const VkPipelineLayout									pipelineLayout,
846 																							 const std::vector<de::SharedPtr<Move<VkPipeline>>>&	pipelineLibraries);
847 
848 	std::vector<de::SharedPtr<Move<VkShaderModule> > >			m_shadersModules;
849 	std::vector<de::SharedPtr<de::MovePtr<RayTracingPipeline>>>	m_pipelineLibraries;
850 	std::vector<VkPipelineShaderStageCreateInfo>				m_shaderCreateInfos;
851 	std::vector<VkRayTracingShaderGroupCreateInfoKHR>			m_shadersGroupCreateInfos;
852 	VkPipelineCreateFlags										m_pipelineCreateFlags;
853 	deUint32													m_maxRecursionDepth;
854 	deUint32													m_maxPayloadSize;
855 	deUint32													m_maxAttributeSize;
856 	bool														m_deferredOperation;
857 	deUint32													m_workerThreadCount;
858 	std::vector<VkDynamicState>									m_dynamicStates;
859 };
860 
861 class RayTracingProperties
862 {
863 protected:
RayTracingProperties()864 									RayTracingProperties						() {}
865 
866 public:
RayTracingProperties(const InstanceInterface & vki,const VkPhysicalDevice physicalDevice)867 									RayTracingProperties						(const InstanceInterface&	vki,
868 																				 const VkPhysicalDevice		physicalDevice) { DE_UNREF(vki); DE_UNREF(physicalDevice); }
~RayTracingProperties()869 	virtual							~RayTracingProperties						() {}
870 
871 	virtual deUint32				getShaderGroupHandleSize					(void)	= DE_NULL;
872 	virtual deUint32				getMaxRecursionDepth						(void)	= DE_NULL;
873 	virtual deUint32				getMaxShaderGroupStride						(void)	= DE_NULL;
874 	virtual deUint32				getShaderGroupBaseAlignment					(void)	= DE_NULL;
875 	virtual deUint64				getMaxGeometryCount							(void)	= DE_NULL;
876 	virtual deUint64				getMaxInstanceCount							(void)	= DE_NULL;
877 	virtual deUint64				getMaxPrimitiveCount						(void)	= DE_NULL;
878 	virtual deUint32				getMaxDescriptorSetAccelerationStructures	(void)	= DE_NULL;
879 	virtual deUint32				getMaxRayDispatchInvocationCount			(void)	= DE_NULL;
880 	virtual deUint32				getMaxRayHitAttributeSize					(void)	= DE_NULL;
881 };
882 
883 de::MovePtr<RayTracingProperties> makeRayTracingProperties (const InstanceInterface&	vki,
884 															const VkPhysicalDevice		physicalDevice);
885 
886 void cmdTraceRays	(const DeviceInterface&					vk,
887 					 VkCommandBuffer						commandBuffer,
888 					 const VkStridedDeviceAddressRegionKHR*	raygenShaderBindingTableRegion,
889 					 const VkStridedDeviceAddressRegionKHR*	missShaderBindingTableRegion,
890 					 const VkStridedDeviceAddressRegionKHR*	hitShaderBindingTableRegion,
891 					 const VkStridedDeviceAddressRegionKHR*	callableShaderBindingTableRegion,
892 					 deUint32								width,
893 					 deUint32								height,
894 					 deUint32								depth);
895 
896 void cmdTraceRaysIndirect	(const DeviceInterface&					vk,
897 							 VkCommandBuffer						commandBuffer,
898 							 const VkStridedDeviceAddressRegionKHR*	raygenShaderBindingTableRegion,
899 							 const VkStridedDeviceAddressRegionKHR*	missShaderBindingTableRegion,
900 							 const VkStridedDeviceAddressRegionKHR*	hitShaderBindingTableRegion,
901 							 const VkStridedDeviceAddressRegionKHR*	callableShaderBindingTableRegion,
902 							 VkDeviceAddress						indirectDeviceAddress);
903 } // vk
904 
905 #endif // _VKRAYTRACINGUTIL_HPP
906