1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include "llvm/ADT/DenseMap.h" 17 #include "llvm/Support/Casting.h" 18 #include "mlir-hlo/Dialect/mhlo/IR/hlo_ops.h" 19 #include "mlir-hlo/Dialect/mhlo/transforms/PassDetail.h" 20 #include "mlir/Dialect/StandardOps/IR/Ops.h" 21 #include "mlir/IR/Operation.h" 22 #include "mlir/Pass/Pass.h" 23 #include "mlir/Pass/PassManager.h" 24 #include "mlir/Support/LLVM.h" 25 #include "mlir/Transforms/RegionUtils.h" 26 27 namespace mlir { 28 namespace mhlo { 29 30 namespace { 31 32 // A pass that sinks constants implicitly captured in control flow regions. This 33 // is necessary to export to XLA. 34 // 35 // TODO(hinsu): Generalize this pass to handle all the ops with regions. Any 36 // value used within the region that is defined outside of op's region should be 37 // sank to the regions and not just the constants. Ops such as If and While 38 // whose computations doesn't require fixed signature like Sort or Reduce have 39 // an option to pass outside values as operands of the op to avoid recomputing 40 // those within internally. Note that doing so is the only option in case of 41 // values defined outside that are BlockArguments of any of the parent region. 42 class SinkConstantsToControlFlowPass 43 : public SinkConstantsToControlFlowPassBase< 44 SinkConstantsToControlFlowPass> { runOnFunction()45 void runOnFunction() override { 46 getFunction().walk([](Operation* op) { 47 if (auto while_op = llvm::dyn_cast<WhileOp>(op)) { 48 SinkToRegion(&while_op.body()); 49 SinkToRegion(&while_op.cond()); 50 } else if (auto if_op = llvm::dyn_cast<IfOp>(op)) { 51 SinkToRegion(&if_op.true_branch()); 52 SinkToRegion(&if_op.false_branch()); 53 } else if (auto reduce_window_op = llvm::dyn_cast<ReduceWindowOp>(op)) { 54 SinkToRegion(&reduce_window_op.body()); 55 } else if (auto sort_op = llvm::dyn_cast<SortOp>(op)) { 56 SinkToRegion(&sort_op.comparator()); 57 } 58 }); 59 } 60 61 private: 62 // Performs constant sinking into a region. SinkToRegion(Region * region)63 static void SinkToRegion(Region* region) { 64 llvm::DenseMap<Value, Operation*> sunk_constant; 65 visitUsedValuesDefinedAbove({*region}, [&](OpOperand* use) { 66 Value constant = use->get(); 67 auto op = constant.getDefiningOp(); 68 if (!op || !op->hasTrait<OpTrait::ConstantLike>()) return; 69 auto map_entry = sunk_constant.try_emplace(constant, nullptr); 70 if (!map_entry.second) { 71 // This constant has already been cloned into the region, reuse it. 72 use->set(map_entry.first->getSecond()->getResult(0)); 73 if (op->use_empty()) op->erase(); 74 return; 75 } 76 if (constant.hasOneUse()) { 77 op->moveBefore(®ion->front().front()); 78 return; 79 } 80 map_entry.first->getSecond() = op->clone(); 81 region->front().getOperations().insert(region->front().begin(), 82 map_entry.first->getSecond()); 83 use->set(map_entry.first->getSecond()->getResult(0)); 84 }); 85 } 86 }; 87 88 } // anonymous namespace 89 90 // TODO(hinsu): Rename this pass and move to a different file along with the 91 // generalization to make all ops isolated from above. createSinkConstantsToControlFlowPass()92std::unique_ptr<OperationPass<FuncOp>> createSinkConstantsToControlFlowPass() { 93 return std::make_unique<SinkConstantsToControlFlowPass>(); 94 } 95 96 } // namespace mhlo 97 } // namespace mlir 98