• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2The MIT License (MIT)
3
4Copyright (c) 2022 Google LLC
5Copyright (c) 2022 Sascha Willems
6
7Permission is hereby granted, free of charge, to any person obtaining a copy
8of this software and associated documentation files (the "Software"), to deal
9in the Software without restriction, including without limitation the rights
10to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11copies of the Software, and to permit persons to whom the Software is
12furnished to do so, subject to the following conditions:
13
14The above copyright notice and this permission notice shall be included in all
15copies or substantial portions of the Software.
16
17THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23SOFTWARE.
24*/
25
26struct Particle {
27	float4 pos;
28	float4 vel;
29	float4 uv;
30	float4 normal;
31	float pinned;
32};
33
34[[vk::binding(0)]]
35StructuredBuffer<Particle> particleIn;
36[[vk::binding(1)]]
37RWStructuredBuffer<Particle> particleOut;
38
39struct UBO
40{
41	float deltaT;
42	float particleMass;
43	float springStiffness;
44	float damping;
45	float restDistH;
46	float restDistV;
47	float restDistD;
48	float sphereRadius;
49	float4 spherePos;
50	float4 gravity;
51	int2 particleCount;
52};
53
54cbuffer ubo : register(b2)
55{
56	UBO params;
57};
58
59#ifdef GLSLANG
60layout ( push_constant ) cbuffer PushConstants
61{
62	uint calculateNormals;
63} pushConstants;
64#else
65struct PushConstants
66{
67	uint calculateNormals;
68};
69
70[[vk::push_constant]]
71PushConstants pushConstants;
72#endif
73
74float3 springForce(float3 p0, float3 p1, float restDist)
75{
76	float3 dist = p0 - p1;
77	return normalize(dist) * params.springStiffness * (length(dist) - restDist);
78}
79
80[numthreads(10, 10, 1)]
81void main(uint3 id : SV_DispatchThreadID)
82{
83	uint index = id.y * params.particleCount.x + id.x;
84	if (index > params.particleCount.x * params.particleCount.y)
85		return;
86
87	// Pinned?
88	if (particleIn[index].pinned == 1.0) {
89		particleOut[index].pos = particleOut[index].pos;
90		particleOut[index].vel = float4(0, 0, 0, 0);
91		return;
92	}
93
94	// Initial force from gravity
95	float3 force = params.gravity.xyz * params.particleMass;
96
97	float3 pos = particleIn[index].pos.xyz;
98	float3 vel = particleIn[index].vel.xyz;
99
100	// Spring forces from neighboring particles
101	// left
102	if (id.x > 0) {
103		force += springForce(particleIn[index-1].pos.xyz, pos, params.restDistH);
104	}
105	// right
106	if (id.x < params.particleCount.x - 1) {
107		force += springForce(particleIn[index + 1].pos.xyz, pos, params.restDistH);
108	}
109	// upper
110	if (id.y < params.particleCount.y - 1) {
111		force += springForce(particleIn[index + params.particleCount.x].pos.xyz, pos, params.restDistV);
112	}
113	// lower
114	if (id.y > 0) {
115		force += springForce(particleIn[index - params.particleCount.x].pos.xyz, pos, params.restDistV);
116	}
117	// upper-left
118	if ((id.x > 0) && (id.y < params.particleCount.y - 1)) {
119		force += springForce(particleIn[index + params.particleCount.x - 1].pos.xyz, pos, params.restDistD);
120	}
121	// lower-left
122	if ((id.x > 0) && (id.y > 0)) {
123		force += springForce(particleIn[index - params.particleCount.x - 1].pos.xyz, pos, params.restDistD);
124	}
125	// upper-right
126	if ((id.x < params.particleCount.x - 1) && (id.y < params.particleCount.y - 1)) {
127		force += springForce(particleIn[index + params.particleCount.x + 1].pos.xyz, pos, params.restDistD);
128	}
129	// lower-right
130	if ((id.x < params.particleCount.x - 1) && (id.y > 0)) {
131		force += springForce(particleIn[index - params.particleCount.x + 1].pos.xyz, pos, params.restDistD);
132	}
133
134	force += (-params.damping * vel);
135
136	// Integrate
137	float3 f = force * (1.0 / params.particleMass);
138	particleOut[index].pos = float4(pos + vel * params.deltaT + 0.5 * f * params.deltaT * params.deltaT, 1.0);
139	particleOut[index].vel = float4(vel + f * params.deltaT, 0.0);
140
141	// Sphere collision
142	float3 sphereDist = particleOut[index].pos.xyz - params.spherePos.xyz;
143	if (length(sphereDist) < params.sphereRadius + 0.01) {
144		// If the particle is inside the sphere, push it to the outer radius
145		particleOut[index].pos.xyz = params.spherePos.xyz + normalize(sphereDist) * (params.sphereRadius + 0.01);
146		// Cancel out velocity
147		particleOut[index].vel = float4(0, 0, 0, 0);
148	}
149
150	// Normals
151	if (pushConstants.calculateNormals == 1) {
152		float3 normal = float3(0, 0, 0);
153		float3 a, b, c;
154		if (id.y > 0) {
155			if (id.x > 0) {
156				a = particleIn[index - 1].pos.xyz - pos;
157				b = particleIn[index - params.particleCount.x - 1].pos.xyz - pos;
158				c = particleIn[index - params.particleCount.x].pos.xyz - pos;
159				normal += cross(a,b) + cross(b,c);
160			}
161			if (id.x < params.particleCount.x - 1) {
162				a = particleIn[index - params.particleCount.x].pos.xyz - pos;
163				b = particleIn[index - params.particleCount.x + 1].pos.xyz - pos;
164				c = particleIn[index + 1].pos.xyz - pos;
165				normal += cross(a,b) + cross(b,c);
166			}
167		}
168		if (id.y < params.particleCount.y - 1) {
169			if (id.x > 0) {
170				a = particleIn[index + params.particleCount.x].pos.xyz - pos;
171				b = particleIn[index + params.particleCount.x - 1].pos.xyz - pos;
172				c = particleIn[index - 1].pos.xyz - pos;
173				normal += cross(a,b) + cross(b,c);
174			}
175			if (id.x < params.particleCount.x - 1) {
176				a = particleIn[index + 1].pos.xyz - pos;
177				b = particleIn[index + params.particleCount.x + 1].pos.xyz - pos;
178				c = particleIn[index + params.particleCount.x].pos.xyz - pos;
179				normal += cross(a,b) + cross(b,c);
180			}
181		}
182		particleOut[index].normal = float4(normalize(normal), 0.0f);
183	}
184}
185