1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This pass hoists a `tf_device.replicate` body and replicates each TensorFlow
17 // dialect op in the body based on its `device` attribute and the `devices`
18 // attribute on the `tf_device.replicate`.
19
20 #include "mlir/IR/Builders.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25
26 namespace mlir {
27 namespace TFDevice {
28 namespace {
29
30 constexpr char kDeviceAttr[] = "device";
31
32 class TFDeviceReplicationPass
33 : public PassWrapper<TFDeviceReplicationPass, OperationPass<ModuleOp>> {
34 public:
getDependentDialects(DialectRegistry & registry) const35 void getDependentDialects(DialectRegistry ®istry) const override {
36 registry.insert<TF::TensorFlowDialect>();
37 }
38
getArgument() const39 StringRef getArgument() const final { return "tf-device-replication"; }
40
getDescription() const41 StringRef getDescription() const final {
42 return "Hoists and replicates the tf_device.replicate inner ops once for "
43 "each associated device.";
44 }
45
runOnOperation()46 void runOnOperation() override {
47 ModuleOp module = getOperation();
48 const Dialect *tf_dialect = getContext().getLoadedDialect("tf");
49 module.walk([&](tf_device::ReplicateOp replicate_op) {
50 OpBuilder builder(replicate_op);
51 // Map from the existing operation in ReplicateOp's region to a list of
52 // its replicated operations.
53 llvm::DenseMap<Operation *, llvm::SmallVector<Operation *, 4>>
54 operation_map;
55 llvm::Optional<DictionaryAttr> devices = replicate_op.devices();
56 const int replicate_num = replicate_op.n();
57
58 // Replicates every operation in the region of the ReplicateOp to match
59 // the number of devices.
60 for (int i : llvm::seq<int>(0, replicate_num)) {
61 // Gets the mapping from the packed and replicated block arguments to
62 // the actual value. This mapping is used to replace the arguments used
63 // by the cloned operations.
64 BlockAndValueMapping mapping;
65 for (BlockArgument &arg : replicate_op.GetBody().getArguments()) {
66 Value new_arg =
67 replicate_op.GetReplicaOperandForBlockArgument(arg, i);
68 mapping.map(arg, new_arg);
69 }
70 for (Operation &op : replicate_op.GetBody().without_terminator()) {
71 // Clones the operation and places it outside the replicate_op's body.
72 llvm::SmallVector<Operation *, 4> &new_ops = operation_map[&op];
73 Operation *new_op = builder.clone(op, mapping);
74 new_ops.push_back(new_op);
75 // If the op is a TF op, it has a string-valued device attribute and
76 // the replicate_op has a list of devices corresponding to this device
77 // attribute's value, updates the device attribute for this op.
78 if (!devices) continue;
79
80 if (op.getDialect() != tf_dialect) continue;
81
82 StringAttr device_alias =
83 new_op->getAttrOfType<StringAttr>(kDeviceAttr);
84 if (!device_alias) continue;
85
86 Attribute new_devices = devices->get(device_alias.getValue());
87 if (!new_devices) continue;
88
89 ArrayAttr new_devices_array = new_devices.cast<ArrayAttr>();
90 new_op->setAttr(kDeviceAttr, new_devices_array[i].cast<StringAttr>());
91 }
92 }
93 // Replaces usages of the existing results of the tf_device.replicate
94 // op with the results of the newly replicated operations.
95 llvm::SmallVector<Value, 4> new_results;
96 for (Value v : replicate_op.GetBody().getTerminator()->getOperands()) {
97 OpResult result = v.dyn_cast<OpResult>();
98 // Uses the original value if the value is not an OpResult.
99 if (!result) {
100 for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
101 continue;
102 }
103 // Uses the original value if the value is defined by an op outside the
104 // tf_device.replicate's body.
105 Operation *op = result.getDefiningOp();
106 if (operation_map.find(op) == operation_map.end()) {
107 for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
108 continue;
109 }
110 // Uses the values defined by the newly replicated operations.
111 int result_num = result.getResultNumber();
112 for (Operation *new_op : operation_map[op]) {
113 new_results.push_back(new_op->getResult(result_num));
114 }
115 }
116 replicate_op.replaceAllUsesWith(new_results);
117 replicate_op.erase();
118 });
119 }
120 };
121
122 } // namespace
123
CreateTFDeviceReplicationPass()124 std::unique_ptr<OperationPass<ModuleOp>> CreateTFDeviceReplicationPass() {
125 return std::make_unique<TFDeviceReplicationPass>();
126 }
127
128 static PassRegistration<TFDeviceReplicationPass> pass;
129
130 } // namespace TFDevice
131 } // namespace mlir
132