• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
22 #include "mlir/Pass/Pass.h"  // from @llvm-project
23 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
24 #include "mlir/Support/LLVM.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
28 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
29 
30 namespace mlir {
31 namespace TFTPU {
32 
33 // A pass that finds TPU clusters with write only resource access and adds an
34 // associated resource read, so the resource can later be fused into TPUExecute.
35 namespace {
36 struct TPUResourceReadForWritePass
37     : public TF::TPUResourceReadForWritePassBase<TPUResourceReadForWritePass> {
38   void runOnOperation() override;
39 };
40 
41 // Helper struct holding a resource value and its associated type.
42 struct ResourceValueAndSubtype {
43   Value resource;
44   Type subtype;
45 };
46 
47 // Finds resource handle and type for result if result writes to a resource.
GetResourceWriteResult(tf_device::ClusterFuncOp cluster_func,Value result)48 ResourceValueAndSubtype GetResourceWriteResult(
49     tf_device::ClusterFuncOp cluster_func, Value result) {
50   ResourceValueAndSubtype resource;
51   if (!result.hasOneUse()) return resource;
52   Operation* result_user = *result.getUsers().begin();
53   auto assign_var = dyn_cast<TF::AssignVariableOp>(result_user);
54   if (!assign_var) return resource;
55 
56   auto handle = assign_var.resource();
57   // Skip result if cluster writes to the same variable via multiple results.
58   for (Operation* handle_user : handle.getUsers()) {
59     if (handle_user == assign_var) continue;
60     auto assign_var_user = dyn_cast<TF::AssignVariableOp>(handle_user);
61     if (!assign_var_user) continue;
62     if (assign_var_user.value().getDefiningOp() == cluster_func)
63       return resource;
64   }
65 
66   resource.resource = assign_var.resource();
67   resource.subtype = assign_var.value().getType();
68   return resource;
69 }
70 
71 // Checks if resource is read by TPU cluster.
ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,Value resource)72 bool ClusterFuncHasResourceRead(tf_device::ClusterFuncOp cluster_func,
73                                 Value resource) {
74   for (Operation* resource_user : resource.getUsers())
75     if (auto read = dyn_cast<TF::ReadVariableOp>(resource_user))
76       for (Operation* read_user : read.value().getUsers())
77         if (read_user == cluster_func) return true;
78 
79   return false;
80 }
81 
runOnOperation()82 void TPUResourceReadForWritePass::runOnOperation() {
83   SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
84   getOperation().walk([&](tf_device::ClusterFuncOp cluster_func) {
85     cluster_funcs.push_back(cluster_func);
86   });
87 
88   OpBuilder builder(&getContext());
89   // Add resource reads for resource writes from TPU cluster where for such
90   // resources the TPU cluster does not read from.
91   for (tf_device::ClusterFuncOp cluster_func : cluster_funcs) {
92     builder.setInsertionPoint(cluster_func);
93 
94     SmallVector<Value, 4> read_operands;
95     for (Value result : cluster_func.getResults()) {
96       // TODO(lyandy): Update pass to use resource alias analysis.
97       auto resource_and_type = GetResourceWriteResult(cluster_func, result);
98       if (!resource_and_type.resource) continue;
99       if (ClusterFuncHasResourceRead(cluster_func, resource_and_type.resource))
100         continue;
101       auto new_read = builder.create<TF::ReadVariableOp>(
102           resource_and_type.resource.getLoc(), resource_and_type.subtype,
103           resource_and_type.resource);
104       read_operands.push_back(new_read.value());
105     }
106 
107     if (read_operands.empty()) continue;
108 
109     // Update caller and function types with new read operands.
110     auto operands = llvm::to_vector<4>(cluster_func.getOperands());
111     operands.append(read_operands.begin(), read_operands.end());
112 
113     auto new_cluster_func = builder.create<tf_device::ClusterFuncOp>(
114         cluster_func.getLoc(), cluster_func.getResultTypes(), operands,
115         cluster_func.getAttrs());
116     cluster_func.replaceAllUsesWith(new_cluster_func);
117     FuncOp func = cluster_func.getFunc();
118     Block& block = func.front();
119     for (Value read_operand : read_operands)
120       block.addArgument(read_operand.getType());
121 
122     func.setType(FunctionType::get(&getContext(), block.getArgumentTypes(),
123                                    func.getCallableResults()));
124     cluster_func.erase();
125   }
126 }
127 
128 }  // namespace
129 
CreateTPUResourceReadForWritePass()130 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUResourceReadForWritePass() {
131   return std::make_unique<TPUResourceReadForWritePass>();
132 }
133 
134 }  // namespace TFTPU
135 }  // namespace mlir
136