//===- ConvertSimQuant.cpp - Converts simulated quant ops------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "PassDetail.h" #include "mlir/Dialect/Quant/FakeQuantSupport.h" #include "mlir/Dialect/Quant/Passes.h" #include "mlir/Dialect/Quant/QuantOps.h" #include "mlir/Dialect/Quant/UniformSupport.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" using namespace mlir; using namespace mlir::quant; namespace { struct ConvertSimulatedQuantPass : public QuantConvertSimulatedQuantBase { void runOnFunction() override; }; /// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. template class FakeQuantRewrite : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) : OpRewritePattern(ctx), hadFailure(hadFailure) {} LogicalResult matchAndRewrite(FakeQuantOp op, PatternRewriter &rewriter) const override { // TODO: If this pattern comes up more frequently, consider adding core // support for failable rewrites. if (failableRewrite(op, rewriter)) { *hadFailure = true; return failure(); } return success(); } private: bool *hadFailure; bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); if (!converter) { return (op.emitError("unsupported quantized type conversion"), true); } QuantizedType elementType = static_cast(this) ->convertFakeQuantAttrsToType(op, converter.expressedType); if (!elementType) { // Note that the fakeQuantAttrsToType will have emitted the error. return true; } Type quantizedType = converter.convert(elementType); assert(quantizedType && "Converter accepted a type that it did not convert"); // TODO: Map to a qbarrier with an attribute like [Forced] to signal that // this is a forced/hard-coded constraint. auto qbarrier = rewriter.create(op.getLoc(), quantizedType, op.inputs()); rewriter.replaceOpWithNewOp(op, converter.inputType, qbarrier.getResult()); return false; } }; class ConstFakeQuantRewrite : public FakeQuantRewrite { public: using BaseRewrite = FakeQuantRewrite; ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) : BaseRewrite(ctx, hadFailure) {} QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, Type expressedType) const { return fakeQuantAttrsToType( fqOp.getLoc(), fqOp.num_bits(), fqOp.min().convertToFloat(), fqOp.max().convertToFloat(), fqOp.narrow_range(), expressedType, fqOp.is_signed()); } }; class ConstFakeQuantPerAxisRewrite : public FakeQuantRewrite { public: using BaseRewrite = FakeQuantRewrite; ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) : BaseRewrite(ctx, hadFailure) {} QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, Type expressedType) const { SmallVector min, max; min.reserve(fqOp.min().size()); max.reserve(fqOp.max().size()); for (auto m : fqOp.min()) min.push_back(m.cast().getValueAsDouble()); for (auto m : fqOp.max()) max.push_back(m.cast().getValueAsDouble()); return fakeQuantAttrsToType(fqOp.getLoc(), fqOp.num_bits(), fqOp.axis(), min, max, fqOp.narrow_range(), expressedType, fqOp.is_signed()); } }; } // namespace void ConvertSimulatedQuantPass::runOnFunction() { bool hadFailure = false; OwningRewritePatternList patterns; auto func = getFunction(); auto ctx = func.getContext(); patterns.insert( ctx, &hadFailure); applyPatternsAndFoldGreedily(func, std::move(patterns)); if (hadFailure) signalPassFailure(); } std::unique_ptr> mlir::quant::createConvertSimulatedQuantPass() { return std::make_unique(); }