• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass hoists a `tf_device.replicate` body and replicates each TensorFlow
17 // dialect op in the body based on its `device` attribute and the `devices`
18 // attribute on the `tf_device.replicate`.
19 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 
26 namespace mlir {
27 namespace TFDevice {
28 namespace {
29 
30 constexpr char kDeviceAttr[] = "device";
31 
32 class TFDeviceReplicationPass
33     : public PassWrapper<TFDeviceReplicationPass, OperationPass<ModuleOp>> {
34  public:
getDependentDialects(DialectRegistry & registry) const35   void getDependentDialects(DialectRegistry &registry) const override {
36     registry.insert<TF::TensorFlowDialect>();
37   }
38 
getArgument() const39   StringRef getArgument() const final { return "tf-device-replication"; }
40 
getDescription() const41   StringRef getDescription() const final {
42     return "Hoists and replicates the tf_device.replicate inner ops once for "
43            "each associated device.";
44   }
45 
runOnOperation()46   void runOnOperation() override {
47     ModuleOp module = getOperation();
48     const Dialect *tf_dialect = getContext().getLoadedDialect("tf");
49     module.walk([&](tf_device::ReplicateOp replicate_op) {
50       OpBuilder builder(replicate_op);
51       // Map from the existing operation in ReplicateOp's region to a list of
52       // its replicated operations.
53       llvm::DenseMap<Operation *, llvm::SmallVector<Operation *, 4>>
54           operation_map;
55       llvm::Optional<DictionaryAttr> devices = replicate_op.devices();
56       const int replicate_num = replicate_op.n();
57 
58       // Replicates every operation in the region of the ReplicateOp to match
59       // the number of devices.
60       for (int i : llvm::seq<int>(0, replicate_num)) {
61         // Gets the mapping from the packed and replicated block arguments to
62         // the actual value. This mapping is used to replace the arguments used
63         // by the cloned operations.
64         BlockAndValueMapping mapping;
65         for (BlockArgument &arg : replicate_op.GetBody().getArguments()) {
66           Value new_arg =
67               replicate_op.GetReplicaOperandForBlockArgument(arg, i);
68           mapping.map(arg, new_arg);
69         }
70         for (Operation &op : replicate_op.GetBody().without_terminator()) {
71           // Clones the operation and places it outside the replicate_op's body.
72           llvm::SmallVector<Operation *, 4> &new_ops = operation_map[&op];
73           Operation *new_op = builder.clone(op, mapping);
74           new_ops.push_back(new_op);
75           // If the op is a TF op, it has a string-valued device attribute and
76           // the replicate_op has a list of devices corresponding to this device
77           // attribute's value, updates the device attribute for this op.
78           if (!devices) continue;
79 
80           if (op.getDialect() != tf_dialect) continue;
81 
82           StringAttr device_alias =
83               new_op->getAttrOfType<StringAttr>(kDeviceAttr);
84           if (!device_alias) continue;
85 
86           Attribute new_devices = devices->get(device_alias.getValue());
87           if (!new_devices) continue;
88 
89           ArrayAttr new_devices_array = new_devices.cast<ArrayAttr>();
90           new_op->setAttr(kDeviceAttr, new_devices_array[i].cast<StringAttr>());
91         }
92       }
93       // Replaces usages of the existing results of the tf_device.replicate
94       // op with the results of the newly replicated operations.
95       llvm::SmallVector<Value, 4> new_results;
96       for (Value v : replicate_op.GetBody().getTerminator()->getOperands()) {
97         OpResult result = v.dyn_cast<OpResult>();
98         // Uses the original value if the value is not an OpResult.
99         if (!result) {
100           for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
101           continue;
102         }
103         // Uses the original value if the value is defined by an op outside the
104         // tf_device.replicate's body.
105         Operation *op = result.getDefiningOp();
106         if (operation_map.find(op) == operation_map.end()) {
107           for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
108           continue;
109         }
110         // Uses the values defined by the newly replicated operations.
111         int result_num = result.getResultNumber();
112         for (Operation *new_op : operation_map[op]) {
113           new_results.push_back(new_op->getResult(result_num));
114         }
115       }
116       replicate_op.replaceAllUsesWith(new_results);
117       replicate_op.erase();
118     });
119   }
120 };
121 
122 }  // namespace
123 
CreateTFDeviceReplicationPass()124 std::unique_ptr<OperationPass<ModuleOp>> CreateTFDeviceReplicationPass() {
125   return std::make_unique<TFDeviceReplicationPass>();
126 }
127 
128 static PassRegistration<TFDeviceReplicationPass> pass;
129 
130 }  // namespace TFDevice
131 }  // namespace mlir
132