1 /*
2 * Copyright 2018 Google Inc.
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8 #ifndef SKSL_STANDALONE
9
10 #ifdef SK_LLVM_AVAILABLE
11
12 #include "SkSLJIT.h"
13
14 #include "SkCpu.h"
15 #include "SkRasterPipeline.h"
16 #include "ir/SkSLAppendStage.h"
17 #include "ir/SkSLExpressionStatement.h"
18 #include "ir/SkSLFunctionCall.h"
19 #include "ir/SkSLFunctionReference.h"
20 #include "ir/SkSLIndexExpression.h"
21 #include "ir/SkSLProgram.h"
22 #include "ir/SkSLUnresolvedFunction.h"
23 #include "llvm/ExecutionEngine/RTDyldMemoryManager.h"
24
25 static constexpr int MAX_VECTOR_COUNT = 16;
26
sksl_pipeline_append(SkRasterPipeline * p,int stage,void * ctx)27 extern "C" void sksl_pipeline_append(SkRasterPipeline* p, int stage, void* ctx) {
28 p->append((SkRasterPipeline::StockStage) stage, ctx);
29 }
30
31 #define PTR_SIZE sizeof(void*)
32
sksl_pipeline_append_callback(SkRasterPipeline * p,void * fn)33 extern "C" void sksl_pipeline_append_callback(SkRasterPipeline* p, void* fn) {
34 p->append(fn, nullptr);
35 }
36
sksl_debug_print(float f)37 extern "C" void sksl_debug_print(float f) {
38 printf("Debug: %f\n", f);
39 }
40
sksl_clamp1(float f,float min,float max)41 extern "C" float sksl_clamp1(float f, float min, float max) {
42 return SkTPin(f, min, max);
43 }
44
45 using float2 = __attribute__((vector_size(8))) float;
46 using float3 = __attribute__((vector_size(16))) float;
47 using float4 = __attribute__((vector_size(16))) float;
48
sksl_clamp2(float2 f,float min,float max)49 extern "C" float2 sksl_clamp2(float2 f, float min, float max) {
50 return float2 { SkTPin(f[0], min, max), SkTPin(f[1], min, max) };
51 }
52
sksl_clamp3(float3 f,float min,float max)53 extern "C" float3 sksl_clamp3(float3 f, float min, float max) {
54 return float3 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max) };
55 }
56
sksl_clamp4(float4 f,float min,float max)57 extern "C" float4 sksl_clamp4(float4 f, float min, float max) {
58 return float4 { SkTPin(f[0], min, max), SkTPin(f[1], min, max), SkTPin(f[2], min, max),
59 SkTPin(f[3], min, max) };
60 }
61
62 namespace SkSL {
63
64 static constexpr int STAGE_PARAM_COUNT = 12;
65
ends_with_branch(const Statement & stmt)66 static bool ends_with_branch(const Statement& stmt) {
67 switch (stmt.fKind) {
68 case Statement::kBlock_Kind: {
69 const Block& b = (const Block&) stmt;
70 if (b.fStatements.size()) {
71 return ends_with_branch(*b.fStatements.back());
72 }
73 return false;
74 }
75 case Statement::kBreak_Kind: // fall through
76 case Statement::kContinue_Kind: // fall through
77 case Statement::kReturn_Kind: // fall through
78 return true;
79 default:
80 return false;
81 }
82 }
83
JIT(Compiler * compiler)84 JIT::JIT(Compiler* compiler)
85 : fCompiler(*compiler) {
86 LLVMInitializeNativeTarget();
87 LLVMInitializeNativeAsmPrinter();
88 LLVMLinkInMCJIT();
89 SkASSERT(!SkCpu::Supports(SkCpu::SKX)); // not yet supported
90 if (SkCpu::Supports(SkCpu::HSW)) {
91 fVectorCount = 8;
92 fCPU = "haswell";
93 } else if (SkCpu::Supports(SkCpu::AVX)) {
94 fVectorCount = 8;
95 fCPU = "ivybridge";
96 } else {
97 fVectorCount = 4;
98 fCPU = nullptr;
99 }
100 fContext = LLVMContextCreate();
101 fVoidType = LLVMVoidTypeInContext(fContext);
102 fInt1Type = LLVMInt1TypeInContext(fContext);
103 fInt1VectorType = LLVMVectorType(fInt1Type, fVectorCount);
104 fInt1Vector2Type = LLVMVectorType(fInt1Type, 2);
105 fInt1Vector3Type = LLVMVectorType(fInt1Type, 3);
106 fInt1Vector4Type = LLVMVectorType(fInt1Type, 4);
107 fInt8Type = LLVMInt8TypeInContext(fContext);
108 fInt8PtrType = LLVMPointerType(fInt8Type, 0);
109 fInt32Type = LLVMInt32TypeInContext(fContext);
110 fInt64Type = LLVMInt64TypeInContext(fContext);
111 fSizeTType = LLVMInt64TypeInContext(fContext);
112 fInt32VectorType = LLVMVectorType(fInt32Type, fVectorCount);
113 fInt32Vector2Type = LLVMVectorType(fInt32Type, 2);
114 fInt32Vector3Type = LLVMVectorType(fInt32Type, 3);
115 fInt32Vector4Type = LLVMVectorType(fInt32Type, 4);
116 fFloat32Type = LLVMFloatTypeInContext(fContext);
117 fFloat32VectorType = LLVMVectorType(fFloat32Type, fVectorCount);
118 fFloat32Vector2Type = LLVMVectorType(fFloat32Type, 2);
119 fFloat32Vector3Type = LLVMVectorType(fFloat32Type, 3);
120 fFloat32Vector4Type = LLVMVectorType(fFloat32Type, 4);
121 }
122
~JIT()123 JIT::~JIT() {
124 LLVMOrcDisposeInstance(fJITStack);
125 LLVMContextDispose(fContext);
126 }
127
addBuiltinFunction(const char * ourName,const char * realName,LLVMTypeRef returnType,std::vector<LLVMTypeRef> parameters)128 void JIT::addBuiltinFunction(const char* ourName, const char* realName, LLVMTypeRef returnType,
129 std::vector<LLVMTypeRef> parameters) {
130 bool found = false;
131 for (const auto& pair : *fProgram->fSymbols) {
132 if (Symbol::kFunctionDeclaration_Kind == pair.second->fKind) {
133 const FunctionDeclaration& f = (const FunctionDeclaration&) *pair.second;
134 if (pair.first != ourName || returnType != this->getType(f.fReturnType) ||
135 parameters.size() != f.fParameters.size()) {
136 continue;
137 }
138 for (size_t i = 0; i < parameters.size(); ++i) {
139 if (parameters[i] != this->getType(f.fParameters[i]->fType)) {
140 goto next;
141 }
142 }
143 fFunctions[&f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(returnType,
144 parameters.data(),
145 parameters.size(),
146 false));
147 found = true;
148 }
149 if (Symbol::kUnresolvedFunction_Kind == pair.second->fKind) {
150 // FIXME consolidate this with the code above
151 for (const auto& f : ((const UnresolvedFunction&) *pair.second).fFunctions) {
152 if (pair.first != ourName || returnType != this->getType(f->fReturnType) ||
153 parameters.size() != f->fParameters.size()) {
154 continue;
155 }
156 for (size_t i = 0; i < parameters.size(); ++i) {
157 if (parameters[i] != this->getType(f->fParameters[i]->fType)) {
158 goto next;
159 }
160 }
161 fFunctions[f] = LLVMAddFunction(fModule, realName, LLVMFunctionType(
162 returnType,
163 parameters.data(),
164 parameters.size(),
165 false));
166 found = true;
167 }
168 }
169 next:;
170 }
171 SkASSERT(found);
172 }
173
loadBuiltinFunctions()174 void JIT::loadBuiltinFunctions() {
175 this->addBuiltinFunction("abs", "fabs", fFloat32Type, { fFloat32Type });
176 this->addBuiltinFunction("sin", "sinf", fFloat32Type, { fFloat32Type });
177 this->addBuiltinFunction("cos", "cosf", fFloat32Type, { fFloat32Type });
178 this->addBuiltinFunction("tan", "tanf", fFloat32Type, { fFloat32Type });
179 this->addBuiltinFunction("sqrt", "sqrtf", fFloat32Type, { fFloat32Type });
180 this->addBuiltinFunction("clamp", "sksl_clamp1", fFloat32Type, { fFloat32Type,
181 fFloat32Type,
182 fFloat32Type });
183 this->addBuiltinFunction("clamp", "sksl_clamp2", fFloat32Vector2Type, { fFloat32Vector2Type,
184 fFloat32Type,
185 fFloat32Type });
186 this->addBuiltinFunction("clamp", "sksl_clamp3", fFloat32Vector3Type, { fFloat32Vector3Type,
187 fFloat32Type,
188 fFloat32Type });
189 this->addBuiltinFunction("clamp", "sksl_clamp4", fFloat32Vector4Type, { fFloat32Vector4Type,
190 fFloat32Type,
191 fFloat32Type });
192 this->addBuiltinFunction("print", "sksl_debug_print", fVoidType, { fFloat32Type });
193 }
194
resolveSymbol(const char * name,JIT * jit)195 uint64_t JIT::resolveSymbol(const char* name, JIT* jit) {
196 LLVMOrcTargetAddress result;
197 if (!LLVMOrcGetSymbolAddress(jit->fJITStack, &result, name)) {
198 if (!strcmp(name, "_sksl_pipeline_append")) {
199 result = (uint64_t) &sksl_pipeline_append;
200 } else if (!strcmp(name, "_sksl_pipeline_append_callback")) {
201 result = (uint64_t) &sksl_pipeline_append_callback;
202 } else if (!strcmp(name, "_sksl_clamp1")) {
203 result = (uint64_t) &sksl_clamp1;
204 } else if (!strcmp(name, "_sksl_clamp2")) {
205 result = (uint64_t) &sksl_clamp2;
206 } else if (!strcmp(name, "_sksl_clamp3")) {
207 result = (uint64_t) &sksl_clamp3;
208 } else if (!strcmp(name, "_sksl_clamp4")) {
209 result = (uint64_t) &sksl_clamp4;
210 } else if (!strcmp(name, "_sksl_debug_print")) {
211 result = (uint64_t) &sksl_debug_print;
212 } else {
213 result = llvm::RTDyldMemoryManager::getSymbolAddressInProcess(name);
214 }
215 }
216 SkASSERT(result);
217 return result;
218 }
219
compileFunctionCall(LLVMBuilderRef builder,const FunctionCall & fc)220 LLVMValueRef JIT::compileFunctionCall(LLVMBuilderRef builder, const FunctionCall& fc) {
221 LLVMValueRef func = fFunctions[&fc.fFunction];
222 SkASSERT(func);
223 std::vector<LLVMValueRef> parameters;
224 for (const auto& a : fc.fArguments) {
225 parameters.push_back(this->compileExpression(builder, *a));
226 }
227 return LLVMBuildCall(builder, func, parameters.data(), parameters.size(), "");
228 }
229
getType(const Type & type)230 LLVMTypeRef JIT::getType(const Type& type) {
231 switch (type.kind()) {
232 case Type::kOther_Kind:
233 if (type.name() == "void") {
234 return fVoidType;
235 }
236 SkASSERT(type.name() == "SkRasterPipeline");
237 return fInt8PtrType;
238 case Type::kScalar_Kind:
239 if (type.isSigned() || type.isUnsigned()) {
240 return fInt32Type;
241 }
242 if (type.isUnsigned()) {
243 return fInt32Type;
244 }
245 if (type.isFloat()) {
246 return fFloat32Type;
247 }
248 SkASSERT(type.name() == "bool");
249 return fInt1Type;
250 case Type::kArray_Kind:
251 return LLVMPointerType(this->getType(type.componentType()), 0);
252 case Type::kVector_Kind:
253 if (type.name() == "float2" || type.name() == "half2") {
254 return fFloat32Vector2Type;
255 }
256 if (type.name() == "float3" || type.name() == "half3") {
257 return fFloat32Vector3Type;
258 }
259 if (type.name() == "float4" || type.name() == "half4") {
260 return fFloat32Vector4Type;
261 }
262 if (type.name() == "int2" || type.name() == "short2" || type.name == "byte2") {
263 return fInt32Vector2Type;
264 }
265 if (type.name() == "int3" || type.name() == "short3" || type.name == "byte3") {
266 return fInt32Vector3Type;
267 }
268 if (type.name() == "int4" || type.name() == "short4" || type.name == "byte3") {
269 return fInt32Vector4Type;
270 }
271 // fall through
272 default:
273 ABORT("unsupported type");
274 }
275 }
276
setBlock(LLVMBuilderRef builder,LLVMBasicBlockRef block)277 void JIT::setBlock(LLVMBuilderRef builder, LLVMBasicBlockRef block) {
278 fCurrentBlock = block;
279 LLVMPositionBuilderAtEnd(builder, block);
280 }
281
getLValue(LLVMBuilderRef builder,const Expression & expr)282 std::unique_ptr<JIT::LValue> JIT::getLValue(LLVMBuilderRef builder, const Expression& expr) {
283 switch (expr.fKind) {
284 case Expression::kVariableReference_Kind: {
285 class PointerLValue : public LValue {
286 public:
287 PointerLValue(LLVMValueRef ptr)
288 : fPointer(ptr) {}
289
290 LLVMValueRef load(LLVMBuilderRef builder) override {
291 return LLVMBuildLoad(builder, fPointer, "lvalue load");
292 }
293
294 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
295 LLVMBuildStore(builder, value, fPointer);
296 }
297
298 private:
299 LLVMValueRef fPointer;
300 };
301 const Variable* var = &((VariableReference&) expr).fVariable;
302 if (var->fStorage == Variable::kParameter_Storage &&
303 !(var->fModifiers.fFlags & Modifiers::kOut_Flag) &&
304 fPromotedParameters.find(var) == fPromotedParameters.end()) {
305 // promote parameter to variable
306 fPromotedParameters.insert(var);
307 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
308 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(var->fType),
309 String(var->fName).c_str());
310 LLVMBuildStore(builder, fVariables[var], alloca);
311 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
312 fVariables[var] = alloca;
313 }
314 LLVMValueRef ptr = fVariables[var];
315 return std::unique_ptr<LValue>(new PointerLValue(ptr));
316 }
317 case Expression::kTernary_Kind: {
318 class TernaryLValue : public LValue {
319 public:
320 TernaryLValue(JIT* jit, LLVMValueRef test, std::unique_ptr<LValue> ifTrue,
321 std::unique_ptr<LValue> ifFalse)
322 : fJIT(*jit)
323 , fTest(test)
324 , fIfTrue(std::move(ifTrue))
325 , fIfFalse(std::move(ifFalse)) {}
326
327 LLVMValueRef load(LLVMBuilderRef builder) override {
328 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
329 fJIT.fContext,
330 fJIT.fCurrentFunction,
331 "true ? ...");
332 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
333 fJIT.fContext,
334 fJIT.fCurrentFunction,
335 "false ? ...");
336 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
337 fJIT.fCurrentFunction,
338 "ternary merge");
339 LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
340 fJIT.setBlock(builder, trueBlock);
341 LLVMValueRef ifTrue = fIfTrue->load(builder);
342 LLVMBuildBr(builder, merge);
343 fJIT.setBlock(builder, falseBlock);
344 LLVMValueRef ifFalse = fIfTrue->load(builder);
345 LLVMBuildBr(builder, merge);
346 fJIT.setBlock(builder, merge);
347 LLVMTypeRef type = LLVMPointerType(LLVMTypeOf(ifTrue), 0);
348 LLVMValueRef phi = LLVMBuildPhi(builder, type, "?");
349 LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
350 LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
351 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
352 return phi;
353 }
354
355 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
356 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(
357 fJIT.fContext,
358 fJIT.fCurrentFunction,
359 "true ? ...");
360 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(
361 fJIT.fContext,
362 fJIT.fCurrentFunction,
363 "false ? ...");
364 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fJIT.fContext,
365 fJIT.fCurrentFunction,
366 "ternary merge");
367 LLVMBuildCondBr(builder, fTest, trueBlock, falseBlock);
368 fJIT.setBlock(builder, trueBlock);
369 fIfTrue->store(builder, value);
370 LLVMBuildBr(builder, merge);
371 fJIT.setBlock(builder, falseBlock);
372 fIfTrue->store(builder, value);
373 LLVMBuildBr(builder, merge);
374 fJIT.setBlock(builder, merge);
375 }
376
377 private:
378 JIT& fJIT;
379 LLVMValueRef fTest;
380 std::unique_ptr<LValue> fIfTrue;
381 std::unique_ptr<LValue> fIfFalse;
382 };
383 const TernaryExpression& t = (const TernaryExpression&) expr;
384 LLVMValueRef test = this->compileExpression(builder, *t.fTest);
385 return std::unique_ptr<LValue>(new TernaryLValue(this,
386 test,
387 this->getLValue(builder,
388 *t.fIfTrue),
389 this->getLValue(builder,
390 *t.fIfFalse)));
391 }
392 case Expression::kSwizzle_Kind: {
393 class SwizzleLValue : public LValue {
394 public:
395 SwizzleLValue(JIT* jit, LLVMTypeRef type, std::unique_ptr<LValue> base,
396 std::vector<int> components)
397 : fJIT(*jit)
398 , fType(type)
399 , fBase(std::move(base))
400 , fComponents(components) {}
401
402 LLVMValueRef load(LLVMBuilderRef builder) override {
403 LLVMValueRef base = fBase->load(builder);
404 if (fComponents.size() > 1) {
405 LLVMValueRef result = LLVMGetUndef(fType);
406 for (size_t i = 0; i < fComponents.size(); ++i) {
407 LLVMValueRef element = LLVMBuildExtractElement(
408 builder,
409 base,
410 LLVMConstInt(fJIT.fInt32Type,
411 fComponents[i],
412 false),
413 "swizzle extract");
414 result = LLVMBuildInsertElement(builder, result, element,
415 LLVMConstInt(fJIT.fInt32Type, i, false),
416 "swizzle insert");
417 }
418 return result;
419 }
420 SkASSERT(fComponents.size() == 1);
421 return LLVMBuildExtractElement(builder, base,
422 LLVMConstInt(fJIT.fInt32Type,
423 fComponents[0],
424 false),
425 "swizzle extract");
426 }
427
428 void store(LLVMBuilderRef builder, LLVMValueRef value) override {
429 LLVMValueRef result = fBase->load(builder);
430 if (fComponents.size() > 1) {
431 for (size_t i = 0; i < fComponents.size(); ++i) {
432 LLVMValueRef element = LLVMBuildExtractElement(builder, value,
433 LLVMConstInt(
434 fJIT.fInt32Type,
435 i,
436 false),
437 "swizzle extract");
438 result = LLVMBuildInsertElement(builder, result, element,
439 LLVMConstInt(fJIT.fInt32Type,
440 fComponents[i],
441 false),
442 "swizzle insert");
443 }
444 } else {
445 result = LLVMBuildInsertElement(builder, result, value,
446 LLVMConstInt(fJIT.fInt32Type,
447 fComponents[0],
448 false),
449 "swizzle insert");
450 }
451 fBase->store(builder, result);
452 }
453
454 private:
455 JIT& fJIT;
456 LLVMTypeRef fType;
457 std::unique_ptr<LValue> fBase;
458 std::vector<int> fComponents;
459 };
460 const Swizzle& s = (const Swizzle&) expr;
461 return std::unique_ptr<LValue>(new SwizzleLValue(this, this->getType(s.fType),
462 this->getLValue(builder, *s.fBase),
463 s.fComponents));
464 }
465 default:
466 ABORT("unsupported lvalue");
467 }
468 }
469
typeKind(const Type & type)470 JIT::TypeKind JIT::typeKind(const Type& type) {
471 if (type.kind() == Type::kVector_Kind) {
472 return this->typeKind(type.componentType());
473 }
474 if (type.fName == "int" || type.fName == "short" || type.fName == "byte") {
475 return JIT::kInt_TypeKind;
476 } else if (type.fName == "uint" || type.fName == "ushort" || type.fName == "ubyte") {
477 return JIT::kUInt_TypeKind;
478 } else if (type.fName == "float" || type.fName == "double" || type.fName == "half") {
479 return JIT::kFloat_TypeKind;
480 }
481 ABORT("unsupported type: %s\n", type.description().c_str());
482 }
483
vectorize(LLVMBuilderRef builder,LLVMValueRef * value,int columns)484 void JIT::vectorize(LLVMBuilderRef builder, LLVMValueRef* value, int columns) {
485 LLVMValueRef result = LLVMGetUndef(LLVMVectorType(LLVMTypeOf(*value), columns));
486 for (int i = 0; i < columns; ++i) {
487 result = LLVMBuildInsertElement(builder,
488 result,
489 *value,
490 LLVMConstInt(fInt32Type, i, false),
491 "vectorize");
492 }
493 *value = result;
494 }
495
vectorize(LLVMBuilderRef builder,const BinaryExpression & b,LLVMValueRef * left,LLVMValueRef * right)496 void JIT::vectorize(LLVMBuilderRef builder, const BinaryExpression& b, LLVMValueRef* left,
497 LLVMValueRef* right) {
498 if (b.fLeft->fType.kind() == Type::kScalar_Kind &&
499 b.fRight->fType.kind() == Type::kVector_Kind) {
500 this->vectorize(builder, left, b.fRight->fType.columns());
501 } else if (b.fLeft->fType.kind() == Type::kVector_Kind &&
502 b.fRight->fType.kind() == Type::kScalar_Kind) {
503 this->vectorize(builder, right, b.fLeft->fType.columns());
504 }
505 }
506
507
compileBinary(LLVMBuilderRef builder,const BinaryExpression & b)508 LLVMValueRef JIT::compileBinary(LLVMBuilderRef builder, const BinaryExpression& b) {
509 #define BINARY(SFunc, UFunc, FFunc) { \
510 LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
511 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
512 this->vectorize(builder, b, &left, &right); \
513 switch (this->typeKind(b.fLeft->fType)) { \
514 case kInt_TypeKind: \
515 return SFunc(builder, left, right, "binary"); \
516 case kUInt_TypeKind: \
517 return UFunc(builder, left, right, "binary"); \
518 case kFloat_TypeKind: \
519 return FFunc(builder, left, right, "binary"); \
520 default: \
521 ABORT("unsupported typeKind"); \
522 } \
523 }
524 #define COMPOUND(SFunc, UFunc, FFunc) { \
525 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft); \
526 LLVMValueRef left = lvalue->load(builder); \
527 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
528 this->vectorize(builder, b, &left, &right); \
529 LLVMValueRef result; \
530 switch (this->typeKind(b.fLeft->fType)) { \
531 case kInt_TypeKind: \
532 result = SFunc(builder, left, right, "binary"); \
533 break; \
534 case kUInt_TypeKind: \
535 result = UFunc(builder, left, right, "binary"); \
536 break; \
537 case kFloat_TypeKind: \
538 result = FFunc(builder, left, right, "binary"); \
539 break; \
540 default: \
541 ABORT("unsupported typeKind"); \
542 } \
543 lvalue->store(builder, result); \
544 return result; \
545 }
546 #define COMPARE(SFunc, SOp, UFunc, UOp, FFunc, FOp) { \
547 LLVMValueRef left = this->compileExpression(builder, *b.fLeft); \
548 LLVMValueRef right = this->compileExpression(builder, *b.fRight); \
549 this->vectorize(builder, b, &left, &right); \
550 switch (this->typeKind(b.fLeft->fType)) { \
551 case kInt_TypeKind: \
552 return SFunc(builder, SOp, left, right, "binary"); \
553 case kUInt_TypeKind: \
554 return UFunc(builder, UOp, left, right, "binary"); \
555 case kFloat_TypeKind: \
556 return FFunc(builder, FOp, left, right, "binary"); \
557 default: \
558 ABORT("unsupported typeKind"); \
559 } \
560 }
561 switch (b.fOperator) {
562 case Token::EQ: {
563 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *b.fLeft);
564 LLVMValueRef result = this->compileExpression(builder, *b.fRight);
565 lvalue->store(builder, result);
566 return result;
567 }
568 case Token::PLUS:
569 BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
570 case Token::MINUS:
571 BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
572 case Token::STAR:
573 BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
574 case Token::SLASH:
575 BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
576 case Token::PERCENT:
577 BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
578 case Token::BITWISEAND:
579 BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
580 case Token::BITWISEOR:
581 BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
582 case Token::SHL:
583 BINARY(LLVMBuildShl, LLVMBuildShl, LLVMBuildShl);
584 case Token::SHR:
585 BINARY(LLVMBuildAShr, LLVMBuildLShr, LLVMBuildAShr);
586 case Token::PLUSEQ:
587 COMPOUND(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
588 case Token::MINUSEQ:
589 COMPOUND(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
590 case Token::STAREQ:
591 COMPOUND(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
592 case Token::SLASHEQ:
593 COMPOUND(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
594 case Token::BITWISEANDEQ:
595 COMPOUND(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
596 case Token::BITWISEOREQ:
597 COMPOUND(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
598 case Token::EQEQ:
599 switch (b.fLeft->fType.kind()) {
600 case Type::kScalar_Kind:
601 COMPARE(LLVMBuildICmp, LLVMIntEQ,
602 LLVMBuildICmp, LLVMIntEQ,
603 LLVMBuildFCmp, LLVMRealOEQ);
604 case Type::kVector_Kind: {
605 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
606 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
607 this->vectorize(builder, b, &left, &right);
608 LLVMValueRef value;
609 switch (this->typeKind(b.fLeft->fType)) {
610 case kInt_TypeKind:
611 value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
612 break;
613 case kUInt_TypeKind:
614 value = LLVMBuildICmp(builder, LLVMIntEQ, left, right, "binary");
615 break;
616 case kFloat_TypeKind:
617 value = LLVMBuildFCmp(builder, LLVMRealOEQ, left, right, "binary");
618 break;
619 default:
620 ABORT("unsupported typeKind");
621 }
622 LLVMValueRef args[1] = { value };
623 LLVMValueRef func;
624 switch (b.fLeft->fType.columns()) {
625 case 2: func = fFoldAnd2Func; break;
626 case 3: func = fFoldAnd3Func; break;
627 case 4: func = fFoldAnd4Func; break;
628 default:
629 SkASSERT(false);
630 func = fFoldAnd2Func;
631 }
632 return LLVMBuildCall(builder, func, args, 1, "all");
633 }
634 default:
635 SkASSERT(false);
636 }
637 case Token::NEQ:
638 switch (b.fLeft->fType.kind()) {
639 case Type::kScalar_Kind:
640 COMPARE(LLVMBuildICmp, LLVMIntNE,
641 LLVMBuildICmp, LLVMIntNE,
642 LLVMBuildFCmp, LLVMRealONE);
643 case Type::kVector_Kind: {
644 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
645 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
646 this->vectorize(builder, b, &left, &right);
647 LLVMValueRef value;
648 switch (this->typeKind(b.fLeft->fType)) {
649 case kInt_TypeKind:
650 value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
651 break;
652 case kUInt_TypeKind:
653 value = LLVMBuildICmp(builder, LLVMIntNE, left, right, "binary");
654 break;
655 case kFloat_TypeKind:
656 value = LLVMBuildFCmp(builder, LLVMRealONE, left, right, "binary");
657 break;
658 default:
659 ABORT("unsupported typeKind");
660 }
661 LLVMValueRef args[1] = { value };
662 LLVMValueRef func;
663 switch (b.fLeft->fType.columns()) {
664 case 2: func = fFoldOr2Func; break;
665 case 3: func = fFoldOr3Func; break;
666 case 4: func = fFoldOr4Func; break;
667 default:
668 SkASSERT(false);
669 func = fFoldOr2Func;
670 }
671 return LLVMBuildCall(builder, func, args, 1, "all");
672 }
673 default:
674 SkASSERT(false);
675 }
676 case Token::LT:
677 COMPARE(LLVMBuildICmp, LLVMIntSLT,
678 LLVMBuildICmp, LLVMIntULT,
679 LLVMBuildFCmp, LLVMRealOLT);
680 case Token::LTEQ:
681 COMPARE(LLVMBuildICmp, LLVMIntSLE,
682 LLVMBuildICmp, LLVMIntULE,
683 LLVMBuildFCmp, LLVMRealOLE);
684 case Token::GT:
685 COMPARE(LLVMBuildICmp, LLVMIntSGT,
686 LLVMBuildICmp, LLVMIntUGT,
687 LLVMBuildFCmp, LLVMRealOGT);
688 case Token::GTEQ:
689 COMPARE(LLVMBuildICmp, LLVMIntSGE,
690 LLVMBuildICmp, LLVMIntUGE,
691 LLVMBuildFCmp, LLVMRealOGE);
692 case Token::LOGICALAND: {
693 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
694 LLVMBasicBlockRef ifFalse = fCurrentBlock;
695 LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
696 "true && ...");
697 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
698 "&& merge");
699 LLVMBuildCondBr(builder, left, ifTrue, merge);
700 this->setBlock(builder, ifTrue);
701 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
702 LLVMBuildBr(builder, merge);
703 this->setBlock(builder, merge);
704 LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "&&");
705 LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 0, false) };
706 LLVMBasicBlockRef incomingBlocks[2] = { ifTrue, ifFalse };
707 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
708 return phi;
709 }
710 case Token::LOGICALOR: {
711 LLVMValueRef left = this->compileExpression(builder, *b.fLeft);
712 LLVMBasicBlockRef ifTrue = fCurrentBlock;
713 LLVMBasicBlockRef ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
714 "false || ...");
715 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
716 "|| merge");
717 LLVMBuildCondBr(builder, left, merge, ifFalse);
718 this->setBlock(builder, ifFalse);
719 LLVMValueRef right = this->compileExpression(builder, *b.fRight);
720 LLVMBuildBr(builder, merge);
721 this->setBlock(builder, merge);
722 LLVMValueRef phi = LLVMBuildPhi(builder, fInt1Type, "||");
723 LLVMValueRef incomingValues[2] = { right, LLVMConstInt(fInt1Type, 1, false) };
724 LLVMBasicBlockRef incomingBlocks[2] = { ifFalse, ifTrue };
725 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
726 return phi;
727 }
728 default:
729 printf("%s\n", b.description().c_str());
730 ABORT("unsupported binary operator");
731 }
732 }
733
compileIndex(LLVMBuilderRef builder,const IndexExpression & idx)734 LLVMValueRef JIT::compileIndex(LLVMBuilderRef builder, const IndexExpression& idx) {
735 LLVMValueRef base = this->compileExpression(builder, *idx.fBase);
736 LLVMValueRef index = this->compileExpression(builder, *idx.fIndex);
737 LLVMValueRef ptr = LLVMBuildGEP(builder, base, &index, 1, "index ptr");
738 return LLVMBuildLoad(builder, ptr, "index load");
739 }
740
compilePostfix(LLVMBuilderRef builder,const PostfixExpression & p)741 LLVMValueRef JIT::compilePostfix(LLVMBuilderRef builder, const PostfixExpression& p) {
742 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
743 LLVMValueRef result = lvalue->load(builder);
744 LLVMValueRef mod;
745 LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
746 switch (p.fOperator) {
747 case Token::PLUSPLUS:
748 switch (this->typeKind(p.fType)) {
749 case kInt_TypeKind: // fall through
750 case kUInt_TypeKind:
751 mod = LLVMBuildAdd(builder, result, one, "++");
752 break;
753 case kFloat_TypeKind:
754 mod = LLVMBuildFAdd(builder, result, one, "++");
755 break;
756 default:
757 ABORT("unsupported typeKind");
758 }
759 break;
760 case Token::MINUSMINUS:
761 switch (this->typeKind(p.fType)) {
762 case kInt_TypeKind: // fall through
763 case kUInt_TypeKind:
764 mod = LLVMBuildSub(builder, result, one, "--");
765 break;
766 case kFloat_TypeKind:
767 mod = LLVMBuildFSub(builder, result, one, "--");
768 break;
769 default:
770 ABORT("unsupported typeKind");
771 }
772 break;
773 default:
774 ABORT("unsupported postfix op");
775 }
776 lvalue->store(builder, mod);
777 return result;
778 }
779
compilePrefix(LLVMBuilderRef builder,const PrefixExpression & p)780 LLVMValueRef JIT::compilePrefix(LLVMBuilderRef builder, const PrefixExpression& p) {
781 LLVMValueRef one = LLVMConstInt(this->getType(p.fType), 1, false);
782 if (Token::LOGICALNOT == p.fOperator) {
783 LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
784 return LLVMBuildXor(builder, base, one, "!");
785 }
786 if (Token::MINUS == p.fOperator) {
787 LLVMValueRef base = this->compileExpression(builder, *p.fOperand);
788 return LLVMBuildSub(builder, LLVMConstInt(this->getType(p.fType), 0, false), base, "-");
789 }
790 std::unique_ptr<LValue> lvalue = this->getLValue(builder, *p.fOperand);
791 LLVMValueRef raw = lvalue->load(builder);
792 LLVMValueRef result;
793 switch (p.fOperator) {
794 case Token::PLUSPLUS:
795 switch (this->typeKind(p.fType)) {
796 case kInt_TypeKind: // fall through
797 case kUInt_TypeKind:
798 result = LLVMBuildAdd(builder, raw, one, "++");
799 break;
800 case kFloat_TypeKind:
801 result = LLVMBuildFAdd(builder, raw, one, "++");
802 break;
803 default:
804 ABORT("unsupported typeKind");
805 }
806 break;
807 case Token::MINUSMINUS:
808 switch (this->typeKind(p.fType)) {
809 case kInt_TypeKind: // fall through
810 case kUInt_TypeKind:
811 result = LLVMBuildSub(builder, raw, one, "--");
812 break;
813 case kFloat_TypeKind:
814 result = LLVMBuildFSub(builder, raw, one, "--");
815 break;
816 default:
817 ABORT("unsupported typeKind");
818 }
819 break;
820 default:
821 ABORT("unsupported prefix op");
822 }
823 lvalue->store(builder, result);
824 return result;
825 }
826
compileVariableReference(LLVMBuilderRef builder,const VariableReference & v)827 LLVMValueRef JIT::compileVariableReference(LLVMBuilderRef builder, const VariableReference& v) {
828 const Variable& var = v.fVariable;
829 if (Variable::kParameter_Storage == var.fStorage &&
830 !(var.fModifiers.fFlags & Modifiers::kOut_Flag) &&
831 fPromotedParameters.find(&var) == fPromotedParameters.end()) {
832 return fVariables[&var];
833 }
834 return LLVMBuildLoad(builder, fVariables[&var], String(var.fName).c_str());
835 }
836
appendStage(LLVMBuilderRef builder,const AppendStage & a)837 void JIT::appendStage(LLVMBuilderRef builder, const AppendStage& a) {
838 SkASSERT(a.fArguments.size() >= 1);
839 SkASSERT(a.fArguments[0]->fType == *fCompiler.context().fSkRasterPipeline_Type);
840 LLVMValueRef pipeline = this->compileExpression(builder, *a.fArguments[0]);
841 LLVMValueRef stage = LLVMConstInt(fInt32Type, a.fStage, 0);
842 switch (a.fStage) {
843 case SkRasterPipeline::callback: {
844 SkASSERT(a.fArguments.size() == 2);
845 SkASSERT(a.fArguments[1]->fKind == Expression::kFunctionReference_Kind);
846 const FunctionDeclaration& functionDecl =
847 *((FunctionReference&) *a.fArguments[1]).fFunctions[0];
848 bool found = false;
849 for (const auto& pe : *fProgram) {
850 if (ProgramElement::kFunction_Kind == pe.fKind) {
851 const FunctionDefinition& def = (const FunctionDefinition&) pe;
852 if (&def.fDeclaration == &functionDecl) {
853 LLVMValueRef fn = this->compileStageFunction(def);
854 LLVMValueRef args[2] = {
855 pipeline,
856 LLVMBuildBitCast(builder, fn, fInt8PtrType, "callback cast")
857 };
858 LLVMBuildCall(builder, fAppendCallbackFunc, args, 2, "");
859 found = true;
860 break;
861 }
862 }
863 }
864 SkASSERT(found);
865 break;
866 }
867 default: {
868 LLVMValueRef ctx;
869 if (a.fArguments.size() == 2) {
870 ctx = this->compileExpression(builder, *a.fArguments[1]);
871 ctx = LLVMBuildBitCast(builder, ctx, fInt8PtrType, "context cast");
872 } else {
873 SkASSERT(a.fArguments.size() == 1);
874 ctx = LLVMConstNull(fInt8PtrType);
875 }
876 LLVMValueRef args[3] = {
877 pipeline,
878 stage,
879 ctx
880 };
881 LLVMBuildCall(builder, fAppendFunc, args, 3, "");
882 break;
883 }
884 }
885 }
886
compileConstructor(LLVMBuilderRef builder,const Constructor & c)887 LLVMValueRef JIT::compileConstructor(LLVMBuilderRef builder, const Constructor& c) {
888 switch (c.fType.kind()) {
889 case Type::kScalar_Kind: {
890 SkASSERT(c.fArguments.size() == 1);
891 TypeKind from = this->typeKind(c.fArguments[0]->fType);
892 TypeKind to = this->typeKind(c.fType);
893 LLVMValueRef base = this->compileExpression(builder, *c.fArguments[0]);
894 switch (to) {
895 case kFloat_TypeKind:
896 switch (from) {
897 case kInt_TypeKind:
898 return LLVMBuildSIToFP(builder, base, this->getType(c.fType), "cast");
899 case kUInt_TypeKind:
900 return LLVMBuildUIToFP(builder, base, this->getType(c.fType), "cast");
901 case kFloat_TypeKind:
902 return base;
903 case kBool_TypeKind:
904 SkASSERT(false);
905 }
906 case kInt_TypeKind:
907 switch (from) {
908 case kInt_TypeKind:
909 return base;
910 case kUInt_TypeKind:
911 return base;
912 case kFloat_TypeKind:
913 return LLVMBuildFPToSI(builder, base, this->getType(c.fType), "cast");
914 case kBool_TypeKind:
915 SkASSERT(false);
916 }
917 case kUInt_TypeKind:
918 switch (from) {
919 case kInt_TypeKind:
920 return base;
921 case kUInt_TypeKind:
922 return base;
923 case kFloat_TypeKind:
924 return LLVMBuildFPToUI(builder, base, this->getType(c.fType), "cast");
925 case kBool_TypeKind:
926 SkASSERT(false);
927 }
928 case kBool_TypeKind:
929 SkASSERT(false);
930 }
931 }
932 case Type::kVector_Kind: {
933 LLVMValueRef vec = LLVMGetUndef(this->getType(c.fType));
934 if (c.fArguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
935 LLVMValueRef value = this->compileExpression(builder, *c.fArguments[0]);
936 for (int i = 0; i < c.fType.columns(); ++i) {
937 vec = LLVMBuildInsertElement(builder, vec, value,
938 LLVMConstInt(fInt32Type, i, false),
939 "vec build 1");
940 }
941 } else {
942 int index = 0;
943 for (const auto& arg : c.fArguments) {
944 LLVMValueRef value = this->compileExpression(builder, *arg);
945 if (arg->fType.kind() == Type::kVector_Kind) {
946 for (int i = 0; i < arg->fType.columns(); ++i) {
947 LLVMValueRef column = LLVMBuildExtractElement(builder,
948 vec,
949 LLVMConstInt(fInt32Type,
950 i,
951 false),
952 "construct extract");
953 vec = LLVMBuildInsertElement(builder, vec, column,
954 LLVMConstInt(fInt32Type, index++, false),
955 "vec build 2");
956 }
957 } else {
958 vec = LLVMBuildInsertElement(builder, vec, value,
959 LLVMConstInt(fInt32Type, index++, false),
960 "vec build 3");
961 }
962 }
963 }
964 return vec;
965 }
966 default:
967 break;
968 }
969 ABORT("unsupported constructor");
970 }
971
compileSwizzle(LLVMBuilderRef builder,const Swizzle & s)972 LLVMValueRef JIT::compileSwizzle(LLVMBuilderRef builder, const Swizzle& s) {
973 LLVMValueRef base = this->compileExpression(builder, *s.fBase);
974 if (s.fComponents.size() > 1) {
975 LLVMValueRef result = LLVMGetUndef(this->getType(s.fType));
976 for (size_t i = 0; i < s.fComponents.size(); ++i) {
977 LLVMValueRef element = LLVMBuildExtractElement(
978 builder,
979 base,
980 LLVMConstInt(fInt32Type,
981 s.fComponents[i],
982 false),
983 "swizzle extract");
984 result = LLVMBuildInsertElement(builder, result, element,
985 LLVMConstInt(fInt32Type, i, false),
986 "swizzle insert");
987 }
988 return result;
989 }
990 SkASSERT(s.fComponents.size() == 1);
991 return LLVMBuildExtractElement(builder, base,
992 LLVMConstInt(fInt32Type,
993 s.fComponents[0],
994 false),
995 "swizzle extract");
996 }
997
compileTernary(LLVMBuilderRef builder,const TernaryExpression & t)998 LLVMValueRef JIT::compileTernary(LLVMBuilderRef builder, const TernaryExpression& t) {
999 LLVMValueRef test = this->compileExpression(builder, *t.fTest);
1000 LLVMBasicBlockRef trueBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1001 "if true");
1002 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1003 "if merge");
1004 LLVMBasicBlockRef falseBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1005 "if false");
1006 LLVMBuildCondBr(builder, test, trueBlock, falseBlock);
1007 this->setBlock(builder, trueBlock);
1008 LLVMValueRef ifTrue = this->compileExpression(builder, *t.fIfTrue);
1009 trueBlock = fCurrentBlock;
1010 LLVMBuildBr(builder, merge);
1011 this->setBlock(builder, falseBlock);
1012 LLVMValueRef ifFalse = this->compileExpression(builder, *t.fIfFalse);
1013 falseBlock = fCurrentBlock;
1014 LLVMBuildBr(builder, merge);
1015 this->setBlock(builder, merge);
1016 LLVMValueRef phi = LLVMBuildPhi(builder, this->getType(t.fType), "?");
1017 LLVMValueRef incomingValues[2] = { ifTrue, ifFalse };
1018 LLVMBasicBlockRef incomingBlocks[2] = { trueBlock, falseBlock };
1019 LLVMAddIncoming(phi, incomingValues, incomingBlocks, 2);
1020 return phi;
1021 }
1022
compileExpression(LLVMBuilderRef builder,const Expression & expr)1023 LLVMValueRef JIT::compileExpression(LLVMBuilderRef builder, const Expression& expr) {
1024 switch (expr.fKind) {
1025 case Expression::kAppendStage_Kind: {
1026 this->appendStage(builder, (const AppendStage&) expr);
1027 return LLVMValueRef();
1028 }
1029 case Expression::kBinary_Kind:
1030 return this->compileBinary(builder, (BinaryExpression&) expr);
1031 case Expression::kBoolLiteral_Kind:
1032 return LLVMConstInt(fInt1Type, ((BoolLiteral&) expr).fValue, false);
1033 case Expression::kConstructor_Kind:
1034 return this->compileConstructor(builder, (Constructor&) expr);
1035 case Expression::kIntLiteral_Kind:
1036 return LLVMConstInt(this->getType(expr.fType), ((IntLiteral&) expr).fValue, true);
1037 case Expression::kFieldAccess_Kind:
1038 abort();
1039 case Expression::kFloatLiteral_Kind:
1040 return LLVMConstReal(this->getType(expr.fType), ((FloatLiteral&) expr).fValue);
1041 case Expression::kFunctionCall_Kind:
1042 return this->compileFunctionCall(builder, (FunctionCall&) expr);
1043 case Expression::kIndex_Kind:
1044 return this->compileIndex(builder, (IndexExpression&) expr);
1045 case Expression::kPrefix_Kind:
1046 return this->compilePrefix(builder, (PrefixExpression&) expr);
1047 case Expression::kPostfix_Kind:
1048 return this->compilePostfix(builder, (PostfixExpression&) expr);
1049 case Expression::kSetting_Kind:
1050 abort();
1051 case Expression::kSwizzle_Kind:
1052 return this->compileSwizzle(builder, (Swizzle&) expr);
1053 case Expression::kVariableReference_Kind:
1054 return this->compileVariableReference(builder, (VariableReference&) expr);
1055 case Expression::kTernary_Kind:
1056 return this->compileTernary(builder, (TernaryExpression&) expr);
1057 case Expression::kTypeReference_Kind:
1058 abort();
1059 default:
1060 abort();
1061 }
1062 ABORT("unsupported expression: %s\n", expr.description().c_str());
1063 }
1064
compileBlock(LLVMBuilderRef builder,const Block & block)1065 void JIT::compileBlock(LLVMBuilderRef builder, const Block& block) {
1066 for (const auto& stmt : block.fStatements) {
1067 this->compileStatement(builder, *stmt);
1068 }
1069 }
1070
compileVarDeclarations(LLVMBuilderRef builder,const VarDeclarationsStatement & decls)1071 void JIT::compileVarDeclarations(LLVMBuilderRef builder, const VarDeclarationsStatement& decls) {
1072 for (const auto& declStatement : decls.fDeclaration->fVars) {
1073 const VarDeclaration& decl = (VarDeclaration&) *declStatement;
1074 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1075 LLVMValueRef alloca = LLVMBuildAlloca(builder, this->getType(decl.fVar->fType),
1076 String(decl.fVar->fName).c_str());
1077 fVariables[decl.fVar] = alloca;
1078 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1079 if (decl.fValue) {
1080 LLVMValueRef result = this->compileExpression(builder, *decl.fValue);
1081 LLVMBuildStore(builder, result, alloca);
1082 }
1083 }
1084 }
1085
compileIf(LLVMBuilderRef builder,const IfStatement & i)1086 void JIT::compileIf(LLVMBuilderRef builder, const IfStatement& i) {
1087 LLVMValueRef test = this->compileExpression(builder, *i.fTest);
1088 LLVMBasicBlockRef ifTrue = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if true");
1089 LLVMBasicBlockRef merge = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1090 "if merge");
1091 LLVMBasicBlockRef ifFalse;
1092 if (i.fIfFalse) {
1093 ifFalse = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "if false");
1094 } else {
1095 ifFalse = merge;
1096 }
1097 LLVMBuildCondBr(builder, test, ifTrue, ifFalse);
1098 this->setBlock(builder, ifTrue);
1099 this->compileStatement(builder, *i.fIfTrue);
1100 if (!ends_with_branch(*i.fIfTrue)) {
1101 LLVMBuildBr(builder, merge);
1102 }
1103 if (i.fIfFalse) {
1104 this->setBlock(builder, ifFalse);
1105 this->compileStatement(builder, *i.fIfFalse);
1106 if (!ends_with_branch(*i.fIfFalse)) {
1107 LLVMBuildBr(builder, merge);
1108 }
1109 }
1110 this->setBlock(builder, merge);
1111 }
1112
compileFor(LLVMBuilderRef builder,const ForStatement & f)1113 void JIT::compileFor(LLVMBuilderRef builder, const ForStatement& f) {
1114 if (f.fInitializer) {
1115 this->compileStatement(builder, *f.fInitializer);
1116 }
1117 LLVMBasicBlockRef start;
1118 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for body");
1119 LLVMBasicBlockRef next = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for next");
1120 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for end");
1121 if (f.fTest) {
1122 start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "for test");
1123 LLVMBuildBr(builder, start);
1124 this->setBlock(builder, start);
1125 LLVMValueRef test = this->compileExpression(builder, *f.fTest);
1126 LLVMBuildCondBr(builder, test, body, end);
1127 } else {
1128 start = body;
1129 LLVMBuildBr(builder, body);
1130 }
1131 this->setBlock(builder, body);
1132 fBreakTarget.push_back(end);
1133 fContinueTarget.push_back(next);
1134 this->compileStatement(builder, *f.fStatement);
1135 fBreakTarget.pop_back();
1136 fContinueTarget.pop_back();
1137 if (!ends_with_branch(*f.fStatement)) {
1138 LLVMBuildBr(builder, next);
1139 }
1140 this->setBlock(builder, next);
1141 if (f.fNext) {
1142 this->compileExpression(builder, *f.fNext);
1143 }
1144 LLVMBuildBr(builder, start);
1145 this->setBlock(builder, end);
1146 }
1147
compileDo(LLVMBuilderRef builder,const DoStatement & d)1148 void JIT::compileDo(LLVMBuilderRef builder, const DoStatement& d) {
1149 LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1150 "do test");
1151 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1152 "do body");
1153 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1154 "do end");
1155 LLVMBuildBr(builder, body);
1156 this->setBlock(builder, testBlock);
1157 LLVMValueRef test = this->compileExpression(builder, *d.fTest);
1158 LLVMBuildCondBr(builder, test, body, end);
1159 this->setBlock(builder, body);
1160 fBreakTarget.push_back(end);
1161 fContinueTarget.push_back(body);
1162 this->compileStatement(builder, *d.fStatement);
1163 fBreakTarget.pop_back();
1164 fContinueTarget.pop_back();
1165 if (!ends_with_branch(*d.fStatement)) {
1166 LLVMBuildBr(builder, testBlock);
1167 }
1168 this->setBlock(builder, end);
1169 }
1170
compileWhile(LLVMBuilderRef builder,const WhileStatement & w)1171 void JIT::compileWhile(LLVMBuilderRef builder, const WhileStatement& w) {
1172 LLVMBasicBlockRef testBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1173 "while test");
1174 LLVMBasicBlockRef body = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1175 "while body");
1176 LLVMBasicBlockRef end = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction,
1177 "while end");
1178 LLVMBuildBr(builder, testBlock);
1179 this->setBlock(builder, testBlock);
1180 LLVMValueRef test = this->compileExpression(builder, *w.fTest);
1181 LLVMBuildCondBr(builder, test, body, end);
1182 this->setBlock(builder, body);
1183 fBreakTarget.push_back(end);
1184 fContinueTarget.push_back(testBlock);
1185 this->compileStatement(builder, *w.fStatement);
1186 fBreakTarget.pop_back();
1187 fContinueTarget.pop_back();
1188 if (!ends_with_branch(*w.fStatement)) {
1189 LLVMBuildBr(builder, testBlock);
1190 }
1191 this->setBlock(builder, end);
1192 }
1193
compileBreak(LLVMBuilderRef builder,const BreakStatement & b)1194 void JIT::compileBreak(LLVMBuilderRef builder, const BreakStatement& b) {
1195 LLVMBuildBr(builder, fBreakTarget.back());
1196 }
1197
compileContinue(LLVMBuilderRef builder,const ContinueStatement & b)1198 void JIT::compileContinue(LLVMBuilderRef builder, const ContinueStatement& b) {
1199 LLVMBuildBr(builder, fContinueTarget.back());
1200 }
1201
compileReturn(LLVMBuilderRef builder,const ReturnStatement & r)1202 void JIT::compileReturn(LLVMBuilderRef builder, const ReturnStatement& r) {
1203 if (r.fExpression) {
1204 LLVMBuildRet(builder, this->compileExpression(builder, *r.fExpression));
1205 } else {
1206 LLVMBuildRetVoid(builder);
1207 }
1208 }
1209
compileStatement(LLVMBuilderRef builder,const Statement & stmt)1210 void JIT::compileStatement(LLVMBuilderRef builder, const Statement& stmt) {
1211 switch (stmt.fKind) {
1212 case Statement::kBlock_Kind:
1213 this->compileBlock(builder, (Block&) stmt);
1214 break;
1215 case Statement::kBreak_Kind:
1216 this->compileBreak(builder, (BreakStatement&) stmt);
1217 break;
1218 case Statement::kContinue_Kind:
1219 this->compileContinue(builder, (ContinueStatement&) stmt);
1220 break;
1221 case Statement::kDiscard_Kind:
1222 abort();
1223 case Statement::kDo_Kind:
1224 this->compileDo(builder, (DoStatement&) stmt);
1225 break;
1226 case Statement::kExpression_Kind:
1227 this->compileExpression(builder, *((ExpressionStatement&) stmt).fExpression);
1228 break;
1229 case Statement::kFor_Kind:
1230 this->compileFor(builder, (ForStatement&) stmt);
1231 break;
1232 case Statement::kGroup_Kind:
1233 abort();
1234 case Statement::kIf_Kind:
1235 this->compileIf(builder, (IfStatement&) stmt);
1236 break;
1237 case Statement::kNop_Kind:
1238 break;
1239 case Statement::kReturn_Kind:
1240 this->compileReturn(builder, (ReturnStatement&) stmt);
1241 break;
1242 case Statement::kSwitch_Kind:
1243 abort();
1244 case Statement::kVarDeclarations_Kind:
1245 this->compileVarDeclarations(builder, (VarDeclarationsStatement&) stmt);
1246 break;
1247 case Statement::kWhile_Kind:
1248 this->compileWhile(builder, (WhileStatement&) stmt);
1249 break;
1250 default:
1251 abort();
1252 }
1253 }
1254
compileStageFunctionLoop(const FunctionDefinition & f,LLVMValueRef newFunc)1255 void JIT::compileStageFunctionLoop(const FunctionDefinition& f, LLVMValueRef newFunc) {
1256 // loop over fVectorCount pixels, running the body of the stage function for each of them
1257 LLVMValueRef oldFunction = fCurrentFunction;
1258 fCurrentFunction = newFunc;
1259 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1260 LLVMGetParams(fCurrentFunction, params.get());
1261 LLVMValueRef programParam = params.get()[1];
1262 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1263 LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1264 LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1265 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1266 this->setBlock(builder, fAllocaBlock);
1267 // temporaries to store the color channel vectors
1268 LLVMValueRef rVec = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1269 LLVMBuildStore(builder, params.get()[4], rVec);
1270 LLVMValueRef gVec = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1271 LLVMBuildStore(builder, params.get()[5], gVec);
1272 LLVMValueRef bVec = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1273 LLVMBuildStore(builder, params.get()[6], bVec);
1274 LLVMValueRef aVec = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1275 LLVMBuildStore(builder, params.get()[7], aVec);
1276 LLVMValueRef color = LLVMBuildAlloca(builder, fFloat32Vector4Type, "color");
1277 fVariables[f.fDeclaration.fParameters[1]] = LLVMBuildTrunc(builder, params.get()[3], fInt32Type,
1278 "y->Int32");
1279 fVariables[f.fDeclaration.fParameters[2]] = color;
1280 LLVMValueRef ivar = LLVMBuildAlloca(builder, fInt32Type, "i");
1281 LLVMBuildStore(builder, LLVMConstInt(fInt32Type, 0, false), ivar);
1282 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1283 this->setBlock(builder, start);
1284 LLVMValueRef iload = LLVMBuildLoad(builder, ivar, "load i");
1285 fVariables[f.fDeclaration.fParameters[0]] = LLVMBuildAdd(builder,
1286 LLVMBuildTrunc(builder,
1287 params.get()[2],
1288 fInt32Type,
1289 "x->Int32"),
1290 iload,
1291 "x");
1292 LLVMValueRef vectorSize = LLVMConstInt(fInt32Type, fVectorCount, false);
1293 LLVMValueRef test = LLVMBuildICmp(builder, LLVMIntSLT, iload, vectorSize, "i < vectorSize");
1294 LLVMBasicBlockRef loopBody = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "body");
1295 LLVMBasicBlockRef loopEnd = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "end");
1296 LLVMBuildCondBr(builder, test, loopBody, loopEnd);
1297 this->setBlock(builder, loopBody);
1298 LLVMValueRef vec = LLVMGetUndef(fFloat32Vector4Type);
1299 // extract the r, g, b, and a values from the color channel vectors and store them into "color"
1300 for (int i = 0; i < 4; ++i) {
1301 vec = LLVMBuildInsertElement(builder, vec,
1302 LLVMBuildExtractElement(builder,
1303 params.get()[4 + i],
1304 iload, "initial"),
1305 LLVMConstInt(fInt32Type, i, false),
1306 "vec build");
1307 }
1308 LLVMBuildStore(builder, vec, color);
1309 // write actual loop body
1310 this->compileStatement(builder, *f.fBody);
1311 // extract the r, g, b, and a values from "color" and stick them back into the color channel
1312 // vectors
1313 LLVMValueRef colorLoad = LLVMBuildLoad(builder, color, "color load");
1314 LLVMBuildStore(builder,
1315 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, rVec, "rVec"),
1316 LLVMBuildExtractElement(builder, colorLoad,
1317 LLVMConstInt(fInt32Type, 0,
1318 false),
1319 "rExtract"),
1320 iload, "rInsert"),
1321 rVec);
1322 LLVMBuildStore(builder,
1323 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, gVec, "gVec"),
1324 LLVMBuildExtractElement(builder, colorLoad,
1325 LLVMConstInt(fInt32Type, 1,
1326 false),
1327 "gExtract"),
1328 iload, "gInsert"),
1329 gVec);
1330 LLVMBuildStore(builder,
1331 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, bVec, "bVec"),
1332 LLVMBuildExtractElement(builder, colorLoad,
1333 LLVMConstInt(fInt32Type, 2,
1334 false),
1335 "bExtract"),
1336 iload, "bInsert"),
1337 bVec);
1338 LLVMBuildStore(builder,
1339 LLVMBuildInsertElement(builder, LLVMBuildLoad(builder, aVec, "aVec"),
1340 LLVMBuildExtractElement(builder, colorLoad,
1341 LLVMConstInt(fInt32Type, 3,
1342 false),
1343 "aExtract"),
1344 iload, "aInsert"),
1345 aVec);
1346 LLVMValueRef inc = LLVMBuildAdd(builder, iload, LLVMConstInt(fInt32Type, 1, false), "inc i");
1347 LLVMBuildStore(builder, inc, ivar);
1348 LLVMBuildBr(builder, start);
1349 this->setBlock(builder, loopEnd);
1350 // increment program pointer, call the next stage
1351 LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1352 LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1353 LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType, "cast next->func");
1354 LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1355 LLVMBuildAdd(builder,
1356 LLVMBuildPtrToInt(builder,
1357 programParam,
1358 fInt64Type,
1359 "cast 1"),
1360 LLVMConstInt(fInt64Type, PTR_SIZE, false),
1361 "add"),
1362 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1363 LLVMValueRef args[STAGE_PARAM_COUNT] = {
1364 params.get()[0],
1365 nextInc,
1366 params.get()[2],
1367 params.get()[3],
1368 LLVMBuildLoad(builder, rVec, "rVec"),
1369 LLVMBuildLoad(builder, gVec, "gVec"),
1370 LLVMBuildLoad(builder, bVec, "bVec"),
1371 LLVMBuildLoad(builder, aVec, "aVec"),
1372 params.get()[8],
1373 params.get()[9],
1374 params.get()[10],
1375 params.get()[11]
1376 };
1377 LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1378 LLVMBuildRetVoid(builder);
1379 // finish
1380 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1381 LLVMBuildBr(builder, start);
1382 LLVMDisposeBuilder(builder);
1383 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1384 ABORT("verify failed\n");
1385 }
1386 fAllocaBlock = oldAllocaBlock;
1387 fCurrentBlock = oldCurrentBlock;
1388 fCurrentFunction = oldFunction;
1389 }
1390
1391 // FIXME maybe pluggable code generators? Need to do something to separate all
1392 // of the normal codegen from the vector codegen and break this up into multiple
1393 // classes.
1394
getVectorLValue(LLVMBuilderRef builder,const Expression & e,LLVMValueRef out[CHANNELS])1395 bool JIT::getVectorLValue(LLVMBuilderRef builder, const Expression& e,
1396 LLVMValueRef out[CHANNELS]) {
1397 switch (e.fKind) {
1398 case Expression::kVariableReference_Kind:
1399 if (fColorParam == &((VariableReference&) e).fVariable) {
1400 memcpy(out, fChannels, sizeof(fChannels));
1401 return true;
1402 }
1403 return false;
1404 case Expression::kSwizzle_Kind: {
1405 const Swizzle& s = (const Swizzle&) e;
1406 LLVMValueRef base[CHANNELS];
1407 if (!this->getVectorLValue(builder, *s.fBase, base)) {
1408 return false;
1409 }
1410 for (size_t i = 0; i < s.fComponents.size(); ++i) {
1411 out[i] = base[s.fComponents[i]];
1412 }
1413 return true;
1414 }
1415 default:
1416 return false;
1417 }
1418 }
1419
getVectorBinaryOperands(LLVMBuilderRef builder,const Expression & left,LLVMValueRef outLeft[CHANNELS],const Expression & right,LLVMValueRef outRight[CHANNELS])1420 bool JIT::getVectorBinaryOperands(LLVMBuilderRef builder, const Expression& left,
1421 LLVMValueRef outLeft[CHANNELS], const Expression& right,
1422 LLVMValueRef outRight[CHANNELS]) {
1423 if (!this->compileVectorExpression(builder, left, outLeft)) {
1424 return false;
1425 }
1426 int leftColumns = left.fType.columns();
1427 int rightColumns = right.fType.columns();
1428 if (leftColumns == 1 && rightColumns > 1) {
1429 for (int i = 1; i < rightColumns; ++i) {
1430 outLeft[i] = outLeft[0];
1431 }
1432 }
1433 if (!this->compileVectorExpression(builder, right, outRight)) {
1434 return false;
1435 }
1436 if (rightColumns == 1 && leftColumns > 1) {
1437 for (int i = 1; i < leftColumns; ++i) {
1438 outRight[i] = outRight[0];
1439 }
1440 }
1441 return true;
1442 }
1443
compileVectorBinary(LLVMBuilderRef builder,const BinaryExpression & b,LLVMValueRef out[CHANNELS])1444 bool JIT::compileVectorBinary(LLVMBuilderRef builder, const BinaryExpression& b,
1445 LLVMValueRef out[CHANNELS]) {
1446 LLVMValueRef left[CHANNELS];
1447 LLVMValueRef right[CHANNELS];
1448 #define VECTOR_BINARY(signedOp, unsignedOp, floatOp) { \
1449 if (!this->getVectorBinaryOperands(builder, *b.fLeft, left, *b.fRight, right)) { \
1450 return false; \
1451 } \
1452 for (int i = 0; i < b.fLeft->fType.columns(); ++i) { \
1453 switch (this->typeKind(b.fLeft->fType)) { \
1454 case kInt_TypeKind: \
1455 out[i] = signedOp(builder, left[i], right[i], "binary"); \
1456 break; \
1457 case kUInt_TypeKind: \
1458 out[i] = unsignedOp(builder, left[i], right[i], "binary"); \
1459 break; \
1460 case kFloat_TypeKind: \
1461 out[i] = floatOp(builder, left[i], right[i], "binary"); \
1462 break; \
1463 case kBool_TypeKind: \
1464 SkASSERT(false); \
1465 break; \
1466 } \
1467 } \
1468 return true; \
1469 }
1470 switch (b.fOperator) {
1471 case Token::EQ: {
1472 if (!this->getVectorLValue(builder, *b.fLeft, left)) {
1473 return false;
1474 }
1475 if (!this->compileVectorExpression(builder, *b.fRight, right)) {
1476 return false;
1477 }
1478 int columns = b.fRight->fType.columns();
1479 for (int i = 0; i < columns; ++i) {
1480 LLVMBuildStore(builder, right[i], left[i]);
1481 }
1482 return true;
1483 }
1484 case Token::PLUS:
1485 VECTOR_BINARY(LLVMBuildAdd, LLVMBuildAdd, LLVMBuildFAdd);
1486 case Token::MINUS:
1487 VECTOR_BINARY(LLVMBuildSub, LLVMBuildSub, LLVMBuildFSub);
1488 case Token::STAR:
1489 VECTOR_BINARY(LLVMBuildMul, LLVMBuildMul, LLVMBuildFMul);
1490 case Token::SLASH:
1491 VECTOR_BINARY(LLVMBuildSDiv, LLVMBuildUDiv, LLVMBuildFDiv);
1492 case Token::PERCENT:
1493 VECTOR_BINARY(LLVMBuildSRem, LLVMBuildURem, LLVMBuildSRem);
1494 case Token::BITWISEAND:
1495 VECTOR_BINARY(LLVMBuildAnd, LLVMBuildAnd, LLVMBuildAnd);
1496 case Token::BITWISEOR:
1497 VECTOR_BINARY(LLVMBuildOr, LLVMBuildOr, LLVMBuildOr);
1498 default:
1499 printf("unsupported operator: %s\n", b.description().c_str());
1500 return false;
1501 }
1502 }
1503
compileVectorConstructor(LLVMBuilderRef builder,const Constructor & c,LLVMValueRef out[CHANNELS])1504 bool JIT::compileVectorConstructor(LLVMBuilderRef builder, const Constructor& c,
1505 LLVMValueRef out[CHANNELS]) {
1506 switch (c.fType.kind()) {
1507 case Type::kScalar_Kind: {
1508 SkASSERT(c.fArguments.size() == 1);
1509 TypeKind from = this->typeKind(c.fArguments[0]->fType);
1510 TypeKind to = this->typeKind(c.fType);
1511 LLVMValueRef base[CHANNELS];
1512 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1513 return false;
1514 }
1515 #define CONSTRUCT(fn) \
1516 out[0] = LLVMGetUndef(LLVMVectorType(this->getType(c.fType), fVectorCount)); \
1517 for (int i = 0; i < fVectorCount; ++i) { \
1518 LLVMValueRef index = LLVMConstInt(fInt32Type, i, false); \
1519 LLVMValueRef baseVal = LLVMBuildExtractElement(builder, base[0], index, \
1520 "construct extract"); \
1521 out[0] = LLVMBuildInsertElement(builder, out[0], \
1522 fn(builder, baseVal, this->getType(c.fType), \
1523 "cast"), \
1524 index, "construct insert"); \
1525 } \
1526 return true;
1527 if (kFloat_TypeKind == to) {
1528 if (kInt_TypeKind == from) {
1529 CONSTRUCT(LLVMBuildSIToFP);
1530 }
1531 if (kUInt_TypeKind == from) {
1532 CONSTRUCT(LLVMBuildUIToFP);
1533 }
1534 }
1535 if (kInt_TypeKind == to) {
1536 if (kFloat_TypeKind == from) {
1537 CONSTRUCT(LLVMBuildFPToSI);
1538 }
1539 if (kUInt_TypeKind == from) {
1540 return true;
1541 }
1542 }
1543 if (kUInt_TypeKind == to) {
1544 if (kFloat_TypeKind == from) {
1545 CONSTRUCT(LLVMBuildFPToUI);
1546 }
1547 if (kInt_TypeKind == from) {
1548 return base;
1549 }
1550 }
1551 printf("%s\n", c.description().c_str());
1552 ABORT("unsupported constructor");
1553 }
1554 case Type::kVector_Kind: {
1555 if (c.fArguments.size() == 1) {
1556 LLVMValueRef base[CHANNELS];
1557 if (!this->compileVectorExpression(builder, *c.fArguments[0], base)) {
1558 return false;
1559 }
1560 for (int i = 0; i < c.fType.columns(); ++i) {
1561 out[i] = base[0];
1562 }
1563 } else {
1564 SkASSERT(c.fArguments.size() == (size_t) c.fType.columns());
1565 for (int i = 0; i < c.fType.columns(); ++i) {
1566 LLVMValueRef base[CHANNELS];
1567 if (!this->compileVectorExpression(builder, *c.fArguments[i], base)) {
1568 return false;
1569 }
1570 out[i] = base[0];
1571 }
1572 }
1573 return true;
1574 }
1575 default:
1576 break;
1577 }
1578 ABORT("unsupported constructor");
1579 }
1580
compileVectorFloatLiteral(LLVMBuilderRef builder,const FloatLiteral & f,LLVMValueRef out[CHANNELS])1581 bool JIT::compileVectorFloatLiteral(LLVMBuilderRef builder,
1582 const FloatLiteral& f,
1583 LLVMValueRef out[CHANNELS]) {
1584 LLVMValueRef value = LLVMConstReal(this->getType(f.fType), f.fValue);
1585 LLVMValueRef values[MAX_VECTOR_COUNT];
1586 for (int i = 0; i < fVectorCount; ++i) {
1587 values[i] = value;
1588 }
1589 out[0] = LLVMConstVector(values, fVectorCount);
1590 return true;
1591 }
1592
1593
compileVectorSwizzle(LLVMBuilderRef builder,const Swizzle & s,LLVMValueRef out[CHANNELS])1594 bool JIT::compileVectorSwizzle(LLVMBuilderRef builder, const Swizzle& s,
1595 LLVMValueRef out[CHANNELS]) {
1596 LLVMValueRef all[CHANNELS];
1597 if (!this->compileVectorExpression(builder, *s.fBase, all)) {
1598 return false;
1599 }
1600 for (size_t i = 0; i < s.fComponents.size(); ++i) {
1601 out[i] = all[s.fComponents[i]];
1602 }
1603 return true;
1604 }
1605
compileVectorVariableReference(LLVMBuilderRef builder,const VariableReference & v,LLVMValueRef out[CHANNELS])1606 bool JIT::compileVectorVariableReference(LLVMBuilderRef builder, const VariableReference& v,
1607 LLVMValueRef out[CHANNELS]) {
1608 if (&v.fVariable == fColorParam) {
1609 for (int i = 0; i < CHANNELS; ++i) {
1610 out[i] = LLVMBuildLoad(builder, fChannels[i], "variable reference");
1611 }
1612 return true;
1613 }
1614 return false;
1615 }
1616
compileVectorExpression(LLVMBuilderRef builder,const Expression & expr,LLVMValueRef out[CHANNELS])1617 bool JIT::compileVectorExpression(LLVMBuilderRef builder, const Expression& expr,
1618 LLVMValueRef out[CHANNELS]) {
1619 switch (expr.fKind) {
1620 case Expression::kBinary_Kind:
1621 return this->compileVectorBinary(builder, (const BinaryExpression&) expr, out);
1622 case Expression::kConstructor_Kind:
1623 return this->compileVectorConstructor(builder, (const Constructor&) expr, out);
1624 case Expression::kFloatLiteral_Kind:
1625 return this->compileVectorFloatLiteral(builder, (const FloatLiteral&) expr, out);
1626 case Expression::kSwizzle_Kind:
1627 return this->compileVectorSwizzle(builder, (const Swizzle&) expr, out);
1628 case Expression::kVariableReference_Kind:
1629 return this->compileVectorVariableReference(builder, (const VariableReference&) expr,
1630 out);
1631 default:
1632 return false;
1633 }
1634 }
1635
compileVectorStatement(LLVMBuilderRef builder,const Statement & stmt)1636 bool JIT::compileVectorStatement(LLVMBuilderRef builder, const Statement& stmt) {
1637 switch (stmt.fKind) {
1638 case Statement::kBlock_Kind:
1639 for (const auto& s : ((const Block&) stmt).fStatements) {
1640 if (!this->compileVectorStatement(builder, *s)) {
1641 return false;
1642 }
1643 }
1644 return true;
1645 case Statement::kExpression_Kind:
1646 LLVMValueRef result;
1647 return this->compileVectorExpression(builder,
1648 *((const ExpressionStatement&) stmt).fExpression,
1649 &result);
1650 default:
1651 return false;
1652 }
1653 }
1654
compileStageFunctionVector(const FunctionDefinition & f,LLVMValueRef newFunc)1655 bool JIT::compileStageFunctionVector(const FunctionDefinition& f, LLVMValueRef newFunc) {
1656 LLVMValueRef oldFunction = fCurrentFunction;
1657 fCurrentFunction = newFunc;
1658 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[STAGE_PARAM_COUNT]);
1659 LLVMGetParams(fCurrentFunction, params.get());
1660 LLVMValueRef programParam = params.get()[1];
1661 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1662 LLVMBasicBlockRef oldAllocaBlock = fAllocaBlock;
1663 LLVMBasicBlockRef oldCurrentBlock = fCurrentBlock;
1664 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1665 this->setBlock(builder, fAllocaBlock);
1666 fChannels[0] = LLVMBuildAlloca(builder, fFloat32VectorType, "rVec");
1667 LLVMBuildStore(builder, params.get()[4], fChannels[0]);
1668 fChannels[1] = LLVMBuildAlloca(builder, fFloat32VectorType, "gVec");
1669 LLVMBuildStore(builder, params.get()[5], fChannels[1]);
1670 fChannels[2] = LLVMBuildAlloca(builder, fFloat32VectorType, "bVec");
1671 LLVMBuildStore(builder, params.get()[6], fChannels[2]);
1672 fChannels[3] = LLVMBuildAlloca(builder, fFloat32VectorType, "aVec");
1673 LLVMBuildStore(builder, params.get()[7], fChannels[3]);
1674 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1675 this->setBlock(builder, start);
1676 bool success = this->compileVectorStatement(builder, *f.fBody);
1677 if (success) {
1678 // increment program pointer, call next
1679 LLVMValueRef rawNextPtr = LLVMBuildLoad(builder, programParam, "next load");
1680 LLVMTypeRef stageFuncType = LLVMTypeOf(newFunc);
1681 LLVMValueRef nextPtr = LLVMBuildBitCast(builder, rawNextPtr, stageFuncType,
1682 "cast next->func");
1683 LLVMValueRef nextInc = LLVMBuildIntToPtr(builder,
1684 LLVMBuildAdd(builder,
1685 LLVMBuildPtrToInt(builder,
1686 programParam,
1687 fInt64Type,
1688 "cast 1"),
1689 LLVMConstInt(fInt64Type, PTR_SIZE,
1690 false),
1691 "add"),
1692 LLVMPointerType(fInt8PtrType, 0), "cast 2");
1693 LLVMValueRef args[STAGE_PARAM_COUNT] = {
1694 params.get()[0],
1695 nextInc,
1696 params.get()[2],
1697 params.get()[3],
1698 LLVMBuildLoad(builder, fChannels[0], "rVec"),
1699 LLVMBuildLoad(builder, fChannels[1], "gVec"),
1700 LLVMBuildLoad(builder, fChannels[2], "bVec"),
1701 LLVMBuildLoad(builder, fChannels[3], "aVec"),
1702 params.get()[8],
1703 params.get()[9],
1704 params.get()[10],
1705 params.get()[11]
1706 };
1707 LLVMBuildCall(builder, nextPtr, args, STAGE_PARAM_COUNT, "");
1708 LLVMBuildRetVoid(builder);
1709 // finish
1710 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1711 LLVMBuildBr(builder, start);
1712 LLVMDisposeBuilder(builder);
1713 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1714 ABORT("verify failed\n");
1715 }
1716 } else {
1717 LLVMDeleteBasicBlock(fAllocaBlock);
1718 LLVMDeleteBasicBlock(start);
1719 }
1720
1721 fAllocaBlock = oldAllocaBlock;
1722 fCurrentBlock = oldCurrentBlock;
1723 fCurrentFunction = oldFunction;
1724 return success;
1725 }
1726
compileStageFunction(const FunctionDefinition & f)1727 LLVMValueRef JIT::compileStageFunction(const FunctionDefinition& f) {
1728 LLVMTypeRef returnType = fVoidType;
1729 LLVMTypeRef parameterTypes[12] = { fSizeTType, LLVMPointerType(fInt8PtrType, 0), fSizeTType,
1730 fSizeTType, fFloat32VectorType, fFloat32VectorType,
1731 fFloat32VectorType, fFloat32VectorType, fFloat32VectorType,
1732 fFloat32VectorType, fFloat32VectorType, fFloat32VectorType };
1733 LLVMTypeRef stageFuncType = LLVMFunctionType(returnType, parameterTypes, 12, false);
1734 LLVMValueRef result = LLVMAddFunction(fModule,
1735 (String(f.fDeclaration.fName) + "$stage").c_str(),
1736 stageFuncType);
1737 fColorParam = f.fDeclaration.fParameters[2];
1738 if (!this->compileStageFunctionVector(f, result)) {
1739 // vectorization failed, fall back to looping over the pixels
1740 this->compileStageFunctionLoop(f, result);
1741 }
1742 return result;
1743 }
1744
hasStageSignature(const FunctionDeclaration & f)1745 bool JIT::hasStageSignature(const FunctionDeclaration& f) {
1746 return f.fReturnType == *fProgram->fContext->fVoid_Type &&
1747 f.fParameters.size() == 3 &&
1748 f.fParameters[0]->fType == *fProgram->fContext->fInt_Type &&
1749 f.fParameters[0]->fModifiers.fFlags == 0 &&
1750 f.fParameters[1]->fType == *fProgram->fContext->fInt_Type &&
1751 f.fParameters[1]->fModifiers.fFlags == 0 &&
1752 f.fParameters[2]->fType == *fProgram->fContext->fHalf4_Type &&
1753 f.fParameters[2]->fModifiers.fFlags == (Modifiers::kIn_Flag | Modifiers::kOut_Flag);
1754 }
1755
compileFunction(const FunctionDefinition & f)1756 LLVMValueRef JIT::compileFunction(const FunctionDefinition& f) {
1757 if (this->hasStageSignature(f.fDeclaration)) {
1758 this->compileStageFunction(f);
1759 // we compile foo$stage *in addition* to compiling foo, as we can't be sure that the intent
1760 // was to produce an SkJumper stage just because the signature matched or that the function
1761 // is not otherwise called. May need a better way to handle this.
1762 }
1763 LLVMTypeRef returnType = this->getType(f.fDeclaration.fReturnType);
1764 std::vector<LLVMTypeRef> parameterTypes;
1765 for (const auto& p : f.fDeclaration.fParameters) {
1766 LLVMTypeRef type = this->getType(p->fType);
1767 if (p->fModifiers.fFlags & Modifiers::kOut_Flag) {
1768 type = LLVMPointerType(type, 0);
1769 }
1770 parameterTypes.push_back(type);
1771 }
1772 fCurrentFunction = LLVMAddFunction(fModule,
1773 String(f.fDeclaration.fName).c_str(),
1774 LLVMFunctionType(returnType, parameterTypes.data(),
1775 parameterTypes.size(), false));
1776 fFunctions[&f.fDeclaration] = fCurrentFunction;
1777
1778 std::unique_ptr<LLVMValueRef[]> params(new LLVMValueRef[parameterTypes.size()]);
1779 LLVMGetParams(fCurrentFunction, params.get());
1780 for (size_t i = 0; i < f.fDeclaration.fParameters.size(); ++i) {
1781 fVariables[f.fDeclaration.fParameters[i]] = params.get()[i];
1782 }
1783 LLVMBuilderRef builder = LLVMCreateBuilderInContext(fContext);
1784 fAllocaBlock = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "alloca");
1785 LLVMBasicBlockRef start = LLVMAppendBasicBlockInContext(fContext, fCurrentFunction, "start");
1786 fCurrentBlock = start;
1787 LLVMPositionBuilderAtEnd(builder, fCurrentBlock);
1788 this->compileStatement(builder, *f.fBody);
1789 if (!ends_with_branch(*f.fBody)) {
1790 if (f.fDeclaration.fReturnType == *fProgram->fContext->fVoid_Type) {
1791 LLVMBuildRetVoid(builder);
1792 } else {
1793 LLVMBuildUnreachable(builder);
1794 }
1795 }
1796 LLVMPositionBuilderAtEnd(builder, fAllocaBlock);
1797 LLVMBuildBr(builder, start);
1798 LLVMDisposeBuilder(builder);
1799 if (LLVMVerifyFunction(fCurrentFunction, LLVMPrintMessageAction)) {
1800 ABORT("verify failed\n");
1801 }
1802 return fCurrentFunction;
1803 }
1804
createModule()1805 void JIT::createModule() {
1806 fPromotedParameters.clear();
1807 fModule = LLVMModuleCreateWithNameInContext("skslmodule", fContext);
1808 this->loadBuiltinFunctions();
1809 LLVMTypeRef fold2Params[1] = { fInt1Vector2Type };
1810 fFoldAnd2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v2i1",
1811 LLVMFunctionType(fInt1Type, fold2Params, 1, false));
1812 fFoldOr2Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v2i1",
1813 LLVMFunctionType(fInt1Type, fold2Params, 1, false));
1814 LLVMTypeRef fold3Params[1] = { fInt1Vector3Type };
1815 fFoldAnd3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v3i1",
1816 LLVMFunctionType(fInt1Type, fold3Params, 1, false));
1817 fFoldOr3Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v3i1",
1818 LLVMFunctionType(fInt1Type, fold3Params, 1, false));
1819 LLVMTypeRef fold4Params[1] = { fInt1Vector4Type };
1820 fFoldAnd4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.and.i1.v4i1",
1821 LLVMFunctionType(fInt1Type, fold4Params, 1, false));
1822 fFoldOr4Func = LLVMAddFunction(fModule, "llvm.experimental.vector.reduce.or.i1.v4i1",
1823 LLVMFunctionType(fInt1Type, fold4Params, 1, false));
1824 // LLVM doesn't do void*, have to declare it as int8*
1825 LLVMTypeRef appendParams[3] = { fInt8PtrType, fInt32Type, fInt8PtrType };
1826 fAppendFunc = LLVMAddFunction(fModule, "sksl_pipeline_append", LLVMFunctionType(fVoidType,
1827 appendParams,
1828 3,
1829 false));
1830 LLVMTypeRef appendCallbackParams[2] = { fInt8PtrType, fInt8PtrType };
1831 fAppendCallbackFunc = LLVMAddFunction(fModule, "sksl_pipeline_append_callback",
1832 LLVMFunctionType(fVoidType, appendCallbackParams, 2,
1833 false));
1834
1835 LLVMTypeRef debugParams[3] = { fFloat32Type };
1836 fDebugFunc = LLVMAddFunction(fModule, "sksl_debug_print", LLVMFunctionType(fVoidType,
1837 debugParams,
1838 1,
1839 false));
1840
1841 for (const auto& e : *fProgram) {
1842 if (e.fKind == ProgramElement::kFunction_Kind) {
1843 this->compileFunction((FunctionDefinition&) e);
1844 }
1845 }
1846 }
1847
compile(std::unique_ptr<Program> program)1848 std::unique_ptr<JIT::Module> JIT::compile(std::unique_ptr<Program> program) {
1849 fCompiler.optimize(*program);
1850 fProgram = std::move(program);
1851 this->createModule();
1852 this->optimize();
1853 return std::unique_ptr<Module>(new Module(std::move(fProgram), fSharedModule, fJITStack));
1854 }
1855
optimize()1856 void JIT::optimize() {
1857 LLVMPassManagerBuilderRef pmb = LLVMPassManagerBuilderCreate();
1858 LLVMPassManagerBuilderSetOptLevel(pmb, 3);
1859 LLVMPassManagerRef functionPM = LLVMCreateFunctionPassManagerForModule(fModule);
1860 LLVMPassManagerBuilderPopulateFunctionPassManager(pmb, functionPM);
1861 LLVMPassManagerRef modulePM = LLVMCreatePassManager();
1862 LLVMPassManagerBuilderPopulateModulePassManager(pmb, modulePM);
1863 LLVMInitializeFunctionPassManager(functionPM);
1864
1865 LLVMValueRef func = LLVMGetFirstFunction(fModule);
1866 for (;;) {
1867 if (!func) {
1868 break;
1869 }
1870 LLVMRunFunctionPassManager(functionPM, func);
1871 func = LLVMGetNextFunction(func);
1872 }
1873 LLVMRunPassManager(modulePM, fModule);
1874 LLVMDisposePassManager(functionPM);
1875 LLVMDisposePassManager(modulePM);
1876 LLVMPassManagerBuilderDispose(pmb);
1877
1878 std::string error_string;
1879 if (LLVMLoadLibraryPermanently(nullptr)) {
1880 ABORT("LLVMLoadLibraryPermanently failed");
1881 }
1882 char* defaultTriple = LLVMGetDefaultTargetTriple();
1883 char* error;
1884 LLVMTargetRef target;
1885 if (LLVMGetTargetFromTriple(defaultTriple, &target, &error)) {
1886 ABORT("LLVMGetTargetFromTriple failed");
1887 }
1888
1889 if (!LLVMTargetHasJIT(target)) {
1890 ABORT("!LLVMTargetHasJIT");
1891 }
1892 LLVMTargetMachineRef targetMachine = LLVMCreateTargetMachine(target,
1893 defaultTriple,
1894 fCPU,
1895 nullptr,
1896 LLVMCodeGenLevelDefault,
1897 LLVMRelocDefault,
1898 LLVMCodeModelJITDefault);
1899 LLVMDisposeMessage(defaultTriple);
1900 LLVMTargetDataRef dataLayout = LLVMCreateTargetDataLayout(targetMachine);
1901 LLVMSetModuleDataLayout(fModule, dataLayout);
1902 LLVMDisposeTargetData(dataLayout);
1903
1904 fJITStack = LLVMOrcCreateInstance(targetMachine);
1905 fSharedModule = LLVMOrcMakeSharedModule(fModule);
1906 LLVMOrcModuleHandle orcModule;
1907 LLVMOrcAddEagerlyCompiledIR(fJITStack, &orcModule, fSharedModule,
1908 (LLVMOrcSymbolResolverFn) resolveSymbol, this);
1909 LLVMDisposeTargetMachine(targetMachine);
1910 }
1911
getSymbol(const char * name)1912 void* JIT::Module::getSymbol(const char* name) {
1913 LLVMOrcTargetAddress result;
1914 if (LLVMOrcGetSymbolAddress(fJITStack, &result, name)) {
1915 ABORT("GetSymbolAddress error");
1916 }
1917 if (!result) {
1918 ABORT("symbol not found");
1919 }
1920 return (void*) result;
1921 }
1922
getJumperStage(const char * name)1923 void* JIT::Module::getJumperStage(const char* name) {
1924 return this->getSymbol((String(name) + "$stage").c_str());
1925 }
1926
1927 } // namespace
1928
1929 #endif // SK_LLVM_AVAILABLE
1930
1931 #endif // SKSL_STANDALONE
1932