1 //===- TosaMakeBroadcastable.cpp ------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Insert reshape to binary op's input if needed to match rank
10 //
11 //===----------------------------------------------------------------------===//
12
13 #include "mlir/Dialect/StandardOps/IR/Ops.h"
14 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
15 #include "mlir/Dialect/Tosa/Transforms/PassDetail.h"
16 #include "mlir/Dialect/Tosa/Transforms/Passes.h"
17 #include "mlir/Dialect/Tosa/Utils/QuantUtils.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
20
21 using namespace mlir;
22 using namespace mlir::tosa;
23
24 /// There are two potential ways implementing broadcast:
25 /// a. https://www.tensorflow.org/xla/broadcasting#formal_definition
26 /// b. https://numpy.org/doc/stable/user/basics.broadcasting.html
27 /// TBD: picking option (a) now.
28
29 /// In this pass, we insert RESHAPE operators to increase the rank of the
30 /// lower rank operand as a first step in the broadcasting process. The TOSA
31 /// operators that support broadcast require that the rank of the operands
32 /// are equal.
33
34 // Examples:
35 // If lower=[a], target=[a, b, c], [a] reshaped into [a, 1, 1].
36 // TODO: If lower=[b], target=[a, b, c], [b] should but NOT YET reshaped into
37 // [1, b, 1].
38 // If lower=[c], target=[a, b, c], [c] reshaped into [1, 1, c].
39 // If lower=[a, c], target=[a, b, c], [a, c] reshaped into [a, 1, c].
40 // If lower=[a, b], target=[a, b, c], [a, b] reshaped into [a, b, 1].
41 // If lower=[b, c], target=[a, b, c], [b, c] reshaped into [1, b, c].
42 // If lower=[a], target=[a, a], [a] reshaped into [1, a] instead of [a, 1].
43 // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a].
44 // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1].
45
computeReshapeOutput(ArrayRef<int64_t> higherRankShape,ArrayRef<int64_t> lowerRankShape,SmallVectorImpl<int64_t> & reshapeOutputShape)46 static void computeReshapeOutput(ArrayRef<int64_t> higherRankShape,
47 ArrayRef<int64_t> lowerRankShape,
48 SmallVectorImpl<int64_t> &reshapeOutputShape) {
49 // Intialize new shapes with [1] * higherRank.
50 int64_t higherRank = higherRankShape.size();
51 int64_t lowerRank = lowerRankShape.size();
52
53 reshapeOutputShape.assign(higherRank, 1);
54
55 int64_t higherLeftIndex = 0;
56 int64_t higherRightIndex = higherRank;
57 int64_t lowerLeftIndex = 0;
58 int64_t lowerRightIndex = lowerRank;
59 int64_t higherRankDim, lowerRankDim;
60
61 if (lowerRightIndex != 0 && higherRightIndex != 0) {
62 // Matches lower rank shape from right dimension first, until not
63 // matching high rank shape or reaching dimension 0.
64 while (true) {
65 higherRankDim = higherRankShape[higherRightIndex - 1];
66 lowerRankDim = lowerRankShape[lowerRightIndex - 1];
67 if (higherRankDim != lowerRankDim)
68 break;
69
70 reshapeOutputShape[higherRightIndex - 1] = higherRankDim;
71
72 if (higherRightIndex > 0)
73 higherRightIndex--;
74
75 if (lowerRightIndex > 0)
76 lowerRightIndex--;
77
78 if (higherRightIndex == 0 || lowerRightIndex == 0)
79 break;
80 }
81 if (lowerRightIndex != 0 && higherRightIndex != 0) {
82 // Matches lower rank shape from left dimension, until not matching
83 // high rank shape or reaching right index.
84 while (true) {
85 higherRankDim = higherRankShape[higherLeftIndex];
86 lowerRankDim = lowerRankShape[lowerLeftIndex];
87 if (higherRankDim != lowerRankDim)
88 break;
89
90 reshapeOutputShape[higherLeftIndex] = higherRankDim;
91
92 if (higherLeftIndex < higherRightIndex)
93 higherLeftIndex++;
94
95 if (lowerLeftIndex < lowerRightIndex)
96 lowerLeftIndex++;
97
98 if (higherLeftIndex == higherRightIndex ||
99 lowerLeftIndex == lowerRightIndex)
100 break;
101 }
102 }
103 }
104 }
105
106 /// Common code to reate the reshape op where necessary to make the rank of the
107 /// operations equal. Returns the updated input1 and input2 for the original
108 /// input. The caller is expected to use these to rewrite the original operator
109 /// with the RESHAPE now in the graph.
reshapeLowerToHigher(PatternRewriter & rewriter,Location loc,RankedTensorType outputType,Value input1,Value input2,Value & outInput1,Value & outInput2)110 static int reshapeLowerToHigher(PatternRewriter &rewriter, Location loc,
111 RankedTensorType outputType, Value input1,
112 Value input2, Value &outInput1,
113 Value &outInput2) {
114
115 int64_t input1Rank = input1.getType().cast<RankedTensorType>().getRank();
116 int64_t input2Rank = input2.getType().cast<RankedTensorType>().getRank();
117
118 Value higherTensorValue, lowerTensorValue;
119 // return if rank already match
120 if (input1Rank == input2Rank)
121 return 1;
122
123 if (input1Rank > input2Rank) {
124 higherTensorValue = input1;
125 lowerTensorValue = input2;
126 } else {
127 higherTensorValue = input2;
128 lowerTensorValue = input1;
129 }
130
131 ArrayRef<int64_t> outputRankShape = outputType.getShape();
132 ArrayRef<int64_t> higherRankShape =
133 higherTensorValue.getType().cast<RankedTensorType>().getShape();
134 (void)higherRankShape;
135 ArrayRef<int64_t> lowerRankShape =
136 lowerTensorValue.getType().cast<RankedTensorType>().getShape();
137
138 // outputRank == higherRank == max(input1Rank, input2Rank)
139 assert(higherRankShape.size() == outputRankShape.size());
140
141 SmallVector<int64_t, 4> reshapeOutputShape;
142
143 computeReshapeOutput(outputRankShape, lowerRankShape, reshapeOutputShape);
144
145 auto reshapeInputType = lowerTensorValue.getType().cast<RankedTensorType>();
146 auto reshapeOutputType = RankedTensorType::get(
147 ArrayRef<int64_t>(reshapeOutputShape), reshapeInputType.getElementType());
148
149 auto reshapeLower = rewriter.create<tosa::ReshapeOp>(
150 loc, reshapeOutputType, lowerTensorValue,
151 rewriter.getI64ArrayAttr(reshapeOutputShape));
152
153 if (input1Rank > input2Rank) {
154 outInput1 = higherTensorValue;
155 outInput2 = reshapeLower.getResult();
156 } else {
157 outInput1 = reshapeLower.getResult();
158 outInput2 = higherTensorValue;
159 }
160
161 return 0;
162 }
163
164 namespace {
165 template <typename OpTy>
166 struct ConvertTosaOp : public OpRewritePattern<OpTy> {
167 using OpRewritePattern<OpTy>::OpRewritePattern;
168
matchAndRewrite__anon8d70db250111::ConvertTosaOp169 LogicalResult matchAndRewrite(OpTy tosaBinaryOp,
170 PatternRewriter &rewriter) const override {
171
172 Value input1 = tosaBinaryOp.input1();
173 Value input2 = tosaBinaryOp.input2();
174 Value output = tosaBinaryOp.getResult();
175 auto outputType = output.getType().cast<RankedTensorType>();
176
177 Value outInput1, outInput2;
178 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
179 input1, input2, outInput1, outInput2))
180 return failure();
181
182 rewriter.replaceOpWithNewOp<OpTy>(tosaBinaryOp, outputType, outInput1,
183 outInput2);
184
185 return success();
186 }
187 };
188
189 // The MulOp has an extra parameter 'shift' not present in other elementwise
190 // binary ops, that necessitates special handling of its builder.
191 template <>
192 struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
193 using OpRewritePattern<tosa::MulOp>::OpRewritePattern;
194
matchAndRewrite__anon8d70db250111::ConvertTosaOp195 LogicalResult matchAndRewrite(tosa::MulOp tosaBinaryOp,
196 PatternRewriter &rewriter) const override {
197
198 Value input1 = tosaBinaryOp.input1();
199 Value input2 = tosaBinaryOp.input2();
200 int32_t shift = tosaBinaryOp.shift();
201 Value output = tosaBinaryOp.getResult();
202 auto outputType = output.getType().cast<RankedTensorType>();
203
204 Value outInput1, outInput2;
205 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
206 input1, input2, outInput1, outInput2))
207 return failure();
208
209 rewriter.replaceOpWithNewOp<tosa::MulOp>(tosaBinaryOp, outputType,
210 outInput1, outInput2, shift);
211
212 return success();
213 }
214 };
215
216 // The ArithmeticRightShiftOp has an extra parameter 'round' not present in
217 // other elementwise binary ops, that necessitates special handling of its
218 // builder.
219 template <>
220 struct ConvertTosaOp<tosa::ArithmeticRightShiftOp>
221 : public OpRewritePattern<tosa::ArithmeticRightShiftOp> {
222 using OpRewritePattern<tosa::ArithmeticRightShiftOp>::OpRewritePattern;
223
matchAndRewrite__anon8d70db250111::ConvertTosaOp224 LogicalResult matchAndRewrite(tosa::ArithmeticRightShiftOp tosaBinaryOp,
225 PatternRewriter &rewriter) const override {
226
227 Value input1 = tosaBinaryOp.input1();
228 Value input2 = tosaBinaryOp.input2();
229 int32_t round = tosaBinaryOp.round();
230 Value output = tosaBinaryOp.getResult();
231 auto outputType = output.getType().dyn_cast<RankedTensorType>();
232
233 Value outInput1, outInput2;
234 if (reshapeLowerToHigher(rewriter, tosaBinaryOp.getLoc(), outputType,
235 input1, input2, outInput1, outInput2))
236 return failure();
237
238 rewriter.replaceOpWithNewOp<tosa::ArithmeticRightShiftOp>(
239 tosaBinaryOp, outputType, outInput1, outInput2, round);
240
241 return success();
242 }
243 };
244 } // end anonymous namespace
245
246 namespace {
247 /// Pass that enables broadcast by making all input arrays have the same
248 /// number of dimensions. Insert RESHAPE operations to lower rank operand
249 struct TosaMakeBroadcastable
250 : public TosaMakeBroadcastableBase<TosaMakeBroadcastable> {
251 public:
runOnFunction__anon8d70db250211::TosaMakeBroadcastable252 void runOnFunction() override {
253 auto func = getFunction();
254 OwningRewritePatternList patterns;
255 MLIRContext *ctx = func.getContext();
256 // Add the generated patterns to the list.
257 patterns.insert<ConvertTosaOp<tosa::AddOp>>(ctx);
258 patterns.insert<ConvertTosaOp<tosa::SubOp>>(ctx);
259 patterns.insert<ConvertTosaOp<tosa::MulOp>>(ctx);
260 patterns.insert<ConvertTosaOp<tosa::MaximumOp>>(ctx);
261 patterns.insert<ConvertTosaOp<tosa::MinimumOp>>(ctx);
262 patterns.insert<ConvertTosaOp<tosa::EqualOp>>(ctx);
263 patterns.insert<ConvertTosaOp<tosa::GreaterOp>>(ctx);
264 patterns.insert<ConvertTosaOp<tosa::GreaterEqualOp>>(ctx);
265 patterns.insert<ConvertTosaOp<tosa::LogicalLeftShiftOp>>(ctx);
266 patterns.insert<ConvertTosaOp<tosa::ArithmeticRightShiftOp>>(ctx);
267 patterns.insert<ConvertTosaOp<tosa::LogicalRightShiftOp>>(ctx);
268 applyPatternsAndFoldGreedily(func, std::move(patterns));
269 }
270 };
271 } // end anonymous namespace
272
createTosaMakeBroadcastablePass()273 std::unique_ptr<Pass> mlir::tosa::createTosaMakeBroadcastablePass() {
274 return std::make_unique<TosaMakeBroadcastable>();
275 }
276