1 /****************************************************************************
2 * Copyright (C) 2014-2015 Intel Corporation. All Rights Reserved.
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * @file streamout_jit.cpp
24 *
25 * @brief Implementation of the streamout jitter
26 *
27 * Notes:
28 *
29 ******************************************************************************/
30 #include "jit_pch.hpp"
31 #include "builder_gfx_mem.h"
32 #include "jit_api.h"
33 #include "streamout_jit.h"
34 #include "gen_state_llvm.h"
35 #include "functionpasses/passes.h"
36
37 using namespace llvm;
38 using namespace SwrJit;
39
40 //////////////////////////////////////////////////////////////////////////
41 /// Interface to Jitting a fetch shader
42 //////////////////////////////////////////////////////////////////////////
43 struct StreamOutJit : public BuilderGfxMem
44 {
StreamOutJitStreamOutJit45 StreamOutJit(JitManager* pJitMgr) : BuilderGfxMem(pJitMgr){};
46
47 // returns pointer to SWR_STREAMOUT_BUFFER
getSOBufferStreamOutJit48 Value* getSOBuffer(Value* pSoCtx, uint32_t buffer)
49 {
50 return LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_pBuffer, buffer});
51 }
52
53 //////////////////////////////////////////////////////////////////////////
54 // @brief checks if streamout buffer is oob
55 // @return <i1> true/false
oobStreamOutJit56 Value* oob(const STREAMOUT_COMPILE_STATE& state, Value* pSoCtx, uint32_t buffer)
57 {
58 Value* returnMask = C(false);
59
60 Value* pBuf = getSOBuffer(pSoCtx, buffer);
61
62 // load enable
63 // @todo bool data types should generate <i1> llvm type
64 Value* enabled = TRUNC(LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_enable}), IRB()->getInt1Ty());
65
66 // load buffer size
67 Value* bufferSize = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_bufferSize});
68
69 // load current streamOffset
70 Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
71
72 // load buffer pitch
73 Value* pitch = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pitch});
74
75 // buffer is considered oob if in use in a decl but not enabled
76 returnMask = OR(returnMask, NOT(enabled));
77
78 // buffer is oob if cannot fit a prims worth of verts
79 Value* newOffset = ADD(streamOffset, MUL(pitch, C(state.numVertsPerPrim)));
80 returnMask = OR(returnMask, ICMP_SGT(newOffset, bufferSize));
81
82 return returnMask;
83 }
84
85 //////////////////////////////////////////////////////////////////////////
86 // @brief converts scalar bitmask to <4 x i32> suitable for shuffle vector,
87 // packing the active mask bits
88 // ex. bitmask 0011 -> (0, 1, 0, 0)
89 // bitmask 1000 -> (3, 0, 0, 0)
90 // bitmask 1100 -> (2, 3, 0, 0)
PackMaskStreamOutJit91 Value* PackMask(uint32_t bitmask)
92 {
93 std::vector<Constant*> indices(4, C(0));
94 unsigned long index;
95 uint32_t elem = 0;
96 while (_BitScanForward(&index, bitmask))
97 {
98 indices[elem++] = C((int)index);
99 bitmask &= ~(1 << index);
100 }
101
102 return ConstantVector::get(indices);
103 }
104
105 //////////////////////////////////////////////////////////////////////////
106 // @brief convert scalar bitmask to <4xfloat> bitmask
ToMaskStreamOutJit107 Value* ToMask(uint32_t bitmask)
108 {
109 std::vector<Constant*> indices;
110 for (uint32_t i = 0; i < 4; ++i)
111 {
112 if (bitmask & (1 << i))
113 {
114 indices.push_back(C(true));
115 }
116 else
117 {
118 indices.push_back(C(false));
119 }
120 }
121 return ConstantVector::get(indices);
122 }
123
124 //////////////////////////////////////////////////////////////////////////
125 // @brief processes a single decl from the streamout stream. Reads 4 components from the input
126 // stream and writes N components to the output buffer given the componentMask or if
127 // a hole, just increments the buffer pointer
128 // @param pStream - pointer to current attribute
129 // @param pOutBuffers - pointers to the current location of each output buffer
130 // @param decl - input decl
buildDeclStreamOutJit131 void buildDecl(Value* pStream, Value* pOutBuffers[4], const STREAMOUT_DECL& decl)
132 {
133 uint32_t numComponents = _mm_popcnt_u32(decl.componentMask);
134 uint32_t packedMask = (1 << numComponents) - 1;
135 if (!decl.hole)
136 {
137 // increment stream pointer to correct slot
138 Value* pAttrib = GEP(pStream, C(4 * decl.attribSlot));
139
140 // load 4 components from stream
141 Type* simd4Ty = getVectorType(IRB()->getFloatTy(), 4);
142 Type* simd4PtrTy = PointerType::get(simd4Ty, 0);
143 pAttrib = BITCAST(pAttrib, simd4PtrTy);
144 Value* vattrib = LOAD(pAttrib);
145
146 // shuffle/pack enabled components
147 Value* vpackedAttrib = VSHUFFLE(vattrib, vattrib, PackMask(decl.componentMask));
148
149 // store to output buffer
150 // cast SO buffer to i8*, needed by maskstore
151 Value* pOut = BITCAST(pOutBuffers[decl.bufferIndex], PointerType::get(simd4Ty, 0));
152
153 // cast input to <4xfloat>
154 Value* src = BITCAST(vpackedAttrib, simd4Ty);
155
156 // cast mask to <4xi1>
157 Value* mask = ToMask(packedMask);
158 MASKED_STORE(src, pOut, 4, mask, PointerType::get(simd4Ty, 0), MEM_CLIENT::GFX_MEM_CLIENT_STREAMOUT);
159 }
160
161 // increment SO buffer
162 pOutBuffers[decl.bufferIndex] = GEP(pOutBuffers[decl.bufferIndex], C(numComponents));
163 }
164
165 //////////////////////////////////////////////////////////////////////////
166 // @brief builds a single vertex worth of data for the given stream
167 // @param streamState - state for this stream
168 // @param pCurVertex - pointer to src stream vertex data
169 // @param pOutBuffer - pointers to up to 4 SO buffers
buildVertexStreamOutJit170 void buildVertex(const STREAMOUT_STREAM& streamState, Value* pCurVertex, Value* pOutBuffer[4])
171 {
172 for (uint32_t d = 0; d < streamState.numDecls; ++d)
173 {
174 const STREAMOUT_DECL& decl = streamState.decl[d];
175 buildDecl(pCurVertex, pOutBuffer, decl);
176 }
177 }
178
buildStreamStreamOutJit179 void buildStream(const STREAMOUT_COMPILE_STATE& state,
180 const STREAMOUT_STREAM& streamState,
181 Value* pSoCtx,
182 BasicBlock* returnBB,
183 Function* soFunc)
184 {
185 // get list of active SO buffers
186 std::unordered_set<uint32_t> activeSOBuffers;
187 for (uint32_t d = 0; d < streamState.numDecls; ++d)
188 {
189 const STREAMOUT_DECL& decl = streamState.decl[d];
190 activeSOBuffers.insert(decl.bufferIndex);
191 }
192
193 // always increment numPrimStorageNeeded
194 Value* numPrimStorageNeeded = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded});
195 numPrimStorageNeeded = ADD(numPrimStorageNeeded, C(1));
196 STORE(numPrimStorageNeeded, pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimStorageNeeded});
197
198 // check OOB on active SO buffers. If any buffer is out of bound, don't write
199 // the primitive to any buffer
200 Value* oobMask = C(false);
201 for (uint32_t buffer : activeSOBuffers)
202 {
203 oobMask = OR(oobMask, oob(state, pSoCtx, buffer));
204 }
205
206 BasicBlock* validBB = BasicBlock::Create(JM()->mContext, "valid", soFunc);
207
208 // early out if OOB
209 COND_BR(oobMask, returnBB, validBB);
210
211 IRB()->SetInsertPoint(validBB);
212
213 Value* numPrimsWritten = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimsWritten});
214 numPrimsWritten = ADD(numPrimsWritten, C(1));
215 STORE(numPrimsWritten, pSoCtx, {0, SWR_STREAMOUT_CONTEXT_numPrimsWritten});
216
217 // compute start pointer for each output buffer
218 Value* pOutBuffer[4];
219 Value* pOutBufferStartVertex[4];
220 Value* outBufferPitch[4];
221 for (uint32_t b : activeSOBuffers)
222 {
223 Value* pBuf = getSOBuffer(pSoCtx, b);
224 Value* pData = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pBuffer});
225 Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
226 pOutBuffer[b] = GEP(pData, streamOffset, PointerType::get(IRB()->getInt32Ty(), 0));
227 pOutBufferStartVertex[b] = pOutBuffer[b];
228
229 outBufferPitch[b] = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_pitch});
230 }
231
232 // loop over the vertices of the prim
233 Value* pStreamData = LOAD(pSoCtx, {0, SWR_STREAMOUT_CONTEXT_pPrimData});
234 for (uint32_t v = 0; v < state.numVertsPerPrim; ++v)
235 {
236 buildVertex(streamState, pStreamData, pOutBuffer);
237
238 // increment stream and output buffer pointers
239 // stream verts are always 32*4 dwords apart
240 pStreamData = GEP(pStreamData, C(SWR_VTX_NUM_SLOTS * 4));
241
242 // output buffers offset using pitch in buffer state
243 for (uint32_t b : activeSOBuffers)
244 {
245 pOutBufferStartVertex[b] = GEP(pOutBufferStartVertex[b], outBufferPitch[b]);
246 pOutBuffer[b] = pOutBufferStartVertex[b];
247 }
248 }
249
250 // update each active buffer's streamOffset
251 for (uint32_t b : activeSOBuffers)
252 {
253 Value* pBuf = getSOBuffer(pSoCtx, b);
254 Value* streamOffset = LOAD(pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
255 streamOffset = ADD(streamOffset, MUL(C(state.numVertsPerPrim), outBufferPitch[b]));
256 STORE(streamOffset, pBuf, {0, SWR_STREAMOUT_BUFFER_streamOffset});
257 }
258 }
259
CreateStreamOutJit260 Function* Create(const STREAMOUT_COMPILE_STATE& state)
261 {
262 std::stringstream fnName("SO_",
263 std::ios_base::in | std::ios_base::out | std::ios_base::ate);
264 fnName << ComputeCRC(0, &state, sizeof(state));
265
266 std::vector<Type*> args{
267 mInt8PtrTy,
268 mInt8PtrTy,
269 PointerType::get(Gen_SWR_STREAMOUT_CONTEXT(JM()), 0), // SWR_STREAMOUT_CONTEXT*
270 };
271
272 FunctionType* fTy = FunctionType::get(IRB()->getVoidTy(), args, false);
273 Function* soFunc = Function::Create(
274 fTy, GlobalValue::ExternalLinkage, fnName.str(), JM()->mpCurrentModule);
275
276 soFunc->getParent()->setModuleIdentifier(soFunc->getName());
277
278 // create return basic block
279 BasicBlock* entry = BasicBlock::Create(JM()->mContext, "entry", soFunc);
280 BasicBlock* returnBB = BasicBlock::Create(JM()->mContext, "return", soFunc);
281
282 IRB()->SetInsertPoint(entry);
283
284 // arguments
285 auto argitr = soFunc->arg_begin();
286
287 Value* privateContext = &*argitr++;
288 privateContext->setName("privateContext");
289 SetPrivateContext(privateContext);
290
291 mpWorkerData = &*argitr;
292 ++argitr;
293 mpWorkerData->setName("pWorkerData");
294
295 Value* pSoCtx = &*argitr++;
296 pSoCtx->setName("pSoCtx");
297
298 const STREAMOUT_STREAM& streamState = state.stream;
299 buildStream(state, streamState, pSoCtx, returnBB, soFunc);
300
301 BR(returnBB);
302
303 IRB()->SetInsertPoint(returnBB);
304 RET_VOID();
305
306 JitManager::DumpToFile(soFunc, "SoFunc");
307
308 ::FunctionPassManager passes(JM()->mpCurrentModule);
309
310 passes.add(createBreakCriticalEdgesPass());
311 passes.add(createCFGSimplificationPass());
312 passes.add(createEarlyCSEPass());
313 passes.add(createPromoteMemoryToRegisterPass());
314 passes.add(createCFGSimplificationPass());
315 passes.add(createEarlyCSEPass());
316 passes.add(createInstructionCombiningPass());
317 #if LLVM_VERSION_MAJOR <= 11
318 passes.add(createConstantPropagationPass());
319 #endif
320 passes.add(createSCCPPass());
321 passes.add(createAggressiveDCEPass());
322
323 passes.add(createLowerX86Pass(this));
324
325 passes.run(*soFunc);
326
327 JitManager::DumpToFile(soFunc, "SoFunc_optimized");
328
329
330 return soFunc;
331 }
332 };
333
334 //////////////////////////////////////////////////////////////////////////
335 /// @brief JITs from streamout shader IR
336 /// @param hJitMgr - JitManager handle
337 /// @param func - LLVM function IR
338 /// @return PFN_SO_FUNC - pointer to SOS function
JitStreamoutFunc(HANDLE hJitMgr,const HANDLE hFunc)339 PFN_SO_FUNC JitStreamoutFunc(HANDLE hJitMgr, const HANDLE hFunc)
340 {
341 llvm::Function* func = (llvm::Function*)hFunc;
342 JitManager* pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
343 PFN_SO_FUNC pfnStreamOut;
344 pfnStreamOut = (PFN_SO_FUNC)(pJitMgr->mpExec->getFunctionAddress(func->getName().str()));
345 // MCJIT finalizes modules the first time you JIT code from them. After finalized, you cannot
346 // add new IR to the module
347 pJitMgr->mIsModuleFinalized = true;
348
349 pJitMgr->DumpAsm(func, "SoFunc_optimized");
350
351
352 return pfnStreamOut;
353 }
354
355 //////////////////////////////////////////////////////////////////////////
356 /// @brief JIT compiles streamout shader
357 /// @param hJitMgr - JitManager handle
358 /// @param state - SO state to build function from
JitCompileStreamout(HANDLE hJitMgr,const STREAMOUT_COMPILE_STATE & state)359 extern "C" PFN_SO_FUNC JITCALL JitCompileStreamout(HANDLE hJitMgr,
360 const STREAMOUT_COMPILE_STATE& state)
361 {
362 JitManager* pJitMgr = reinterpret_cast<JitManager*>(hJitMgr);
363
364 STREAMOUT_COMPILE_STATE soState = state;
365 if (soState.offsetAttribs)
366 {
367 for (uint32_t i = 0; i < soState.stream.numDecls; ++i)
368 {
369 soState.stream.decl[i].attribSlot -= soState.offsetAttribs;
370 }
371 }
372
373 pJitMgr->SetupNewModule();
374
375 StreamOutJit theJit(pJitMgr);
376 HANDLE hFunc = theJit.Create(soState);
377
378 return JitStreamoutFunc(hJitMgr, hFunc);
379 }
380