• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
17 
18 #include <climits>
19 #include <cstdint>
20 
21 #include "absl/memory/memory.h"
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/ADT/StringSwitch.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/Debug.h"
27 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
31 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
32 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
33 #include "mlir/Pass/Pass.h"  // from @llvm-project
34 #include "mlir/Support/LLVM.h"  // from @llvm-project
35 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
36 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
38 #include "tensorflow/core/util/matmul_bcast.h"
39 
40 namespace mlir {
41 namespace TF {
42 
43 namespace {
44 // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
45 // of the inputs, matmul them individually, then stack them all back together at
46 // the end.
47 struct UnrollBatchMatMulPass
48     : public PassWrapper<UnrollBatchMatMulPass, FunctionPass> {
49   void runOnFunction() override;
50 };
51 
runOnFunction()52 void UnrollBatchMatMulPass::runOnFunction() {
53   OwningRewritePatternList patterns;
54   auto func = getFunction();
55 
56   patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
57                   ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>>(&getContext());
58   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
59 }
60 
61 }  // namespace
62 
63 template <typename BatchMatMulOpType>
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter & rewriter)64 TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
65     Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
66     PatternRewriter& rewriter) {
67   int64_t shape_rank = shape.size();
68   auto shape_spec_type =
69       RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64));
70   Type resultType = RankedTensorType::get(shape, element_type);
71   auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
72   auto shape_tensor =
73       rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
74   return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
75                                         /*shape=*/shape_tensor);
76 }
77 
78 template <typename BatchMatMulOpType>
sliceInput(Value value,int batch_size,Location loc,PatternRewriter & rewriter)79 std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
80     Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
81   RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
82   Type element_type = tensorType.getElementType();
83 
84   int rank = tensorType.getShape().size();
85   int num_rows = tensorType.getShape()[rank - 2];
86   int num_cols = tensorType.getShape()[rank - 1];
87 
88   // Reshape to rank-3 Tensor with first dimension as the batch size.
89   auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols},
90                                     element_type, loc, rewriter);
91 
92   SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
93 
94   std::vector<Value> sliced;
95   Type int64_type = rewriter.getIntegerType(64);
96   Type slice_result_type = RankedTensorType::get(slice_size, element_type);
97 
98   // Slice along each batch index and remember the slice output for future
99   // use.
100   for (int batch_idx = 0; batch_idx < batch_size; ++batch_idx) {
101     auto vector3_type = RankedTensorType::get({3}, int64_type);
102 
103     auto begin_attr =
104         DenseElementsAttr::get<int64_t>(vector3_type, {batch_idx, 0, 0});
105     auto size_attr = DenseElementsAttr::get<int64_t>(vector3_type, slice_size);
106     auto begin = rewriter.create<TF::ConstOp>(loc, vector3_type, begin_attr);
107     auto size = rewriter.create<TF::ConstOp>(loc, vector3_type, size_attr);
108     auto slice_op = rewriter.create<TF::SliceOp>(loc, slice_result_type,
109                                                  /*input=*/reshape_op.output(),
110                                                  begin, size);
111 
112     // Squeeze matrix, i.e. reshape [1, num_rows, num_cols] -> [num_rows,
113     // num_cols]
114     auto squeeze_op = createReshapeOp(slice_op.output(), {num_rows, num_cols},
115                                       element_type, loc, rewriter);
116 
117     sliced.emplace_back(squeeze_op.output());
118   }
119   return sliced;
120 }
121 
122 template <typename BatchMatMulOpType>
createTransposeOp(Value value,Location loc,PatternRewriter & rewriter)123 TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
124     Value value, Location loc, PatternRewriter& rewriter) {
125   auto value_type = value.getType().cast<RankedTensorType>();
126   auto shape = value_type.getShape();
127   int dims = shape.size();
128 
129   std::vector<int32_t> perm(dims);
130   for (int i = 0; i < dims - 2; i++) {
131     perm[i] = i;
132   }
133   perm[dims - 2] = dims - 1;
134   perm[dims - 1] = dims - 2;
135 
136   auto perm_type = RankedTensorType::get({static_cast<int32_t>(perm.size())},
137                                          rewriter.getIntegerType(32));
138 
139   auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
140   auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
141 
142   std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
143   int64_t r = transposed_shape[dims - 1];
144   int64_t c = transposed_shape[dims - 2];
145 
146   transposed_shape[dims - 1] = c;
147   transposed_shape[dims - 2] = r;
148 
149   auto transposed_type =
150       RankedTensorType::get(transposed_shape, value_type.getElementType());
151   return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
152 }
153 
154 template <typename BatchMatMulOpType>
createMatMulOps(const std::vector<Value> & sliced_lhs,const std::vector<Value> & sliced_rhs,const tensorflow::MatMulBCast & bcast,int rows,int cols,Type element_type,Location loc,PatternRewriter & rewriter)155 TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
156     const std::vector<Value>& sliced_lhs, const std::vector<Value>& sliced_rhs,
157     const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type,
158     Location loc, PatternRewriter& rewriter) {
159   auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
160 
161   std::vector<Value> matmuls;
162   for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
163     int lhs_batch_idx, rhs_batch_idx;
164     if (bcast.IsBroadcastingRequired()) {
165       lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
166       rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
167     } else {
168       lhs_batch_idx = batch_idx;
169       rhs_batch_idx = batch_idx;
170     }
171     auto false_attr = rewriter.getBoolAttr(false);
172     auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type,
173                                                 /*a=*/sliced_lhs[lhs_batch_idx],
174                                                 /*b=*/sliced_rhs[rhs_batch_idx],
175                                                 /*transpose_a=*/false_attr,
176                                                 /*transpose_b=*/false_attr);
177     matmuls.emplace_back(matmul.product());
178   }
179 
180   // Combine the result of each individual MatMul into a rank-3 Tensor.
181   Type packed_type = RankedTensorType::get(
182       {bcast.output_batch_size(), rows, cols}, element_type);
183 
184   auto axis = rewriter.getI64IntegerAttr(0);
185   return rewriter.create<TF::PackOp>(loc, packed_type,
186                                      /*values=*/matmuls, axis);
187 }
188 
189 template <typename BatchMatMulOpType>
matchAndRewrite(BatchMatMulOpType op,PatternRewriter & rewriter) const190 LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
191     BatchMatMulOpType op, PatternRewriter& rewriter) const {
192   Value input_lhs = op.x();
193   Value input_rhs = op.y();
194 
195   if (!input_lhs.getType().isa<RankedTensorType>()) {
196     // LHS must be a ranked tensor type
197     return failure();
198   }
199   if (!input_rhs.getType().isa<RankedTensorType>()) {
200     // RHS must be a ranked tensor type
201     return failure();
202   }
203 
204   auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
205   auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
206 
207   auto element_type = lhs_type.getElementType();
208 
209   if (element_type != rhs_type.getElementType()) {
210     // The element type of LHS must be the same with element type of RHS
211     return failure();
212   }
213 
214   auto lhs_shape = lhs_type.getShape();
215   auto rhs_shape = rhs_type.getShape();
216 
217   Location loc = op.getLoc();
218 
219   // Ensure that input ranks are at least 2.
220   const int dims_a = lhs_shape.size();
221   const int dims_b = rhs_shape.size();
222   if (dims_a < 2 || dims_b < 2) {
223     // Both inputs must have rank >= 2
224     return failure();
225   }
226 
227   // Transpose LHS input if necessary.
228   if (op.adj_x()) {
229     input_lhs = createTransposeOp(input_lhs, loc, rewriter);
230 
231     lhs_type = input_lhs.getType().cast<RankedTensorType>();
232     lhs_shape = lhs_type.getShape();
233   }
234 
235   // Transpose RHS input if necessary.
236   if (op.adj_y()) {
237     input_rhs = createTransposeOp(input_rhs, loc, rewriter);
238 
239     rhs_type = input_rhs.getType().cast<RankedTensorType>();
240     rhs_shape = rhs_type.getShape();
241   }
242 
243   if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
244     // Input dimensions must be compatible for multiplication.
245     return failure();
246   }
247 
248   if (dims_a == 2 && dims_b == 2) {
249     // When both inputs are matrices, just replace the op to a matmul op.
250     Type result_type =
251         RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type);
252     auto false_attr = rewriter.getBoolAttr(false);
253     rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, result_type,
254                                               /*a=*/input_lhs,
255                                               /*b=*/input_rhs,
256                                               /*transpose_a=*/false_attr,
257                                               /*transpose_b=*/false_attr);
258     return success();
259   }
260 
261   // Input dimensions must be defined. MatMulBCast does not support partial
262   // shapes.
263   for (auto dim : lhs_shape) {
264     if (dim == -1) {
265       return failure();
266     }
267   }
268   for (auto dim : rhs_shape) {
269     if (dim == -1) {
270       return failure();
271     }
272   }
273   // Ensure that batch shapes are broadcastable.
274   tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
275                                     lhs_shape.begin(), lhs_shape.end()),
276                                 absl::InlinedVector<tensorflow::int64, 4>(
277                                     rhs_shape.begin(), rhs_shape.end()));
278 
279   if (!bcast.IsValid()) {
280     // Input batch dimensions must be broadcastable
281     return failure();
282   }
283 
284   // Compute slices for each batch in the LHS and RHS.
285   std::vector<Value> sliced_lhs =
286       sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
287   std::vector<Value> sliced_rhs =
288       sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
289 
290   // Compute (single batch) MatMul for each output batch. The MatMul outputs
291   // are then packed together into one output Tensor.
292   auto pack_op =
293       createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
294                       rhs_shape[dims_b - 1], element_type, loc, rewriter);
295 
296   // Reshape the rank-3 Tensor into the correct output shape.
297   const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
298   std::vector<int64_t> result_shape(result_batch_shape.begin(),
299                                     result_batch_shape.end());
300   result_shape.push_back(lhs_shape[dims_a - 2]);
301   result_shape.push_back(rhs_shape[dims_b - 1]);
302 
303   auto reshape_op = createReshapeOp(pack_op.output(), result_shape,
304                                     element_type, loc, rewriter);
305   rewriter.replaceOp(op, reshape_op.output());
306   return success();
307 }
308 
309 static PassRegistration<UnrollBatchMatMulPass> pass(
310     "tf-unroll-batch-matmul",
311     "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.");
312 
CreateUnrollBatchMatMulPassPass()313 std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass() {
314   return std::make_unique<UnrollBatchMatMulPass>();
315 }
316 
317 }  // namespace TF
318 }  // namespace mlir
319