1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
22 #include "mlir/Pass/Pass.h" // from @llvm-project
23 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
24 #include "mlir/Support/LLVM.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
29
30 namespace mlir {
31 namespace TFTPU {
32
33 // A pass that finds TPU clusters with write only resource access and adds an
34 // associated resource read, so the resource can later be fused into TPUExecute.
35 namespace {
36 struct TPUResourceReadForWritePass
37 : public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
38 void runOnOperation() override;
39 };
40
41 // Helper struct holding a resource value and its associated type.
42 struct ResourceValueAndSubtype {
43 Value resource;
44 Type subtype;
45 };
46
47 // Finds resource handle and type for result if result writes to a resource.
GetResourceWriteResult(tf_device::ClusterFuncOp cluster_func,Value result)48 ResourceValueAndSubtype GetResourceWriteResult(
49 tf_device::ClusterFuncOp cluster_func, Value result) {
50 ResourceValueAndSubtype resource;
51 if (!result.hasOneUse()) return resource;
52 Operation* result_user = *result.getUsers().begin();
53 auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
54 if (!assign_var) return resource;
55
56 auto handle = assign_var.resource();
57 // Skip result if cluster writes to the same variable via multiple results.
58 for (Operation* handle_user : handle.getUsers()) {
59 if (handle_user == assign_var) continue;
60 auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
61 if (!assign_var_user) continue;
62 if (assign_var_user.value().getDefiningOp() == cluster_func)
63 return resource;
64 }
65
66 resource.resource = assign_var.resource();
67 resource.subtype = assign_var.value().getType();
68 return resource;
69 }
70
71 // Checks if resource is read by TPU cluster.
ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,Value resource)72 bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
73 Value resource) {
74 for (Operation* resource_user : resource.getUsers())
75 if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
76 for (Operation* read_user : read.value().getUsers())
77 if (read_user == cluster_func) return true;
78
79 return false;
80 }
81
runOnOperation()82 void TPUResourceReadForWritePass::runOnOperation() {
83 SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
84 getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
85 cluster_funcs.push_back(cluster_func);
86 });
87
88 OpBuilder builder(&getContext());
89 // Add resource reads for resource writes from TPU cluster where for such
90 // resources the TPU cluster does not read from.
91 for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
92 builder.setInsertionPoint(cluster_func);
93
94 SmallVector<Value, 4> read_operands;
95 for (Value result : cluster_func.getResults()) {
96 // TODO(lyandy): Update pass to use resource alias analysis.
97 auto resource_and_type = GetResourceWriteResult(cluster_func, result);
98 if (!resource_and_type.resource) continue;
99 if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
100 continue;
101 auto new_read = builder.create<TF::ReadVariableOp>(
102 resource_and_type.resource.getLoc(), resource_and_type.subtype,
103 resource_and_type.resource);
104 read_operands.push_back(new_read.value());
105 }
106
107 if (read_operands.empty()) continue;
108
109 // Update caller and function types with new read operands.
110 auto operands = llvm::to_vector<4>(cluster_func.getOperands());
111 operands.append(read_operands.begin(), read_operands.end());
112
113 auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
114 cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
115 cluster_func.getAttrs());
116 cluster_func.replaceAllUsesWith(new_cluster_func);
117 FuncOp func = cluster_func.getFunc();
118 Block& block = func.front();
119 for (Value read_operand : read_operands)
120 block.addArgument(read_operand.getType());
121
122 func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
123 func.getCallableResults()));
124 cluster_func.erase();
125 }
126 }
127
128 } // namespace
129
CreateTPUResourceReadForWritePass()130 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
131 return std::make_unique<TPUResourceReadForWritePass>();
132 }
133
134 } // namespace TFTPU
135 } // namespace mlir
136