• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/Pass/Pass.h"  // from @llvm-project
24 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
25 #include "mlir/Support/LLVM.h"  // from @llvm-project
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
30 
31 namespace mlir {
32 namespace TFTPU {
33 
34 // A pass that finds TPU clusters with write only resource access and adds an
35 // associated resource read, so the resource can later be fused into TPUExecute.
36 namespace {
37 struct TPUResourceReadForWritePass
38     : public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
39   void runOnOperation() override;
40 };
41 
42 // Helper struct holding a resource value and its associated type.
43 struct ResourceValueAndSubtype {
44   Value resource;
45   Type subtype;
46 };
47 
48 // Finds resource handle and type for result if result writes to a resource.
GetResourceWriteResult(tf_device::ClusterFuncOp cluster_func,Value result)49 ResourceValueAndSubtype GetResourceWriteResult(
50     tf_device::ClusterFuncOp cluster_func, Value result) {
51   ResourceValueAndSubtype resource;
52   if (!result.hasOneUse()) return resource;
53   Operation* result_user = *result.getUsers().begin();
54   auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
55   if (!assign_var) return resource;
56 
57   auto handle = assign_var.resource();
58   // Skip result if cluster writes to the same variable via multiple results.
59   for (Operation* handle_user : handle.getUsers()) {
60     if (handle_user == assign_var) continue;
61     auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
62     if (!assign_var_user) continue;
63     if (assign_var_user.value().getDefiningOp() == cluster_func)
64       return resource;
65   }
66 
67   resource.resource = assign_var.resource();
68   resource.subtype = assign_var.value().getType();
69   return resource;
70 }
71 
72 // Checks if resource is read by TPU cluster.
ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,Value resource)73 bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
74                                 Value resource) {
75   for (Operation* resource_user : resource.getUsers())
76     if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
77       for (Operation* read_user : read.value().getUsers())
78         if (read_user == cluster_func) return true;
79 
80   return false;
81 }
82 
runOnOperation()83 void TPUResourceReadForWritePass::runOnOperation() {
84   SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
85   getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
86     cluster_funcs.push_back(cluster_func);
87   });
88 
89   OpBuilder builder(&getContext());
90   // Add resource reads for resource writes from TPU cluster where for such
91   // resources the TPU cluster does not read from.
92   for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
93     builder.setInsertionPoint(cluster_func);
94 
95     SmallVector<Value, 4> read_operands;
96     for (Value result : cluster_func.getResults()) {
97       // TODO(lyandy): Update pass to use resource alias analysis.
98       auto resource_and_type = GetResourceWriteResult(cluster_func, result);
99       if (!resource_and_type.resource) continue;
100       if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
101         continue;
102       auto new_read = builder.create<TF::ReadVariableOp>(
103           resource_and_type.resource.getLoc(), resource_and_type.subtype,
104           resource_and_type.resource);
105       read_operands.push_back(new_read.value());
106     }
107 
108     if (read_operands.empty()) continue;
109 
110     // Update caller and function types with new read operands.
111     auto operands = llvm::to_vector<4>(cluster_func.getOperands());
112     operands.append(read_operands.begin(), read_operands.end());
113 
114     auto loc = cluster_func.getLoc();
115     auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
116         loc, cluster_func.getResultTypes(), operands, cluster_func->getAttrs());
117     cluster_func.replaceAllUsesWith(new_cluster_func);
118     func::FuncOp func = cluster_func.getFunc();
119     Block& block = func.front();
120     for (Value read_operand : read_operands)
121       block.addArgument(read_operand.getType(), loc);
122 
123     func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
124                                    func.getCallableResults()));
125     cluster_func.erase();
126   }
127 }
128 
129 }  // namespace
130 
CreateTPUResourceReadForWritePass()131 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
132   return std::make_unique<TPUResourceReadForWritePass>();
133 }
134 
135 }  // namespace TFTPU
136 }  // namespace mlir
137