1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
17
18 #include <climits>
19 #include <cstdint>
20
21 #include "absl/memory/memory.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/Debug.h"
27 #include "mlir/Analysis/LoopAnalysis.h" // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
31 #include "mlir/IR/OpImplementation.h" // from @llvm-project
32 #include "mlir/IR/PatternMatch.h" // from @llvm-project
33 #include "mlir/Pass/Pass.h" // from @llvm-project
34 #include "mlir/Support/LLVM.h" // from @llvm-project
35 #include "mlir/Support/LogicalResult.h" // from @llvm-project
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39
40 namespace mlir {
41 namespace TF {
42
43 namespace {
44 // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
45 // of the inputs, matmul them individually, then stack them all back together at
46 // the end.
47 struct UnrollBatchMatMulPass
48 : public PassWrapper<UnrollBatchMatMulPass, FunctionPass> {
49 void runOnFunction() override;
50 };
51
runOnFunction()52 void UnrollBatchMatMulPass::runOnFunction() {
53 OwningRewritePatternList patterns;
54 auto func = getFunction();
55
56 patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
57 ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
58 (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
59 }
60
61 } // namespace
62
63 template <typename BatchMatMulOpType>
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter & rewriter)64 TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
65 Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
66 PatternRewriter& rewriter) {
67 int64_t shape_rank = shape.size();
68 auto shape_spec_type =
69 RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64));
70 Type resultType = RankedTensorType::get(shape, element_type);
71 auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
72 auto shape_tensor =
73 rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
74 return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
75 /*shape=*/shape_tensor);
76 }
77
78 template <typename BatchMatMulOpType>
sliceInput(Value value,int batch_size,Location loc,PatternRewriter & rewriter)79 std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
80 Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
81 RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
82 Type element_type = tensorType.getElementType();
83
84 int rank = tensorType.getShape().size();
85 int num_rows = tensorType.getShape()[rank - 2];
86 int num_cols = tensorType.getShape()[rank - 1];
87
88 // Reshape to rank-3 Tensor with first dimension as the batch size.
89 auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols},
90 element_type, loc, rewriter);
91
92 SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
93
94 std::vector<Value> sliced;
95 Type int64_type = rewriter.getIntegerType(64);
96 Type slice_result_type = RankedTensorType::get(slice_size, element_type);
97
98 // Slice along each batch index and remember the slice output for future
99 // use.
100 for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
101 auto vector3_type = RankedTensorType::get({3}, int64_type);
102
103 auto begin_attr =
104 DenseElementsAttr::get<int64_t>(vector3_type, {batch_idx, 0, 0});
105 auto size_attr = DenseElementsAttr::get<int64_t>(vector3_type, slice_size);
106 auto begin = rewriter.create<TF::ConstOp>(loc, vector3_type, begin_attr);
107 auto size = rewriter.create<TF::ConstOp>(loc, vector3_type, size_attr);
108 auto slice_op = rewriter.create<TF::SliceOp>(loc, slice_result_type,
109 /*input=*/reshape_op.output(),
110 begin, size);
111
112 // Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
113 // num_cols]
114 auto squeeze_op = createReshapeOp(slice_op.output(), {num_rows, num_cols},
115 element_type, loc, rewriter);
116
117 sliced.emplace_back(squeeze_op.output());
118 }
119 return sliced;
120 }
121
122 template <typename BatchMatMulOpType>
createTransposeOp(Value value,Location loc,PatternRewriter & rewriter)123 TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
124 Value value, Location loc, PatternRewriter& rewriter) {
125 auto value_type = value.getType().cast<RankedTensorType>();
126 auto shape = value_type.getShape();
127 int dims = shape.size();
128
129 std::vector<int32_t> perm(dims);
130 for (int i = 0; i < dims - 2; i++) {
131 perm[i] = i;
132 }
133 perm[dims - 2] = dims - 1;
134 perm[dims - 1] = dims - 2;
135
136 auto perm_type = RankedTensorType::get({static_cast<int32_t>(perm.size())},
137 rewriter.getIntegerType(32));
138
139 auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
140 auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
141
142 std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
143 int64_t r = transposed_shape[dims - 1];
144 int64_t c = transposed_shape[dims - 2];
145
146 transposed_shape[dims - 1] = c;
147 transposed_shape[dims - 2] = r;
148
149 auto transposed_type =
150 RankedTensorType::get(transposed_shape, value_type.getElementType());
151 return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
152 }
153
154 template <typename BatchMatMulOpType>
createMatMulOps(const std::vector<Value> & sliced_lhs,const std::vector<Value> & sliced_rhs,const tensorflow::MatMulBCast & bcast,int rows,int cols,Type element_type,Location loc,PatternRewriter & rewriter)155 TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
156 const std::vector<Value>& sliced_lhs, const std::vector<Value>& sliced_rhs,
157 const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type,
158 Location loc, PatternRewriter& rewriter) {
159 auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
160
161 std::vector<Value> matmuls;
162 for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
163 int lhs_batch_idx, rhs_batch_idx;
164 if (bcast.IsBroadcastingRequired()) {
165 lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
166 rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
167 } else {
168 lhs_batch_idx = batch_idx;
169 rhs_batch_idx = batch_idx;
170 }
171 auto false_attr = rewriter.getBoolAttr(false);
172 auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type,
173 /*a=*/sliced_lhs[lhs_batch_idx],
174 /*b=*/sliced_rhs[rhs_batch_idx],
175 /*transpose_a=*/false_attr,
176 /*transpose_b=*/false_attr);
177 matmuls.emplace_back(matmul.product());
178 }
179
180 // Combine the result of each individual MatMul into a rank-3 Tensor.
181 Type packed_type = RankedTensorType::get(
182 {bcast.output_batch_size(), rows, cols}, element_type);
183
184 auto axis = rewriter.getI64IntegerAttr(0);
185 return rewriter.create<TF::PackOp>(loc, packed_type,
186 /*values=*/matmuls, axis);
187 }
188
189 template <typename BatchMatMulOpType>
matchAndRewrite(BatchMatMulOpType op,PatternRewriter & rewriter) const190 LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
191 BatchMatMulOpType op, PatternRewriter& rewriter) const {
192 Value input_lhs = op.x();
193 Value input_rhs = op.y();
194
195 if (!input_lhs.getType().isa<RankedTensorType>()) {
196 // LHS must be a ranked tensor type
197 return failure();
198 }
199 if (!input_rhs.getType().isa<RankedTensorType>()) {
200 // RHS must be a ranked tensor type
201 return failure();
202 }
203
204 auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
205 auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
206
207 auto element_type = lhs_type.getElementType();
208
209 if (element_type != rhs_type.getElementType()) {
210 // The element type of LHS must be the same with element type of RHS
211 return failure();
212 }
213
214 auto lhs_shape = lhs_type.getShape();
215 auto rhs_shape = rhs_type.getShape();
216
217 Location loc = op.getLoc();
218
219 // Ensure that input ranks are at least 2.
220 const int dims_a = lhs_shape.size();
221 const int dims_b = rhs_shape.size();
222 if (dims_a < 2 || dims_b < 2) {
223 // Both inputs must have rank >= 2
224 return failure();
225 }
226
227 // Transpose LHS input if necessary.
228 if (op.adj_x()) {
229 input_lhs = createTransposeOp(input_lhs, loc, rewriter);
230
231 lhs_type = input_lhs.getType().cast<RankedTensorType>();
232 lhs_shape = lhs_type.getShape();
233 }
234
235 // Transpose RHS input if necessary.
236 if (op.adj_y()) {
237 input_rhs = createTransposeOp(input_rhs, loc, rewriter);
238
239 rhs_type = input_rhs.getType().cast<RankedTensorType>();
240 rhs_shape = rhs_type.getShape();
241 }
242
243 if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
244 // Input dimensions must be compatible for multiplication.
245 return failure();
246 }
247
248 if (dims_a == 2 && dims_b == 2) {
249 // When both inputs are matrices, just replace the op to a matmul op.
250 Type result_type =
251 RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type);
252 auto false_attr = rewriter.getBoolAttr(false);
253 rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, result_type,
254 /*a=*/input_lhs,
255 /*b=*/input_rhs,
256 /*transpose_a=*/false_attr,
257 /*transpose_b=*/false_attr);
258 return success();
259 }
260
261 // Input dimensions must be defined. MatMulBCast does not support partial
262 // shapes.
263 for (auto dim : lhs_shape) {
264 if (dim == -1) {
265 return failure();
266 }
267 }
268 for (auto dim : rhs_shape) {
269 if (dim == -1) {
270 return failure();
271 }
272 }
273 // Ensure that batch shapes are broadcastable.
274 tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
275 lhs_shape.begin(), lhs_shape.end()),
276 absl::InlinedVector<tensorflow::int64, 4>(
277 rhs_shape.begin(), rhs_shape.end()));
278
279 if (!bcast.IsValid()) {
280 // Input batch dimensions must be broadcastable
281 return failure();
282 }
283
284 // Compute slices for each batch in the LHS and RHS.
285 std::vector<Value> sliced_lhs =
286 sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
287 std::vector<Value> sliced_rhs =
288 sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
289
290 // Compute (single batch) MatMul for each output batch. The MatMul outputs
291 // are then packed together into one output Tensor.
292 auto pack_op =
293 createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
294 rhs_shape[dims_b - 1], element_type, loc, rewriter);
295
296 // Reshape the rank-3 Tensor into the correct output shape.
297 const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
298 std::vector<int64_t> result_shape(result_batch_shape.begin(),
299 result_batch_shape.end());
300 result_shape.push_back(lhs_shape[dims_a - 2]);
301 result_shape.push_back(rhs_shape[dims_b - 1]);
302
303 auto reshape_op = createReshapeOp(pack_op.output(), result_shape,
304 element_type, loc, rewriter);
305 rewriter.replaceOp(op, reshape_op.output());
306 return success();
307 }
308
309 static PassRegistration<UnrollBatchMatMulPass> pass(
310 "tf-unroll-batch-matmul",
311 "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
312
CreateUnrollBatchMatMulPassPass()313 std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass() {
314 return std::make_unique<UnrollBatchMatMulPass>();
315 }
316
317 } // namespace TF
318 } // namespace mlir
319