1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
17
18 #include <climits>
19 #include <cstdint>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
29 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
32 #include "mlir/IR/OpImplementation.h" // from @llvm-project
33 #include "mlir/IR/PatternMatch.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Support/LLVM.h" // from @llvm-project
36 #include "mlir/Support/LogicalResult.h" // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/core/util/matmul_bcast.h"
40
41 namespace mlir {
42 namespace TF {
43
44 namespace {
45 // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
46 // of the inputs, matmul them individually, then stack them all back together at
47 // the end.
48 struct UnrollBatchMatMulPass
49 : public PassWrapper<UnrollBatchMatMulPass, FunctionPass> {
getArgumentmlir::TF::__anonbbf4b02e0111::UnrollBatchMatMulPass50 StringRef getArgument() const final { return "tf-unroll-batch-matmul"; }
51
getDescriptionmlir::TF::__anonbbf4b02e0111::UnrollBatchMatMulPass52 StringRef getDescription() const final {
53 return "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.";
54 }
55
56 void runOnFunction() override;
57 };
58
runOnFunction()59 void UnrollBatchMatMulPass::runOnFunction() {
60 OwningRewritePatternList patterns(&getContext());
61 auto func = getFunction();
62
63 patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
64 ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>,
65 ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(&getContext());
66 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
67 }
68
69 } // namespace
70
71 template <typename BatchMatMulOpType>
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter & rewriter)72 TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
73 Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
74 PatternRewriter& rewriter) {
75 int64_t shape_rank = shape.size();
76 auto shape_spec_type =
77 RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64));
78 Type resultType = RankedTensorType::get(shape, element_type);
79 auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
80 auto shape_tensor =
81 rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
82 return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
83 /*shape=*/shape_tensor);
84 }
85
86 template <typename BatchMatMulOpType>
sliceInput(Value value,int batch_size,Location loc,PatternRewriter & rewriter)87 std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
88 Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
89 RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
90 Type element_type = tensorType.getElementType();
91
92 int rank = tensorType.getShape().size();
93 int num_rows = tensorType.getShape()[rank - 2];
94 int num_cols = tensorType.getShape()[rank - 1];
95
96 std::vector<Value> sliced;
97
98 if (batch_size == 1) {
99 // Batch size is 1, no splitting is required
100 // Squeeze the batch dimension, i.e. reshape
101 // [1, num_rows, num_cols] -> [num_rows, num_cols]
102 auto squeeze_op = createReshapeOp(value, {num_rows, num_cols}, element_type,
103 loc, rewriter);
104 sliced.emplace_back(squeeze_op.output());
105 } else {
106 // Reshape to rank-3 Tensor with first dimension as the batch size.
107 auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols},
108 element_type, loc, rewriter);
109
110 // Create a constant op for the split axis (=0)
111 auto split_dimension_type =
112 RankedTensorType::get({}, rewriter.getIntegerType(32));
113 auto split_dimension_attr = DenseElementsAttr::get(split_dimension_type, 0);
114 auto split_dimension_op = rewriter.create<TF::ConstOp>(
115 loc, split_dimension_type, split_dimension_attr);
116
117 // Split along each batch.
118 SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
119 Type slice_result_type = RankedTensorType::get(slice_size, element_type);
120 llvm::SmallVector<Type, 4> output_types(batch_size, slice_result_type);
121 auto split_op = rewriter.create<TF::SplitOp>(
122 loc, output_types, split_dimension_op.output(), reshape_op.output());
123
124 // Squeeze each batch, i.e. reshape
125 // [1, num_rows, num_cols] -> [num_rows, num_cols]
126 for (const auto& split_value : split_op.output()) {
127 auto squeeze_op = createReshapeOp(split_value, {num_rows, num_cols},
128 element_type, loc, rewriter);
129
130 sliced.emplace_back(squeeze_op.output());
131 }
132 }
133 return sliced;
134 }
135
136 template <typename BatchMatMulOpType>
createTransposeOp(Value value,Location loc,PatternRewriter & rewriter)137 TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
138 Value value, Location loc, PatternRewriter& rewriter) {
139 auto value_type = value.getType().cast<RankedTensorType>();
140 auto shape = value_type.getShape();
141 int dims = shape.size();
142
143 std::vector<int32_t> perm(dims);
144 for (int i = 0; i < dims - 2; i++) {
145 perm[i] = i;
146 }
147 perm[dims - 2] = dims - 1;
148 perm[dims - 1] = dims - 2;
149
150 auto perm_type = RankedTensorType::get({static_cast<int32_t>(perm.size())},
151 rewriter.getIntegerType(32));
152
153 auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
154 auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
155
156 std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
157 int64_t r = transposed_shape[dims - 1];
158 int64_t c = transposed_shape[dims - 2];
159
160 transposed_shape[dims - 1] = c;
161 transposed_shape[dims - 2] = r;
162
163 auto transposed_type =
164 RankedTensorType::get(transposed_shape, value_type.getElementType());
165 return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
166 }
167
168 template <typename BatchMatMulOpType>
createMatMulOps(const std::vector<Value> & sliced_lhs,const std::vector<Value> & sliced_rhs,const tensorflow::MatMulBCast & bcast,int rows,int cols,Type element_type,Location loc,PatternRewriter & rewriter)169 TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
170 const std::vector<Value>& sliced_lhs, const std::vector<Value>& sliced_rhs,
171 const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type,
172 Location loc, PatternRewriter& rewriter) {
173 auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
174
175 std::vector<Value> matmuls;
176 for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
177 int lhs_batch_idx, rhs_batch_idx;
178 if (bcast.IsBroadcastingRequired()) {
179 lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
180 rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
181 } else {
182 lhs_batch_idx = batch_idx;
183 rhs_batch_idx = batch_idx;
184 }
185 auto false_attr = rewriter.getBoolAttr(false);
186 auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type,
187 /*a=*/sliced_lhs[lhs_batch_idx],
188 /*b=*/sliced_rhs[rhs_batch_idx],
189 /*transpose_a=*/false_attr,
190 /*transpose_b=*/false_attr);
191 matmuls.emplace_back(matmul.product());
192 }
193
194 // Combine the result of each individual MatMul into a rank-3 Tensor.
195 Type packed_type = RankedTensorType::get(
196 {bcast.output_batch_size(), rows, cols}, element_type);
197
198 auto axis = rewriter.getI64IntegerAttr(0);
199 return rewriter.create<TF::PackOp>(loc, packed_type,
200 /*values=*/matmuls, axis);
201 }
202
203 template <typename BatchMatMulOpType>
matchAndRewrite(BatchMatMulOpType op,PatternRewriter & rewriter) const204 LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
205 BatchMatMulOpType op, PatternRewriter& rewriter) const {
206 Value input_lhs = op.x();
207 Value input_rhs = op.y();
208
209 if (!input_lhs.getType().isa<RankedTensorType>()) {
210 // LHS must be a ranked tensor type
211 return failure();
212 }
213 if (!input_rhs.getType().isa<RankedTensorType>()) {
214 // RHS must be a ranked tensor type
215 return failure();
216 }
217
218 auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
219 auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
220
221 // Skip int8 x int8 => int32.
222 if (lhs_type.getElementType().isInteger(8) &&
223 rhs_type.getElementType().isInteger(8)) {
224 return rewriter.notifyMatchFailure(op,
225 "skip unrolling for int8 BatchMatMulV3");
226 }
227
228 auto element_type = lhs_type.getElementType();
229
230 if (element_type != rhs_type.getElementType()) {
231 // The element type of LHS must be the same with element type of RHS
232 return failure();
233 }
234
235 auto lhs_shape = lhs_type.getShape();
236 auto rhs_shape = rhs_type.getShape();
237
238 Location loc = op.getLoc();
239
240 // Ensure that input ranks are at least 2.
241 const int dims_a = lhs_shape.size();
242 const int dims_b = rhs_shape.size();
243 if (dims_a < 2 || dims_b < 2) {
244 // Both inputs must have rank >= 2
245 return failure();
246 }
247
248 // Transpose LHS input if necessary.
249 if (op.adj_x()) {
250 input_lhs = createTransposeOp(input_lhs, loc, rewriter);
251
252 lhs_type = input_lhs.getType().cast<RankedTensorType>();
253 lhs_shape = lhs_type.getShape();
254 }
255
256 // Transpose RHS input if necessary.
257 if (op.adj_y()) {
258 input_rhs = createTransposeOp(input_rhs, loc, rewriter);
259
260 rhs_type = input_rhs.getType().cast<RankedTensorType>();
261 rhs_shape = rhs_type.getShape();
262 }
263
264 if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
265 // Input dimensions must be compatible for multiplication.
266 return failure();
267 }
268
269 if (dims_a == 2 && dims_b == 2) {
270 // When both inputs are matrices, just replace the op to a matmul op.
271 Type result_type =
272 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type);
273 auto false_attr = rewriter.getBoolAttr(false);
274 rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, result_type,
275 /*a=*/input_lhs,
276 /*b=*/input_rhs,
277 /*transpose_a=*/false_attr,
278 /*transpose_b=*/false_attr);
279 return success();
280 }
281
282 // Input dimensions must be defined. MatMulBCast does not support partial
283 // shapes.
284 for (auto dim : lhs_shape) {
285 if (dim == -1) {
286 return failure();
287 }
288 }
289 for (auto dim : rhs_shape) {
290 if (dim == -1) {
291 return failure();
292 }
293 }
294 // Ensure that batch shapes are broadcastable.
295 tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
296 lhs_shape.begin(), lhs_shape.end()),
297 absl::InlinedVector<tensorflow::int64, 4>(
298 rhs_shape.begin(), rhs_shape.end()));
299
300 if (!bcast.IsValid()) {
301 // Input batch dimensions must be broadcastable
302 return failure();
303 }
304
305 // Compute slices for each batch in the LHS and RHS.
306 std::vector<Value> sliced_lhs =
307 sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
308 std::vector<Value> sliced_rhs =
309 sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
310
311 // Compute (single batch) MatMul for each output batch. The MatMul outputs
312 // are then packed together into one output Tensor.
313 auto pack_op =
314 createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
315 rhs_shape[dims_b - 1], element_type, loc, rewriter);
316
317 // Reshape the rank-3 Tensor into the correct output shape.
318 const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
319 std::vector<int64_t> result_shape(result_batch_shape.begin(),
320 result_batch_shape.end());
321 result_shape.push_back(lhs_shape[dims_a - 2]);
322 result_shape.push_back(rhs_shape[dims_b - 1]);
323
324 auto reshape_op = createReshapeOp(pack_op.output(), result_shape,
325 element_type, loc, rewriter);
326 rewriter.replaceOp(op, reshape_op.output());
327 return success();
328 }
329
330 static PassRegistration<UnrollBatchMatMulPass> pass;
331
CreateUnrollBatchMatMulPassPass()332 std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass() {
333 return std::make_unique<UnrollBatchMatMulPass>();
334 }
335
336 } // namespace TF
337 } // namespace mlir
338