// // Copyright 2020 The ANGLE Project Authors. All rights reserved. // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. // #include "compiler/translator/TranslatorMetalDirect/RewriteUnaddressableReferences.h" #include "compiler/translator/TranslatorMetalDirect/AstHelpers.h" #include "compiler/translator/tree_util/AsNode.h" #include "compiler/translator/tree_util/IntermRebuild.h" using namespace sh; namespace { bool IsOutParam(const TType ¶mType) { const TQualifier qual = paramType.getQualifier(); switch (qual) { case TQualifier::EvqParamInOut: case TQualifier::EvqParamOut: return true; default: return false; } } bool IsVectorAccess(TIntermBinary &binary) { TOperator op = binary.getOp(); switch (op) { case TOperator::EOpIndexDirect: case TOperator::EOpIndexIndirect: break; default: return false; } const TType &leftType = binary.getLeft()->getType(); if (!leftType.isVector() || leftType.isArray()) { return false; } ASSERT(IsScalarBasicType(binary.getType())); return true; } bool IsVectorAccess(TIntermNode &node) { if (auto *bin = node.getAsBinaryNode()) { return IsVectorAccess(*bin); } return false; } // Differs from IsAssignment in that it does not include (++) or (--). bool IsAssignEqualsSign(TOperator op) { switch (op) { case TOperator::EOpAssign: case TOperator::EOpInitialize: case TOperator::EOpAddAssign: case TOperator::EOpSubAssign: case TOperator::EOpMulAssign: case TOperator::EOpVectorTimesMatrixAssign: case TOperator::EOpVectorTimesScalarAssign: case TOperator::EOpMatrixTimesScalarAssign: case TOperator::EOpMatrixTimesMatrixAssign: case TOperator::EOpDivAssign: case TOperator::EOpIModAssign: case TOperator::EOpBitShiftLeftAssign: case TOperator::EOpBitShiftRightAssign: case TOperator::EOpBitwiseAndAssign: case TOperator::EOpBitwiseXorAssign: case TOperator::EOpBitwiseOrAssign: return true; default: return false; } } // Only includes ($=) style assigns, where ($) is a binary op. bool IsCompoundAssign(TOperator op) { switch (op) { case TOperator::EOpAddAssign: case TOperator::EOpSubAssign: case TOperator::EOpMulAssign: case TOperator::EOpVectorTimesMatrixAssign: case TOperator::EOpVectorTimesScalarAssign: case TOperator::EOpMatrixTimesScalarAssign: case TOperator::EOpMatrixTimesMatrixAssign: case TOperator::EOpDivAssign: case TOperator::EOpIModAssign: case TOperator::EOpBitShiftLeftAssign: case TOperator::EOpBitShiftRightAssign: case TOperator::EOpBitwiseAndAssign: case TOperator::EOpBitwiseXorAssign: case TOperator::EOpBitwiseOrAssign: return true; default: return false; } } bool ReturnsReference(TOperator op) { switch (op) { case TOperator::EOpAssign: case TOperator::EOpInitialize: case TOperator::EOpAddAssign: case TOperator::EOpSubAssign: case TOperator::EOpMulAssign: case TOperator::EOpVectorTimesMatrixAssign: case TOperator::EOpVectorTimesScalarAssign: case TOperator::EOpMatrixTimesScalarAssign: case TOperator::EOpMatrixTimesMatrixAssign: case TOperator::EOpDivAssign: case TOperator::EOpIModAssign: case TOperator::EOpBitShiftLeftAssign: case TOperator::EOpBitShiftRightAssign: case TOperator::EOpBitwiseAndAssign: case TOperator::EOpBitwiseXorAssign: case TOperator::EOpBitwiseOrAssign: case TOperator::EOpPostIncrement: case TOperator::EOpPostDecrement: case TOperator::EOpPreIncrement: case TOperator::EOpPreDecrement: case TOperator::EOpIndexDirect: case TOperator::EOpIndexIndirect: case TOperator::EOpIndexDirectStruct: case TOperator::EOpIndexDirectInterfaceBlock: return true; default: return false; } } TIntermTyped &DecomposeCompoundAssignment(TIntermBinary &node) { TOperator op = node.getOp(); switch (op) { case TOperator::EOpAddAssign: op = TOperator::EOpAdd; break; case TOperator::EOpSubAssign: op = TOperator::EOpSub; break; case TOperator::EOpMulAssign: op = TOperator::EOpMul; break; case TOperator::EOpVectorTimesMatrixAssign: op = TOperator::EOpVectorTimesMatrix; break; case TOperator::EOpVectorTimesScalarAssign: op = TOperator::EOpVectorTimesScalar; break; case TOperator::EOpMatrixTimesScalarAssign: op = TOperator::EOpMatrixTimesScalar; break; case TOperator::EOpMatrixTimesMatrixAssign: op = TOperator::EOpMatrixTimesMatrix; break; case TOperator::EOpDivAssign: op = TOperator::EOpDiv; break; case TOperator::EOpIModAssign: op = TOperator::EOpIMod; break; case TOperator::EOpBitShiftLeftAssign: op = TOperator::EOpBitShiftLeft; break; case TOperator::EOpBitShiftRightAssign: op = TOperator::EOpBitShiftRight; break; case TOperator::EOpBitwiseAndAssign: op = TOperator::EOpBitwiseAnd; break; case TOperator::EOpBitwiseXorAssign: op = TOperator::EOpBitwiseXor; break; case TOperator::EOpBitwiseOrAssign: op = TOperator::EOpBitwiseOr; break; default: UNREACHABLE(); } // This assumes SeparateCompoundExpressions has already been called. // This assumption allows this code to not need to introduce temporaries. // // e.g. dont have to worry about: // vec[hasSideEffect()] *= 4 // becoming // vec[hasSideEffect()] = vec[hasSideEffect()] * 4 TIntermTyped *left = node.getLeft(); TIntermTyped *right = node.getRight(); return *new TIntermBinary(TOperator::EOpAssign, left->deepCopy(), new TIntermBinary(op, left, right)); } class Rewriter1 : public TIntermRebuild { public: Rewriter1(TCompiler &compiler) : TIntermRebuild(compiler, false, true) {} PostResult visitBinaryPost(TIntermBinary &binaryNode) override { const TOperator op = binaryNode.getOp(); if (IsCompoundAssign(op)) { TIntermTyped &left = *binaryNode.getLeft(); if (left.getAsSwizzleNode() || IsVectorAccess(left)) { return DecomposeCompoundAssignment(binaryNode); } } return binaryNode; } }; class Rewriter2 : public TIntermRebuild { std::vector mRequiresAddressingStack; SymbolEnv &mSymbolEnv; private: bool requiresAddressing() const { if (mRequiresAddressingStack.empty()) { return false; } return mRequiresAddressingStack.back(); } public: ~Rewriter2() override { ASSERT(mRequiresAddressingStack.empty()); } Rewriter2(TCompiler &compiler, SymbolEnv &symbolEnv) : TIntermRebuild(compiler, true, true), mSymbolEnv(symbolEnv) {} PreResult visitAggregatePre(TIntermAggregate &aggregateNode) override { const TFunction *func = aggregateNode.getFunction(); if (!func) { return aggregateNode; } TIntermSequence &args = *aggregateNode.getSequence(); size_t argCount = args.size(); for (size_t i = 0; i < argCount; ++i) { const TVariable ¶m = *func->getParam(i); const TType ¶mType = param.getType(); TIntermNode *arg = args[i]; ASSERT(arg); mRequiresAddressingStack.push_back(IsOutParam(paramType)); args[i] = rebuild(*arg).single(); ASSERT(args[i]); ASSERT(!mRequiresAddressingStack.empty()); mRequiresAddressingStack.pop_back(); } return {aggregateNode, VisitBits::Neither}; } PostResult visitSwizzlePost(TIntermSwizzle &swizzleNode) override { if (!requiresAddressing()) { return swizzleNode; } TIntermTyped &vecNode = *swizzleNode.getOperand(); const TQualifierList &offsets = swizzleNode.getSwizzleOffsets(); ASSERT(!offsets.empty()); ASSERT(offsets.size() <= 4); auto &args = *new TIntermSequence(); args.reserve(offsets.size() + 1); args.push_back(&vecNode); for (int offset : offsets) { args.push_back(new TIntermConstantUnion(new TConstantUnion(offset), *new TType(TBasicType::EbtInt))); } return mSymbolEnv.callFunctionOverload(Name("swizzle_ref"), swizzleNode.getType(), args); } PreResult visitBinaryPre(TIntermBinary &binaryNode) override { const TOperator op = binaryNode.getOp(); const bool isAccess = IsVectorAccess(binaryNode); const bool disableTop = !ReturnsReference(op) || !requiresAddressing(); const bool disableLeft = disableTop; const bool disableRight = disableTop || isAccess || IsAssignEqualsSign(op); auto traverse = [&](TIntermTyped &node, const bool disable) -> TIntermTyped & { if (disable) { mRequiresAddressingStack.push_back(false); } auto *newNode = asNode(rebuild(node).single()); ASSERT(newNode); if (disable) { mRequiresAddressingStack.pop_back(); } return *newNode; }; TIntermTyped &leftNode = *binaryNode.getLeft(); TIntermTyped &rightNode = *binaryNode.getRight(); TIntermTyped &newLeft = traverse(leftNode, disableLeft); TIntermTyped &newRight = traverse(rightNode, disableRight); if (!isAccess || disableTop) { if (&leftNode == &newLeft && &rightNode == &newRight) { return {&binaryNode, VisitBits::Neither}; } return {*new TIntermBinary(op, &newLeft, &newRight), VisitBits::Neither}; } return {mSymbolEnv.callFunctionOverload(Name("elem_ref"), binaryNode.getType(), *new TIntermSequence{&newLeft, &newRight}), VisitBits::Neither}; } }; } // anonymous namespace bool sh::RewriteUnaddressableReferences(TCompiler &compiler, TIntermBlock &root, SymbolEnv &symbolEnv) { if (!Rewriter1(compiler).rebuildRoot(root)) { return false; } if (!Rewriter2(compiler, symbolEnv).rebuildRoot(root)) { return false; } return true; }