1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Support/Debug.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir-hlo/utils/hlo_utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // TF:llvm-project
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/MLIRContext.h" // TF:llvm-project
25 #include "mlir/Pass/Pass.h" // TF:local_config_mlir
26
27 namespace mlir {
28
29 using hlo::kCpu;
30 using hlo::kDiscShapeCalcAttr;
31
32 namespace mhlo {
33 namespace {
34
35 // Check Op if it is a mhlo Op.
isMhloDialect(Operation * op)36 bool isMhloDialect(Operation* op) {
37 return (op->getDialect()->getTypeID() == TypeID::get<mhlo::MhloDialect>());
38 }
39
40 // This pass explicitly marks the shape calculating Op by adding an Attr. Nested
41 // FuncOps should be taken into consideration.
42 // Following Ops are shape Ops:
43 // - i64 Scalar output
44 // - Shape Op's operands
45 // - Shape operands according to kShapeCalcOperandMap
46 // Following Ops regard as shape Ops:
47 // - GetDimensionSizeOp, PrintOp
48 // - ConstOp, SelectOp, IotaOp, DynamicIotaOp if type is i32
49 // - mhlo.dynamic_gather and mhlo.gather if operand_0's type is i32
50 // - Date operands but type is i32 according to kShapeCalcOperandMap
51 class MarkShapeCalc : public MarkShapeCalculationPassBase<MarkShapeCalc> {
52 public:
53 using MarkShapeCalculationPassBase<
54 MarkShapeCalc>::MarkShapeCalculationPassBase;
55
56 MarkShapeCalc() = default;
57 MarkShapeCalc(const MarkShapeCalc& o) = default;
58
initialize(MLIRContext * context)59 LogicalResult initialize(MLIRContext* context) final {
60 // Cache these during initialization to enable pointer comparison during
61 // pass execution.
62 cpu_placement_attr_ = StringAttr::get(context, kCpu);
63 output_placement_attr_key_ =
64 Identifier::get(hlo::kOutputPlacementAttr, context);
65 true_attr_ = BoolAttr::get(context, true);
66 return success();
67 }
68 void runOnOperation() final;
69
70 private:
71 // Mark shape calculation subgraph
72 void MarkShapeCalcOps();
73
74 // Regard any mhlo Ops that calculates I32 as shape calculation Ops
75 void MarkRegardAsShapeCalcOps();
76
77 // for rule based placement strategy, the placement of the op in the list
78 // is up to the placement of the dominant operand
79 const DenseMap<TypeID, /*dominant operand index*/ int> kPlaceRuleMap = {
80 {TypeID::get<DynamicGatherOp>(), /*operand*/ 0},
81 {TypeID::get<GatherOp>(), /*operand*/ 0}};
82
83 const DenseMap<TypeID, SmallVector<int, 3>> kShapeCalcOperandMap = {
84 {TypeID::get<RealDynamicSliceOp>(),
85 {/*start_indices*/ 1, /*limit_indices*/ 2, /*strides*/ 3}},
86 {TypeID::get<DynamicPadOp>(),
87 {/*edge_padding_low*/ 2, /*edge_padding_high*/ 3,
88 /*interior_padding*/ 4}},
89 {TypeID::get<DynamicReshapeOp>(), {/*shape*/ 1}},
90 {TypeID::get<DynamicIotaOp>(), {/*shape*/ 0}},
91 {TypeID::get<DynamicBroadcastInDimOp>(), {/*out_dim_size*/ 1}},
92 {TypeID::get<DynamicGatherOp>(), {/*slice_sizes*/ 2}},
93 {TypeID::get<DynamicConvOp>(), {/*paddings*/ 2}},
94 {TypeID::get<IfOp>(), {/*pred*/ 0}}};
95
96 // add output OP into marked set if it is a I64 scalar and placment is CPU.
97 void markI64ReturnedCpuScalarOps(FuncOp func,
98 DenseSet<Operation*>& shape_calc_ops);
99 // Update marked set.
100 // If a OP is in marked set, add all of its operands to marked set.
101 // Add some operands of dynamic shape OPs into marked set according to lookup
102 // table.
103 void markShapeCalculationOps(FuncOp func,
104 DenseSet<Operation*>& shape_calc_ops);
105
106 // Cached context-owned entities for fast pointer-based access.
107 StringAttr cpu_placement_attr_;
108 Optional<Identifier> output_placement_attr_key_;
109 BoolAttr true_attr_;
110 };
111
runOnOperation()112 void MarkShapeCalc::runOnOperation() {
113 // Mark shape calculation subgraph
114 MarkShapeCalcOps();
115
116 // Mark any mhlo Ops that calculates I32 as shape calculation Ops
117 MarkRegardAsShapeCalcOps();
118 }
119
120 // Mark the Ops that is the producer of any shape operands
121 // TODO(disc): handle when TupleOp exists in shape_calc_ops
MarkShapeCalcOps()122 void MarkShapeCalc::MarkShapeCalcOps() {
123 ModuleOp module = getOperation();
124 Builder builder(&getContext());
125 llvm::DenseSet<Operation*> shape_calc_ops;
126
127 module.walk([&](FuncOp func) {
128 // Mark the i64 Scalar output as shape calculation Op.
129 // TODO(disc): revisit this if we have outputs on CPU for TF in the future.
130 if (func.getName() == "main")
131 markI64ReturnedCpuScalarOps(func, shape_calc_ops);
132 // Skip if this function is external
133 if (func.isExternal()) return;
134 // no target ops
135 if (llvm::none_of(func.getBlocks().front(),
136 [](Operation& op) { return isMhloDialect(&op); })) {
137 return;
138 }
139 markShapeCalculationOps(func, shape_calc_ops);
140 });
141
142 for (Operation* op : shape_calc_ops) {
143 // We suppose that mhlo op only has single output, either having tensor
144 // type or tuple type.
145 if (auto tp = op->getResult(0).getType().dyn_cast<TupleType>()) {
146 // If an op is placed on cpu, then we suppose all its outputs are
147 // placed on cpu.
148 SmallVector<Attribute> attrs(tp.size(), true_attr_);
149 op->setAttr(kDiscShapeCalcAttr, ArrayAttr::get(tp.getContext(), attrs));
150 } else {
151 op->setAttr(kDiscShapeCalcAttr, true_attr_);
152 }
153 }
154 }
155
156 // Regard any mhlo Ops that calculates i32 as shape Ops. This is an rule based
157 // optimization that mimicking the behavior of tensorflow
MarkRegardAsShapeCalcOps()158 void MarkShapeCalc::MarkRegardAsShapeCalcOps() {
159 ModuleOp module = getOperation();
160 Builder builder(&getContext());
161
162 module.walk([&](Operation* op) {
163 if (!isMhloDialect(op)) return;
164 if (isa<mhlo::TupleOp, mhlo::GetTupleElementOp, mhlo::WhileOp, mhlo::IfOp,
165 mhlo::ReturnOp>(op))
166 return;
167
168 // Skip the Op that is already marked shape Op
169 auto attr = op->getAttrOfType<BoolAttr>(kDiscShapeCalcAttr);
170 if ((attr != nullptr) && (attr.getValue() == true)) return;
171
172 if (isa<mhlo::GetDimensionSizeOp, mhlo::PrintOp>(op)) {
173 op->setAttr(kDiscShapeCalcAttr, true_attr_);
174 return;
175 }
176
177 // Ops that only cares about the output element type
178 if (isa<mhlo::ConstOp, mhlo::SelectOp, mhlo::IotaOp, mhlo::DynamicIotaOp>(
179 op)) {
180 auto result_ty = op->getResult(0).getType().dyn_cast<RankedTensorType>();
181 assert(result_ty && "unexpected non ranked type for ConstOp");
182 auto elem_type = result_ty.getElementType();
183 if (elem_type.isInteger(32)) {
184 op->setAttr(kDiscShapeCalcAttr, true_attr_);
185 }
186 return;
187 }
188
189 auto op_type_id = op->getAbstractOperation()->typeID;
190 bool is_shape_calc_op = false;
191 // Follow the rule of kPlaceRuleMap exist, or else follow
192 // kShapeCalcOperandMap
193 auto it = kPlaceRuleMap.find(op_type_id);
194 if (it != kPlaceRuleMap.end()) {
195 auto dominant_idx = it->second;
196 auto operand_ty =
197 op->getOperand(dominant_idx).getType().dyn_cast<RankedTensorType>();
198 assert(operand_ty && "unexpected non unranked type of operand");
199 if (operand_ty.getElementType().isInteger(32)) {
200 is_shape_calc_op = true;
201 }
202 } else {
203 auto iter = kShapeCalcOperandMap.find(op_type_id);
204 if (iter != kShapeCalcOperandMap.end()) {
205 const SmallVector<int, 3>& shape_operand_indices = iter->second;
206 for (int idx : shape_operand_indices) {
207 auto operand_ty =
208 op->getOperand(idx).getType().dyn_cast<RankedTensorType>();
209 if (!operand_ty) continue;
210 auto elem_type = operand_ty.getElementType();
211 if (elem_type.isInteger(32)) {
212 is_shape_calc_op = true;
213 break;
214 }
215 }
216 }
217 }
218 // Set attr if it is a shape Op
219 if (is_shape_calc_op) {
220 if (auto tp = op->getResult(0).getType().dyn_cast<TupleType>()) {
221 SmallVector<Attribute, 4> attrs(tp.size(), true_attr_);
222 op->setAttr(kDiscShapeCalcAttr, ArrayAttr::get(tp.getContext(), attrs));
223 } else {
224 op->setAttr(kDiscShapeCalcAttr, true_attr_);
225 }
226 }
227 return;
228 });
229 }
230
markI64ReturnedCpuScalarOps(FuncOp func,llvm::DenseSet<Operation * > & shape_calc_ops)231 void MarkShapeCalc::markI64ReturnedCpuScalarOps(
232 FuncOp func, llvm::DenseSet<Operation*>& shape_calc_ops) {
233 assert(func.getName() == "main");
234 auto return_op = func.front().getTerminator();
235 if (!isa<mlir::ReturnOp>(return_op)) return;
236 auto result_attrs = func.getAllResultAttrs();
237 if (!result_attrs) return;
238 auto returned_ops = return_op->getOperands();
239 assert(returned_ops.size() == result_attrs.size());
240 for (auto output : llvm::enumerate(returned_ops)) {
241 Operation* op = output.value().getDefiningOp();
242 if (!op || !isMhloDialect(op)) continue;
243 int idx = output.index();
244 if (auto type = op->getResult(0).getType().dyn_cast<RankedTensorType>()) {
245 if (type.getElementType().isInteger(64) && (type.getRank() == 0) &&
246 (result_attrs[idx].cast<DictionaryAttr>().getAs<StringAttr>(
247 *output_placement_attr_key_) == cpu_placement_attr_))
248 shape_calc_ops.insert(op);
249 }
250 }
251 }
252
markShapeCalculationOps(FuncOp func,llvm::DenseSet<Operation * > & shape_calc_ops)253 void MarkShapeCalc::markShapeCalculationOps(
254 FuncOp func, llvm::DenseSet<Operation*>& shape_calc_ops) {
255 auto& block = func.getBlocks().front();
256 for (Operation& op : block) {
257 if (!isMhloDialect(&op)) continue;
258
259 // If the op is already in shape calculation op set, insert all of its
260 // operands into shape calculation op set
261 if (shape_calc_ops.contains(&op)) {
262 for (auto operand_value : op.getOperands()) {
263 Operation* operand = operand_value.getDefiningOp();
264 if (operand == nullptr || !isMhloDialect(operand)) continue;
265 shape_calc_ops.insert(operand);
266 }
267 } else {
268 // Mark operands into shape calculation set according to the lookup table.
269 auto op_type_id = op.getAbstractOperation()->typeID;
270 auto iter = kShapeCalcOperandMap.find(op_type_id);
271 if (iter != kShapeCalcOperandMap.end()) {
272 for (auto operand_idx : iter->second) {
273 auto operand = op.getOperand(operand_idx).getDefiningOp();
274 if (operand == nullptr || !isMhloDialect(operand)) continue;
275 shape_calc_ops.insert(operand);
276 }
277 }
278 }
279 // TODO(disc): If the operand of the op is a nested FuncOp, mark the
280 // associated producer in the nested FuncOp
281 }
282 }
283
284 } // namespace
285
createMarkShapeCalcOpPass()286 std::unique_ptr<OperationPass<ModuleOp>> createMarkShapeCalcOpPass() {
287 return std::make_unique<MarkShapeCalc>();
288 }
289
290 } // namespace mhlo
291 } // namespace mlir
292