1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering HLO dialect to LHLO dialect.
17 
18 #include <memory>
19 #include <utility>
20 
21 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
22 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
23 #include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
24 #include "mlir-hlo/Dialect/mhlo/transforms/rewriters.h"
25 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"
26 #include "mlir/Dialect/MemRef/IR/MemRef.h"
27 #include "mlir/Dialect/Tensor/IR/Tensor.h"
28 #include "mlir/IR/BuiltinDialect.h"
29 #include "mlir/Pass/Pass.h"
30 #include "mlir/Transforms/DialectConversion.h"
31 
32 namespace mlir {
33 namespace mhlo {
34 namespace {
35 
36 struct RngGetAndUpdateStatePattern
37     : public OpConversionPattern<mhlo::XlaRngGetAndUpdateStateOp> {
38   using OpConversionPattern<
39       mhlo::XlaRngGetAndUpdateStateOp>::OpConversionPattern;
40 
matchAndRewritemlir::mhlo::__anon98c622840111::RngGetAndUpdateStatePattern41   LogicalResult matchAndRewrite(
42       mhlo::XlaRngGetAndUpdateStateOp op,
43       XlaRngGetAndUpdateStateOpAdaptor adaptor,
44       ConversionPatternRewriter& rewriter) const final {
45     // Get various type related information
46     auto loc = op->getLoc();
47 
48     const auto globalName = rewriter.getStringAttr("rng_state");
49     constexpr auto initialSeed = 0x7012395ull;
50     auto seedType = rewriter.getIntegerType(128);
51     auto memrefType = MemRefType::get({}, seedType);
52 
53     auto resultType = op.getType();
54     auto wordSize = resultType.getElementType().getIntOrFloatBitWidth();
55     auto smallerIntType = rewriter.getIntegerType(wordSize);
56     auto numElements = resultType.getNumElements();
57 
58     // Get or define the global variable
59     auto* globalOp = mlir::SymbolTable::lookupNearestSymbolFrom(op, globalName);
60     if (!globalOp) {
61       auto* parent = mlir::SymbolTable::getNearestSymbolTable(op);
62       OpBuilder::InsertionGuard g(rewriter);
63       rewriter.setInsertionPointToStart(&parent->getRegions().front().front());
64 
65       const auto priv = rewriter.getStringAttr("private");
66       auto initialValue = mlir::DenseElementsAttr::get(
67           mlir::RankedTensorType::get({}, seedType),
68           rewriter.getIntegerAttr(seedType, initialSeed));
69       globalOp = rewriter.create<memref::GlobalOp>(
70           loc, globalName, priv, memrefType, initialValue, /*constant=*/false,
71           /*alignment=*/IntegerAttr());
72     }
73     assert(isa<memref::GlobalOp>(globalOp) &&
74            "rng_state was defined somewhere else, not as a global op");
75 
76     // Get and update
77     Value rngState =
78         rewriter.create<memref::GetGlobalOp>(loc, memrefType, globalName);
79     Value oldVal = rewriter.create<memref::LoadOp>(loc, rngState);
80     Value delta = rewriter.create<arith::ConstantOp>(
81         loc, rewriter.getIntegerAttr(seedType,
82                                      static_cast<int64_t>(adaptor.delta())));
83     Value newVal = rewriter.create<arith::AddIOp>(loc, oldVal, delta);
84     (void)rewriter.create<memref::StoreOp>(loc, newVal, rngState);
85 
86     // Create the proper return type by packing the old seed into a tensor
87     SmallVector<Value> pieces;
88     for (int i = (numElements - 1) * wordSize; i >= 0; i -= wordSize) {
89       Value shiftDistance = rewriter.create<arith::ConstantOp>(
90           loc, rewriter.getIntegerAttr(seedType, i));
91       pieces.push_back(rewriter.create<arith::TruncIOp>(
92           loc, smallerIntType,
93           rewriter.create<arith::ShRUIOp>(loc, oldVal, shiftDistance)));
94     }
95 
96     // Obtain a tensor with the correct shape and bit widths but the incorrect
97     // integer signedness, then cast the tensor to the correct signedness to
98     // ensure that unrealized casts will successfully lower later.
99     Value resultTensor = rewriter.create<tensor::FromElementsOp>(
100         loc, mlir::RankedTensorType::get(resultType.getShape(), smallerIntType),
101         pieces);
102     rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(op, resultType,
103                                                             resultTensor);
104     return success();
105   }
106 };
107 
108 struct HloLegalizeToArithmeticPass
109     : public HloLegalizeToArithmeticPassBase<HloLegalizeToArithmeticPass> {
getDependentDialectsmlir::mhlo::__anon98c622840111::HloLegalizeToArithmeticPass110   void getDependentDialects(DialectRegistry& registry) const override {
111     registry.insert<arith::ArithmeticDialect, memref::MemRefDialect,
112                     tensor::TensorDialect>();
113   }
114 
115  public:
runOnOperationmlir::mhlo::__anon98c622840111::HloLegalizeToArithmeticPass116   void runOnOperation() override {
117     auto& context = getContext();
118     RewritePatternSet patterns(&context);
119     ConversionTarget target(context);
120 
121     populateHloToArithmeticConversionPatterns(&patterns);
122 
123     target.addIllegalOp<XlaRngGetAndUpdateStateOp>();
124     target.addLegalDialect<arith::ArithmeticDialect, BuiltinDialect,
125                            memref::MemRefDialect, tensor::TensorDialect>();
126 
127     auto module = getOperation();
128     if (failed(applyPartialConversion(module, target, std::move(patterns))))
129       signalPassFailure();
130   }
131 };
132 
133 }  // namespace
134 
populateHloToArithmeticConversionPatterns(RewritePatternSet * patterns)135 void populateHloToArithmeticConversionPatterns(RewritePatternSet* patterns) {
136   patterns->add<RngGetAndUpdateStatePattern>(patterns->getContext());
137 }
138 
createLegalizeToArithmeticPass()139 std::unique_ptr<OperationPass<ModuleOp>> createLegalizeToArithmeticPass() {
140   return std::make_unique<HloLegalizeToArithmeticPass>();
141 }
142 
143 }  // namespace mhlo
144 }  // namespace mlir
145