• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file canonicalize reduction ops in hlo dialect to match the
17 // capacity of codegen backend.
18 
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir/Dialect/Tensor/IR/Tensor.h"
22 
23 namespace mlir {
24 namespace mhlo {
25 namespace {
26 
27 // All the reduce ops can be divided into following four types:
28 //  - a) column reduction, only reduce the most significant dimensions.
29 //  - b) row reduction, only reduce the least significant dimensions.
30 //  - c) reduce to scalar, all dimensions are reduced.
31 //  - d) others. (not support now, maybe use transpose to canonicalize)
32 //
33 // Currently we do following canonicalization to match the capacity of codegen
34 // backend.
35 //
36 // For case a):
37 // ====================================================================================
38 //   we convert all column reduction to rank-2 column reduction.
39 //   For example, suppose we have:
40 //   ```
41 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
42 //       ...
43 //       %2 = "mhlo.reduce"(%arg0, ...) ( {...})
44 //         {dimensions = dense<[0]> : tensor<1xi64>} :
45 //         (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
46 //       return %2 : tensor<?x?xf32>
47 //     }
48 //  ```
49 //   After conversion:
50 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
51 //       // [a, b, c] -> [a, b*c]
52 //       %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
53 //       tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ( {...})
54 //         {dimensions = dense<[0]> : tensor<1xi64>} :
55 //         (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
56 //       %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
57 //       -> tensor<?x?f32> return %3 : tensor<?x?xf32>
58 //     }
59 //  ```
60 //
61 // For case b):
62 // ====================================================================================
63 //   we convert all row reduction to rank-2 row reduction.
64 //   For example, suppose we have:
65 //   ```
66 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
67 //       ...
68 //       %2 = "mhlo.reduce"(%arg0, ...) ( {...})
69 //         {dimensions = dense<[2]> : tensor<1xi64>} :
70 //         (tensor<?x?x?xf32>, tensor<f32>) -> tensor<?x?xf32>
71 //       return %2 : tensor<?x?xf32>
72 //     }
73 //  ```
74 //   After conversion:
75 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<?x?xf32> {
76 //       // [a, b, c] -> [a*b, c]
77 //       %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
78 //       tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ( {...})
79 //         {dimensions = dense<[1]> : tensor<1xi64>} :
80 //         (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
81 //       %3 = "mhlo.dynamic_reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>)
82 //       -> tensor<?x?f32> return %3 : tensor<?x?xf32>
83 //     }
84 //  ```
85 //
86 // For case c):
87 // ====================================================================================
88 //   we convert all reduce-to-scalar to rank-2 column reduction.
89 //
90 //   For example, suppose we have:
91 //   ```
92 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
93 //       ...
94 //       %2 = "mhlo.reduce"(%arg0, ...) ( {...})
95 //         {dimensions = dense<[0,1,2]> : tensor<3xi64>} :
96 //         (tensor<?x?x?xf32>, tensor<f32>) -> tensor<f32>
97 //       return %2 : tensor<f32>
98 //     }
99 //  ```
100 //   After conversion:
101 //     func @test(%arg0: tensor<?x?x?xf32>) -> tensor<f32> {
102 //       // [a, b, c] -> [a*b*c, 1]
103 //       %1 = mhlo.dynamic_reshape(%arg0, ...) : (tensor<?x?x?xf32>,
104 //       tensor<2xi64>) -> tensor<?x?xf32> %2 = "mhlo.reduce"(%1, ...) ( {...})
105 //         {dimensions = dense<[0]> : tensor<1xi64>} :
106 //         (tensor<?x?xf32>, tensor<f32>) -> tensor<?xf32>
107 //       %3 = "mhlo.reshape"(%2, ...) : (tensor<?xf32>, tensor<1xi64>) ->
108 //       tensor<f32> return %3 : tensor<f32>
109 //     }
110 //  ```
111 
112 struct HloCanonicalizeReductionPass
113     : HloCanonicalizeReductionPassBase<HloCanonicalizeReductionPass> {
getDependentDialectsmlir::mhlo::__anon8940eb270111::HloCanonicalizeReductionPass114   void getDependentDialects(DialectRegistry& registry) const override {
115     registry.insert<tensor::TensorDialect>();
116   }
runOnFunctionmlir::mhlo::__anon8940eb270111::HloCanonicalizeReductionPass117   void runOnFunction() override {
118     getFunction().walk([&](ReduceOp op) {
119       SmallVector<int64_t, 4> dims_to_reduce;
120       DenseSet<int64_t> dims_to_reduce_set;
121       for (auto dim : op.dimensions().getIntValues()) {
122         dims_to_reduce.push_back(dim.getSExtValue());
123         dims_to_reduce_set.insert(dims_to_reduce.back());
124       }
125 
126       // empty reduction is just a no-op, thus no need to do codegen.
127       if (dims_to_reduce.empty()) return;
128 
129       // suppose reduce input is a ranked tensor
130       auto ty = op.getOperand(0).getType().dyn_cast<RankedTensorType>();
131       if (!ty) return signalPassFailure();
132       int rank = ty.getRank();
133       int ndims_to_reduce = dims_to_reduce.size();
134       auto elem_ty = ty.getElementType();
135       llvm::sort(dims_to_reduce);
136 
137       // skip case d) form since we don't support it.
138       if ((dims_to_reduce.back() - dims_to_reduce[0]) !=
139               (ndims_to_reduce - 1) ||
140           (dims_to_reduce[0] != 0 && dims_to_reduce.back() != (rank - 1))) {
141         return;
142       }
143 
144       // rank 2 row/column reduction is already supported.
145       if (rank == 2 && ndims_to_reduce == 1) {
146         return;
147       }
148 
149       SmallVector<int64_t, 4> dims_to_keep;
150       for (int i = 0; i < rank; ++i) {
151         if (!dims_to_reduce_set.count(i)) dims_to_keep.push_back(i);
152       }
153 
154       OpBuilder b(op);
155       auto loc = op.getLoc();
156       // TODO(disc): uniformed shape_scalar_type with shape_derivation
157       auto shape_scalar_type = b.getIntegerType(32);
158       auto one = b.create<ConstantIntOp>(loc, 1ll, shape_scalar_type);
159 
160       // funtion to get total elements in selected dimensions
161       auto dim_prod = [&](ArrayRef<int64_t> dims) {
162         Value nelems = one;
163         for (int64_t v : dims) {
164           Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), v);
165           nelems = b.create<MulIOp>(
166               loc, nelems,
167               b.create<IndexCastOp>(loc, dim_index, shape_scalar_type));
168         }
169         return nelems;
170       };
171 
172       SmallVector<Value, 2> new_operand_dims;
173       DenseIntElementsAttr attr;
174       Value nelem_to_reduce = dim_prod(dims_to_reduce);
175       Value nelem_to_keep = dim_prod(dims_to_keep);
176       if (rank == ndims_to_reduce) {
177         // case c) Reduce to scalar.
178         // Currently we don't support reduce to scalar directly.
179         // As a workaround, we convert the `reduce to scalar` to a rank 2
180         // column reduction having following form:
181         // Suppose nelems = ProdutionOp(ShapeOp(I)), We convert I into
182         // shape `[nelems, 1]`.
183         // TODO(disc): this may have performance issue. Implements a reduce to
184         // scalar schedule if necessary.
185         new_operand_dims.push_back(nelem_to_reduce);
186         new_operand_dims.push_back(nelem_to_keep);
187         attr = DenseIntElementsAttr::get(
188             RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
189       } else if (dims_to_reduce[0] == 0) {
190         // case a) column reduction
191         new_operand_dims.push_back(nelem_to_reduce);
192         new_operand_dims.push_back(nelem_to_keep);
193         attr = DenseIntElementsAttr::get(
194             RankedTensorType::get({1}, b.getIntegerType(64)), {0ll});
195       } else {
196         // case b) row reduction
197         new_operand_dims.push_back(nelem_to_keep);
198         new_operand_dims.push_back(nelem_to_reduce);
199         attr = DenseIntElementsAttr::get(
200             RankedTensorType::get({1}, b.getIntegerType(64)), {1ll});
201       }
202 
203       Value new_operand_shape =
204           b.create<tensor::FromElementsOp>(loc, new_operand_dims);
205 
206       SmallVector<Value, 4> new_operands;
207       for (Value operand : op.inputs()) {
208         new_operands.push_back(b.create<DynamicReshapeOp>(
209             loc,
210             RankedTensorType::get(
211                 SmallVector<int64_t, 4>(new_operand_dims.size(),
212                                         ShapedType::kDynamicSize),
213                 elem_ty),
214             operand, new_operand_shape));
215       }
216       auto new_op =
217           b.create<ReduceOp>(loc, new_operands, op.init_values(), attr);
218       new_op.body().takeBody(op.body());
219 
220       SmallVector<Value, 4> new_results;
221       if (dims_to_keep.empty()) {
222         // case c) reduce to scalar
223         // reshape rank 1 tensor with size 1 to a rank 0 tensor
224         for (Value result : new_op.getResults()) {
225           new_results.push_back(b.create<ReshapeOp>(
226               loc, RankedTensorType::get({}, elem_ty), result));
227         }
228       } else {
229         SmallVector<Value, 4> result_dims;
230         for (int64_t i : dims_to_keep) {
231           Value dim_index = b.create<tensor::DimOp>(loc, op.getOperand(0), i);
232           result_dims.push_back(
233               b.create<IndexCastOp>(loc, dim_index, shape_scalar_type));
234         }
235         Value result_shape = b.create<tensor::FromElementsOp>(loc, result_dims);
236         for (auto&& e : llvm::zip(op.getResults(), new_op.getResults())) {
237           new_results.push_back(b.create<DynamicReshapeOp>(
238               loc, std::get<0>(e).getType(), std::get<1>(e), result_shape));
239         }
240       }
241       for (auto&& e : llvm::zip(op.getResults(), new_results)) {
242         std::get<0>(e).replaceAllUsesWith(std::get<1>(e));
243       }
244       op.erase();
245     });
246   }
247 };
248 
249 }  // namespace
250 
createHloCanonicalizeReductionPass()251 std::unique_ptr<FunctionPass> createHloCanonicalizeReductionPass() {
252   return std::make_unique<HloCanonicalizeReductionPass>();
253 }
254 
255 }  // namespace mhlo
256 }  // namespace mlir
257