• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/transforms/unroll_batch_matmul.h"
17 
18 #include <climits>
19 #include <cstdint>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/StringSwitch.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "mlir/Analysis/LoopAnalysis.h"  // from @llvm-project
29 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
32 #include "mlir/IR/OpImplementation.h"  // from @llvm-project
33 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
34 #include "mlir/Pass/Pass.h"  // from @llvm-project
35 #include "mlir/Support/LLVM.h"  // from @llvm-project
36 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
37 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
39 #include "tensorflow/core/util/matmul_bcast.h"
40 
41 namespace mlir {
42 namespace TF {
43 
44 namespace {
45 // Unrolls a BatchMatMul on the batch dimension. We need to slice each batch out
46 // of the inputs, matmul them individually, then stack them all back together at
47 // the end.
48 struct UnrollBatchMatMulPass
49     : public PassWrapper<UnrollBatchMatMulPass, FunctionPass> {
getArgumentmlir::TF::__anonbbf4b02e0111::UnrollBatchMatMulPass50   StringRef getArgument() const final { return "tf-unroll-batch-matmul"; }
51 
getDescriptionmlir::TF::__anonbbf4b02e0111::UnrollBatchMatMulPass52   StringRef getDescription() const final {
53     return "Unroll TF BatchMatMul op into Reshape, Slice, MatMul, Pack ops.";
54   }
55 
56   void runOnFunction() override;
57 };
58 
runOnFunction()59 void UnrollBatchMatMulPass::runOnFunction() {
60   OwningRewritePatternList patterns(&getContext());
61   auto func = getFunction();
62 
63   patterns.insert<ConvertTFBatchMatMulOp<TF::BatchMatMulOp>,
64                   ConvertTFBatchMatMulOp<TF::BatchMatMulV2Op>,
65                   ConvertTFBatchMatMulOp<TF::BatchMatMulV3Op>>(&getContext());
66   (void)applyPatternsAndFoldGreedily(func, std::move(patterns));
67 }
68 
69 }  // namespace
70 
71 template <typename BatchMatMulOpType>
createReshapeOp(Value value,ArrayRef<int64_t> shape,Type element_type,Location loc,PatternRewriter & rewriter)72 TF::ReshapeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createReshapeOp(
73     Value value, ArrayRef<int64_t> shape, Type element_type, Location loc,
74     PatternRewriter& rewriter) {
75   int64_t shape_rank = shape.size();
76   auto shape_spec_type =
77       RankedTensorType::get({shape_rank}, rewriter.getIntegerType(64));
78   Type resultType = RankedTensorType::get(shape, element_type);
79   auto constant_attr = DenseElementsAttr::get(shape_spec_type, shape);
80   auto shape_tensor =
81       rewriter.create<TF::ConstOp>(loc, shape_spec_type, constant_attr);
82   return rewriter.create<TF::ReshapeOp>(loc, resultType, /*tensor=*/value,
83                                         /*shape=*/shape_tensor);
84 }
85 
86 template <typename BatchMatMulOpType>
sliceInput(Value value,int batch_size,Location loc,PatternRewriter & rewriter)87 std::vector<Value> ConvertTFBatchMatMulOp<BatchMatMulOpType>::sliceInput(
88     Value value, int batch_size, Location loc, PatternRewriter& rewriter) {
89   RankedTensorType tensorType = value.getType().cast<RankedTensorType>();
90   Type element_type = tensorType.getElementType();
91 
92   int rank = tensorType.getShape().size();
93   int num_rows = tensorType.getShape()[rank - 2];
94   int num_cols = tensorType.getShape()[rank - 1];
95 
96   std::vector<Value> sliced;
97 
98   if (batch_size == 1) {
99     // Batch size is 1, no splitting is required
100     // Squeeze the batch dimension, i.e. reshape
101     // [1, num_rows, num_cols] -> [num_rows, num_cols]
102     auto squeeze_op = createReshapeOp(value, {num_rows, num_cols}, element_type,
103                                       loc, rewriter);
104     sliced.emplace_back(squeeze_op.output());
105   } else {
106     // Reshape to rank-3 Tensor with first dimension as the batch size.
107     auto reshape_op = createReshapeOp(value, {batch_size, num_rows, num_cols},
108                                       element_type, loc, rewriter);
109 
110     // Create a constant op for the split axis (=0)
111     auto split_dimension_type =
112         RankedTensorType::get({}, rewriter.getIntegerType(32));
113     auto split_dimension_attr = DenseElementsAttr::get(split_dimension_type, 0);
114     auto split_dimension_op = rewriter.create<TF::ConstOp>(
115         loc, split_dimension_type, split_dimension_attr);
116 
117     // Split along each batch.
118     SmallVector<int64_t, 3> slice_size = {1, num_rows, num_cols};
119     Type slice_result_type = RankedTensorType::get(slice_size, element_type);
120     llvm::SmallVector<Type, 4> output_types(batch_size, slice_result_type);
121     auto split_op = rewriter.create<TF::SplitOp>(
122         loc, output_types, split_dimension_op.output(), reshape_op.output());
123 
124     // Squeeze each batch, i.e. reshape
125     // [1, num_rows, num_cols] -> [num_rows, num_cols]
126     for (const auto& split_value : split_op.output()) {
127       auto squeeze_op = createReshapeOp(split_value, {num_rows, num_cols},
128                                         element_type, loc, rewriter);
129 
130       sliced.emplace_back(squeeze_op.output());
131     }
132   }
133   return sliced;
134 }
135 
136 template <typename BatchMatMulOpType>
createTransposeOp(Value value,Location loc,PatternRewriter & rewriter)137 TF::TransposeOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createTransposeOp(
138     Value value, Location loc, PatternRewriter& rewriter) {
139   auto value_type = value.getType().cast<RankedTensorType>();
140   auto shape = value_type.getShape();
141   int dims = shape.size();
142 
143   std::vector<int32_t> perm(dims);
144   for (int i = 0; i < dims - 2; i++) {
145     perm[i] = i;
146   }
147   perm[dims - 2] = dims - 1;
148   perm[dims - 1] = dims - 2;
149 
150   auto perm_type = RankedTensorType::get({static_cast<int32_t>(perm.size())},
151                                          rewriter.getIntegerType(32));
152 
153   auto perm_attr = DenseElementsAttr::get(perm_type, llvm::makeArrayRef(perm));
154   auto perm_op = rewriter.create<ConstantOp>(loc, perm_type, perm_attr);
155 
156   std::vector<int64_t> transposed_shape(shape.begin(), shape.end());
157   int64_t r = transposed_shape[dims - 1];
158   int64_t c = transposed_shape[dims - 2];
159 
160   transposed_shape[dims - 1] = c;
161   transposed_shape[dims - 2] = r;
162 
163   auto transposed_type =
164       RankedTensorType::get(transposed_shape, value_type.getElementType());
165   return rewriter.create<TF::TransposeOp>(loc, transposed_type, value, perm_op);
166 }
167 
168 template <typename BatchMatMulOpType>
createMatMulOps(const std::vector<Value> & sliced_lhs,const std::vector<Value> & sliced_rhs,const tensorflow::MatMulBCast & bcast,int rows,int cols,Type element_type,Location loc,PatternRewriter & rewriter)169 TF::PackOp ConvertTFBatchMatMulOp<BatchMatMulOpType>::createMatMulOps(
170     const std::vector<Value>& sliced_lhs, const std::vector<Value>& sliced_rhs,
171     const tensorflow::MatMulBCast& bcast, int rows, int cols, Type element_type,
172     Location loc, PatternRewriter& rewriter) {
173   auto matmul_type = RankedTensorType::get({rows, cols}, element_type);
174 
175   std::vector<Value> matmuls;
176   for (int batch_idx = 0; batch_idx < bcast.output_batch_size(); ++batch_idx) {
177     int lhs_batch_idx, rhs_batch_idx;
178     if (bcast.IsBroadcastingRequired()) {
179       lhs_batch_idx = bcast.x_batch_indices()[batch_idx];
180       rhs_batch_idx = bcast.y_batch_indices()[batch_idx];
181     } else {
182       lhs_batch_idx = batch_idx;
183       rhs_batch_idx = batch_idx;
184     }
185     auto false_attr = rewriter.getBoolAttr(false);
186     auto matmul = rewriter.create<TF::MatMulOp>(loc, matmul_type,
187                                                 /*a=*/sliced_lhs[lhs_batch_idx],
188                                                 /*b=*/sliced_rhs[rhs_batch_idx],
189                                                 /*transpose_a=*/false_attr,
190                                                 /*transpose_b=*/false_attr);
191     matmuls.emplace_back(matmul.product());
192   }
193 
194   // Combine the result of each individual MatMul into a rank-3 Tensor.
195   Type packed_type = RankedTensorType::get(
196       {bcast.output_batch_size(), rows, cols}, element_type);
197 
198   auto axis = rewriter.getI64IntegerAttr(0);
199   return rewriter.create<TF::PackOp>(loc, packed_type,
200                                      /*values=*/matmuls, axis);
201 }
202 
203 template <typename BatchMatMulOpType>
matchAndRewrite(BatchMatMulOpType op,PatternRewriter & rewriter) const204 LogicalResult ConvertTFBatchMatMulOp<BatchMatMulOpType>::matchAndRewrite(
205     BatchMatMulOpType op, PatternRewriter& rewriter) const {
206   Value input_lhs = op.x();
207   Value input_rhs = op.y();
208 
209   if (!input_lhs.getType().isa<RankedTensorType>()) {
210     // LHS must be a ranked tensor type
211     return failure();
212   }
213   if (!input_rhs.getType().isa<RankedTensorType>()) {
214     // RHS must be a ranked tensor type
215     return failure();
216   }
217 
218   auto lhs_type = input_lhs.getType().cast<RankedTensorType>();
219   auto rhs_type = input_rhs.getType().cast<RankedTensorType>();
220 
221   // Skip int8 x int8 => int32.
222   if (lhs_type.getElementType().isInteger(8) &&
223       rhs_type.getElementType().isInteger(8)) {
224     return rewriter.notifyMatchFailure(op,
225                                        "skip unrolling for int8 BatchMatMulV3");
226   }
227 
228   auto element_type = lhs_type.getElementType();
229 
230   if (element_type != rhs_type.getElementType()) {
231     // The element type of LHS must be the same with element type of RHS
232     return failure();
233   }
234 
235   auto lhs_shape = lhs_type.getShape();
236   auto rhs_shape = rhs_type.getShape();
237 
238   Location loc = op.getLoc();
239 
240   // Ensure that input ranks are at least 2.
241   const int dims_a = lhs_shape.size();
242   const int dims_b = rhs_shape.size();
243   if (dims_a < 2 || dims_b < 2) {
244     // Both inputs must have rank >= 2
245     return failure();
246   }
247 
248   // Transpose LHS input if necessary.
249   if (op.adj_x()) {
250     input_lhs = createTransposeOp(input_lhs, loc, rewriter);
251 
252     lhs_type = input_lhs.getType().cast<RankedTensorType>();
253     lhs_shape = lhs_type.getShape();
254   }
255 
256   // Transpose RHS input if necessary.
257   if (op.adj_y()) {
258     input_rhs = createTransposeOp(input_rhs, loc, rewriter);
259 
260     rhs_type = input_rhs.getType().cast<RankedTensorType>();
261     rhs_shape = rhs_type.getShape();
262   }
263 
264   if (lhs_shape[dims_a - 1] != rhs_shape[dims_b - 2]) {
265     // Input dimensions must be compatible for multiplication.
266     return failure();
267   }
268 
269   if (dims_a == 2 && dims_b == 2) {
270     // When both inputs are matrices, just replace the op to a matmul op.
271     Type result_type =
272         RankedTensorType::get({lhs_shape[0], rhs_shape[1]}, element_type);
273     auto false_attr = rewriter.getBoolAttr(false);
274     rewriter.replaceOpWithNewOp<TF::MatMulOp>(op, result_type,
275                                               /*a=*/input_lhs,
276                                               /*b=*/input_rhs,
277                                               /*transpose_a=*/false_attr,
278                                               /*transpose_b=*/false_attr);
279     return success();
280   }
281 
282   // Input dimensions must be defined. MatMulBCast does not support partial
283   // shapes.
284   for (auto dim : lhs_shape) {
285     if (dim == -1) {
286       return failure();
287     }
288   }
289   for (auto dim : rhs_shape) {
290     if (dim == -1) {
291       return failure();
292     }
293   }
294   // Ensure that batch shapes are broadcastable.
295   tensorflow::MatMulBCast bcast(absl::InlinedVector<tensorflow::int64, 4>(
296                                     lhs_shape.begin(), lhs_shape.end()),
297                                 absl::InlinedVector<tensorflow::int64, 4>(
298                                     rhs_shape.begin(), rhs_shape.end()));
299 
300   if (!bcast.IsValid()) {
301     // Input batch dimensions must be broadcastable
302     return failure();
303   }
304 
305   // Compute slices for each batch in the LHS and RHS.
306   std::vector<Value> sliced_lhs =
307       sliceInput(input_lhs, bcast.x_batch_size(), loc, rewriter);
308   std::vector<Value> sliced_rhs =
309       sliceInput(input_rhs, bcast.y_batch_size(), loc, rewriter);
310 
311   // Compute (single batch) MatMul for each output batch. The MatMul outputs
312   // are then packed together into one output Tensor.
313   auto pack_op =
314       createMatMulOps(sliced_lhs, sliced_rhs, bcast, lhs_shape[dims_a - 2],
315                       rhs_shape[dims_b - 1], element_type, loc, rewriter);
316 
317   // Reshape the rank-3 Tensor into the correct output shape.
318   const auto& result_batch_shape = bcast.output_batch_shape().dim_sizes();
319   std::vector<int64_t> result_shape(result_batch_shape.begin(),
320                                     result_batch_shape.end());
321   result_shape.push_back(lhs_shape[dims_a - 2]);
322   result_shape.push_back(rhs_shape[dims_b - 1]);
323 
324   auto reshape_op = createReshapeOp(pack_op.output(), result_shape,
325                                     element_type, loc, rewriter);
326   rewriter.replaceOp(op, reshape_op.output());
327   return success();
328 }
329 
330 static PassRegistration<UnrollBatchMatMulPass> pass;
331 
CreateUnrollBatchMatMulPassPass()332 std::unique_ptr<OperationPass<FuncOp>> CreateUnrollBatchMatMulPassPass() {
333   return std::make_unique<UnrollBatchMatMulPass>();
334 }
335 
336 }  // namespace TF
337 }  // namespace mlir
338