• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "llvm/ADT/DenseMap.h"
16 #include "llvm/ADT/DenseSet.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "llvm/Support/Debug.h"
19 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
20 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h"
21 #include "mlir-hlo/utils/hlo_utils.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // TF:llvm-project
23 #include "mlir/IR/BuiltinAttributes.h"
24 #include "mlir/IR/MLIRContext.h"  // TF:llvm-project
25 #include "mlir/Pass/Pass.h"       // TF:local_config_mlir
26 
27 namespace mlir {
28 
29 using hlo::kCpu;
30 using hlo::kDiscShapeCalcAttr;
31 
32 namespace mhlo {
33 namespace {
34 
35 // Check Op if it is a mhlo Op.
isMhloDialect(Operation * op)36 bool isMhloDialect(Operation* op) {
37   return (op->getDialect()->getTypeID() == TypeID::get<mhlo::MhloDialect>());
38 }
39 
40 // This pass explicitly marks the shape calculating Op by adding an Attr. Nested
41 // FuncOps should be taken into consideration.
42 // Following Ops are shape Ops:
43 //  - i64 Scalar output
44 //  - Shape Op's operands
45 //  - Shape operands according to kShapeCalcOperandMap
46 // Following Ops regard as shape Ops:
47 //  - GetDimensionSizeOp, PrintOp
48 //  - ConstOp, SelectOp, IotaOp, DynamicIotaOp if type is i32
49 //  - mhlo.dynamic_gather and mhlo.gather if operand_0's type is i32
50 //  - Date operands but type is i32 according to kShapeCalcOperandMap
51 class MarkShapeCalc : public MarkShapeCalculationPassBase<MarkShapeCalc> {
52  public:
53   using MarkShapeCalculationPassBase<
54       MarkShapeCalc>::MarkShapeCalculationPassBase;
55 
56   MarkShapeCalc() = default;
57   MarkShapeCalc(const MarkShapeCalc& o) = default;
58 
initialize(MLIRContext * context)59   LogicalResult initialize(MLIRContext* context) final {
60     // Cache these during initialization to enable pointer comparison during
61     // pass execution.
62     cpu_placement_attr_ = StringAttr::get(context, kCpu);
63     output_placement_attr_key_ =
64         Identifier::get(hlo::kOutputPlacementAttr, context);
65     true_attr_ = BoolAttr::get(context, true);
66     return success();
67   }
68   void runOnOperation() final;
69 
70  private:
71   // Mark shape calculation subgraph
72   void MarkShapeCalcOps();
73 
74   // Regard any mhlo Ops that calculates I32 as shape calculation Ops
75   void MarkRegardAsShapeCalcOps();
76 
77   // for rule based placement strategy, the placement of the op in the list
78   // is up to the placement of the dominant operand
79   const DenseMap<TypeID, /*dominant operand index*/ int> kPlaceRuleMap = {
80       {TypeID::get<DynamicGatherOp>(), /*operand*/ 0},
81       {TypeID::get<GatherOp>(), /*operand*/ 0}};
82 
83   const DenseMap<TypeID, SmallVector<int, 3>> kShapeCalcOperandMap = {
84       {TypeID::get<RealDynamicSliceOp>(),
85        {/*start_indices*/ 1, /*limit_indices*/ 2, /*strides*/ 3}},
86       {TypeID::get<DynamicPadOp>(),
87        {/*edge_padding_low*/ 2, /*edge_padding_high*/ 3,
88         /*interior_padding*/ 4}},
89       {TypeID::get<DynamicReshapeOp>(), {/*shape*/ 1}},
90       {TypeID::get<DynamicIotaOp>(), {/*shape*/ 0}},
91       {TypeID::get<DynamicBroadcastInDimOp>(), {/*out_dim_size*/ 1}},
92       {TypeID::get<DynamicGatherOp>(), {/*slice_sizes*/ 2}},
93       {TypeID::get<DynamicConvOp>(), {/*paddings*/ 2}},
94       {TypeID::get<IfOp>(), {/*pred*/ 0}}};
95 
96   // add output OP into marked set if it is a I64 scalar and placment is CPU.
97   void markI64ReturnedCpuScalarOps(FuncOp func,
98                                    DenseSet<Operation*>& shape_calc_ops);
99   // Update marked set.
100   // If a OP is in marked set, add all of its operands to marked set.
101   // Add some operands of dynamic shape OPs into marked set according to lookup
102   // table.
103   void markShapeCalculationOps(FuncOp func,
104                                DenseSet<Operation*>& shape_calc_ops);
105 
106   // Cached context-owned entities for fast pointer-based access.
107   StringAttr cpu_placement_attr_;
108   Optional<Identifier> output_placement_attr_key_;
109   BoolAttr true_attr_;
110 };
111 
runOnOperation()112 void MarkShapeCalc::runOnOperation() {
113   // Mark shape calculation subgraph
114   MarkShapeCalcOps();
115 
116   // Mark any mhlo Ops that calculates I32 as shape calculation Ops
117   MarkRegardAsShapeCalcOps();
118 }
119 
120 // Mark the Ops that is the producer of any shape operands
121 // TODO(disc): handle when TupleOp exists in shape_calc_ops
MarkShapeCalcOps()122 void MarkShapeCalc::MarkShapeCalcOps() {
123   ModuleOp module = getOperation();
124   Builder builder(&getContext());
125   llvm::DenseSet<Operation*> shape_calc_ops;
126 
127   module.walk([&](FuncOp func) {
128     // Mark the i64 Scalar output as shape calculation Op.
129     // TODO(disc): revisit this if we have outputs on CPU for TF in the future.
130     if (func.getName() == "main")
131       markI64ReturnedCpuScalarOps(func, shape_calc_ops);
132     // Skip if this function is external
133     if (func.isExternal()) return;
134     // no target ops
135     if (llvm::none_of(func.getBlocks().front(),
136                       [](Operation& op) { return isMhloDialect(&op); })) {
137       return;
138     }
139     markShapeCalculationOps(func, shape_calc_ops);
140   });
141 
142   for (Operation* op : shape_calc_ops) {
143     // We suppose that mhlo op only has single output, either having tensor
144     // type or tuple type.
145     if (auto tp = op->getResult(0).getType().dyn_cast<TupleType>()) {
146       // If an op is placed on cpu, then we suppose all its outputs are
147       // placed on cpu.
148       SmallVector<Attribute> attrs(tp.size(), true_attr_);
149       op->setAttr(kDiscShapeCalcAttr, ArrayAttr::get(tp.getContext(), attrs));
150     } else {
151       op->setAttr(kDiscShapeCalcAttr, true_attr_);
152     }
153   }
154 }
155 
156 // Regard any mhlo Ops that calculates i32 as shape Ops. This is an rule based
157 // optimization that mimicking the behavior of tensorflow
MarkRegardAsShapeCalcOps()158 void MarkShapeCalc::MarkRegardAsShapeCalcOps() {
159   ModuleOp module = getOperation();
160   Builder builder(&getContext());
161 
162   module.walk([&](Operation* op) {
163     if (!isMhloDialect(op)) return;
164     if (isa<mhlo::TupleOp, mhlo::GetTupleElementOp, mhlo::WhileOp, mhlo::IfOp,
165             mhlo::ReturnOp>(op))
166       return;
167 
168     // Skip the Op that is already marked shape Op
169     auto attr = op->getAttrOfType<BoolAttr>(kDiscShapeCalcAttr);
170     if ((attr != nullptr) && (attr.getValue() == true)) return;
171 
172     if (isa<mhlo::GetDimensionSizeOp, mhlo::PrintOp>(op)) {
173       op->setAttr(kDiscShapeCalcAttr, true_attr_);
174       return;
175     }
176 
177     // Ops that only cares about the output element type
178     if (isa<mhlo::ConstOp, mhlo::SelectOp, mhlo::IotaOp, mhlo::DynamicIotaOp>(
179             op)) {
180       auto result_ty = op->getResult(0).getType().dyn_cast<RankedTensorType>();
181       assert(result_ty && "unexpected non ranked type for ConstOp");
182       auto elem_type = result_ty.getElementType();
183       if (elem_type.isInteger(32)) {
184         op->setAttr(kDiscShapeCalcAttr, true_attr_);
185       }
186       return;
187     }
188 
189     auto op_type_id = op->getAbstractOperation()->typeID;
190     bool is_shape_calc_op = false;
191     // Follow the rule of kPlaceRuleMap exist, or else follow
192     // kShapeCalcOperandMap
193     auto it = kPlaceRuleMap.find(op_type_id);
194     if (it != kPlaceRuleMap.end()) {
195       auto dominant_idx = it->second;
196       auto operand_ty =
197           op->getOperand(dominant_idx).getType().dyn_cast<RankedTensorType>();
198       assert(operand_ty && "unexpected non unranked type of operand");
199       if (operand_ty.getElementType().isInteger(32)) {
200         is_shape_calc_op = true;
201       }
202     } else {
203       auto iter = kShapeCalcOperandMap.find(op_type_id);
204       if (iter != kShapeCalcOperandMap.end()) {
205         const SmallVector<int, 3>& shape_operand_indices = iter->second;
206         for (int idx : shape_operand_indices) {
207           auto operand_ty =
208               op->getOperand(idx).getType().dyn_cast<RankedTensorType>();
209           if (!operand_ty) continue;
210           auto elem_type = operand_ty.getElementType();
211           if (elem_type.isInteger(32)) {
212             is_shape_calc_op = true;
213             break;
214           }
215         }
216       }
217     }
218     // Set attr if it is a shape Op
219     if (is_shape_calc_op) {
220       if (auto tp = op->getResult(0).getType().dyn_cast<TupleType>()) {
221         SmallVector<Attribute, 4> attrs(tp.size(), true_attr_);
222         op->setAttr(kDiscShapeCalcAttr, ArrayAttr::get(tp.getContext(), attrs));
223       } else {
224         op->setAttr(kDiscShapeCalcAttr, true_attr_);
225       }
226     }
227     return;
228   });
229 }
230 
markI64ReturnedCpuScalarOps(FuncOp func,llvm::DenseSet<Operation * > & shape_calc_ops)231 void MarkShapeCalc::markI64ReturnedCpuScalarOps(
232     FuncOp func, llvm::DenseSet<Operation*>& shape_calc_ops) {
233   assert(func.getName() == "main");
234   auto return_op = func.front().getTerminator();
235   if (!isa<mlir::ReturnOp>(return_op)) return;
236   auto result_attrs = func.getAllResultAttrs();
237   if (!result_attrs) return;
238   auto returned_ops = return_op->getOperands();
239   assert(returned_ops.size() == result_attrs.size());
240   for (auto output : llvm::enumerate(returned_ops)) {
241     Operation* op = output.value().getDefiningOp();
242     if (!op || !isMhloDialect(op)) continue;
243     int idx = output.index();
244     if (auto type = op->getResult(0).getType().dyn_cast<RankedTensorType>()) {
245       if (type.getElementType().isInteger(64) && (type.getRank() == 0) &&
246           (result_attrs[idx].cast<DictionaryAttr>().getAs<StringAttr>(
247                *output_placement_attr_key_) == cpu_placement_attr_))
248         shape_calc_ops.insert(op);
249     }
250   }
251 }
252 
markShapeCalculationOps(FuncOp func,llvm::DenseSet<Operation * > & shape_calc_ops)253 void MarkShapeCalc::markShapeCalculationOps(
254     FuncOp func, llvm::DenseSet<Operation*>& shape_calc_ops) {
255   auto& block = func.getBlocks().front();
256   for (Operation& op : block) {
257     if (!isMhloDialect(&op)) continue;
258 
259     // If the op is already in shape calculation op set, insert all of its
260     // operands into shape calculation op set
261     if (shape_calc_ops.contains(&op)) {
262       for (auto operand_value : op.getOperands()) {
263         Operation* operand = operand_value.getDefiningOp();
264         if (operand == nullptr || !isMhloDialect(operand)) continue;
265         shape_calc_ops.insert(operand);
266       }
267     } else {
268       // Mark operands into shape calculation set according to the lookup table.
269       auto op_type_id = op.getAbstractOperation()->typeID;
270       auto iter = kShapeCalcOperandMap.find(op_type_id);
271       if (iter != kShapeCalcOperandMap.end()) {
272         for (auto operand_idx : iter->second) {
273           auto operand = op.getOperand(operand_idx).getDefiningOp();
274           if (operand == nullptr || !isMhloDialect(operand)) continue;
275           shape_calc_ops.insert(operand);
276         }
277       }
278     }
279     // TODO(disc): If the operand of the op is a nested FuncOp, mark the
280     // associated producer in the nested FuncOp
281   }
282 }
283 
284 }  // namespace
285 
createMarkShapeCalcOpPass()286 std::unique_ptr<OperationPass<ModuleOp>> createMarkShapeCalcOpPass() {
287   return std::make_unique<MarkShapeCalc>();
288 }
289 
290 }  // namespace mhlo
291 }  // namespace mlir
292