• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This pass hoists a `tf_device.replicate` body and replicates each TensorFlow
17 // dialect op in the body based on its `device` attribute and the `devices`
18 // attribute on the `tf_device.replicate`.
19 
20 #include "mlir/IR/Builders.h"
21 #include "mlir/Pass/Pass.h"
22 #include "mlir/IR/BlockAndValueMapping.h"  // from @llvm-project
23 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
24 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25 
26 namespace mlir {
27 namespace TFDevice {
28 namespace {
29 
30 constexpr char kDeviceAttr[] = "device";
31 
32 class TFDeviceReplicationPass
33     : public PassWrapper<TFDeviceReplicationPass, OperationPass<ModuleOp>> {
34  public:
getDependentDialects(DialectRegistry & registry) const35   void getDependentDialects(DialectRegistry &registry) const override {
36     registry.insert<TF::TensorFlowDialect>();
37   }
38 
runOnOperation()39   void runOnOperation() override {
40     ModuleOp module = getOperation();
41     const Dialect *tf_dialect = getContext().getLoadedDialect("tf");
42     module.walk([&](tf_device::ReplicateOp replicate_op) {
43       OpBuilder builder(replicate_op);
44       // Map from the existing operation in ReplicateOp's region to a list of
45       // its replicated operations.
46       llvm::DenseMap<Operation *, llvm::SmallVector<Operation *, 4>>
47           operation_map;
48       llvm::Optional<DictionaryAttr> devices = replicate_op.devices();
49       const int replicate_num = replicate_op.n();
50 
51       // Replicates every operation in the region of the ReplicateOp to match
52       // the number of devices.
53       for (int i : llvm::seq<int>(0, replicate_num)) {
54         // Gets the mapping from the packed and replicated block arguments to
55         // the actual value. This mapping is used to replace the arguments used
56         // by the cloned operations.
57         BlockAndValueMapping mapping;
58         for (BlockArgument &arg : replicate_op.GetBody().getArguments()) {
59           Value new_arg =
60               replicate_op.GetReplicaOperandForBlockArgument(arg, i);
61           mapping.map(arg, new_arg);
62         }
63         for (Operation &op : replicate_op.GetBody().without_terminator()) {
64           // Clones the operation and places it outside the replicate_op's body.
65           llvm::SmallVector<Operation *, 4> &new_ops = operation_map[&op];
66           Operation *new_op = builder.clone(op, mapping);
67           new_ops.push_back(new_op);
68           // If the op is a TF op, it has a string-valued device attribute and
69           // the replicate_op has a list of devices corresponding to this device
70           // attribute's value, updates the device attribute for this op.
71           if (!devices) continue;
72 
73           if (op.getDialect() != tf_dialect) continue;
74 
75           StringAttr device_alias =
76               new_op->getAttrOfType<StringAttr>(kDeviceAttr);
77           if (!device_alias) continue;
78 
79           Attribute new_devices = devices->get(device_alias.getValue());
80           if (!new_devices) continue;
81 
82           ArrayAttr new_devices_array = new_devices.cast<ArrayAttr>();
83           new_op->setAttr(kDeviceAttr, new_devices_array[i].cast<StringAttr>());
84         }
85       }
86       // Replaces usages of the existing results of the tf_device.replicate
87       // op with the results of the newly replicated operations.
88       llvm::SmallVector<Value, 4> new_results;
89       for (Value v : replicate_op.GetBody().getTerminator()->getOperands()) {
90         OpResult result = v.dyn_cast<OpResult>();
91         // Uses the original value if the value is not an OpResult.
92         if (!result) {
93           for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
94           continue;
95         }
96         // Uses the original value if the value is defined by an op outside the
97         // tf_device.replicate's body.
98         Operation *op = result.getDefiningOp();
99         if (operation_map.find(op) == operation_map.end()) {
100           for (int i = 0; i < replicate_num; ++i) new_results.push_back(v);
101           continue;
102         }
103         // Uses the values defined by the newly replicated operations.
104         int result_num = result.getResultNumber();
105         for (Operation *new_op : operation_map[op]) {
106           new_results.push_back(new_op->getResult(result_num));
107         }
108       }
109       replicate_op.replaceAllUsesWith(new_results);
110       replicate_op.erase();
111     });
112   }
113 };
114 
115 }  // namespace
116 
CreateTFDeviceReplicationPass()117 std::unique_ptr<OperationPass<ModuleOp>> CreateTFDeviceReplicationPass() {
118   return std::make_unique<TFDeviceReplicationPass>();
119 }
120 
121 static PassRegistration<TFDeviceReplicationPass> pass(
122     "tf-device-replication",
123     "Hoists and replicates the tf_device.replicate "
124     "inner ops once for each associated device.");
125 
126 }  // namespace TFDevice
127 }  // namespace mlir
128