1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <tuple>
18
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
31
32 namespace mlir {
33 namespace TFTPU {
34 namespace {
35
36 constexpr char kReplicateSharding[] = "";
37
38 struct TPUResourceReadsWritesPartitioningPass
39 : public TF::TPUResourceReadsWritesPartitioningPassBase<
40 TPUResourceReadsWritesPartitioningPass> {
41 void runOnFunction() override;
42 };
43
AllResourceTypesHaveSubtypes(TypeRange resources)44 bool AllResourceTypesHaveSubtypes(TypeRange resources) {
45 for (Type resource : resources)
46 if (!llvm::hasSingleElement(resource.cast<TensorType>()
47 .getElementType()
48 .cast<TF::ResourceType>()
49 .getSubtypes()))
50 return false;
51
52 return true;
53 }
54
GetResourceSubtype(Type type)55 Type GetResourceSubtype(Type type) {
56 return type.cast<TensorType>()
57 .getElementType()
58 .cast<TF::ResourceType>()
59 .getSubtypes()
60 .front();
61 }
62
GetResourceSubtype(Value resource)63 Type GetResourceSubtype(Value resource) {
64 return GetResourceSubtype(resource.getType());
65 }
66
67 // Rewrites unpartitioned resource reads and writes to partitioned resource
68 // reads and writes. The TPU computation from the frontend is generated in such
69 // a way that resource operations operate on the unpartitioned resource handle
70 // (from a `tf.TPUReplicatedInput`). This results in resource reads and writes
71 // on the unpartitioned resource handle post resource op decomposition/lifting.
72 // Here the unpartitioned resource read and write is expanded to individual
73 // resource reads and writes per associated partitioned resource handle.
PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func)74 void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
75 bool use_spmd = false;
76 if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
77 "use_spmd_for_xla_partitioning"))
78 use_spmd = use_spmd_attr.getValue();
79
80 if (!use_spmd) return;
81
82 OpBuilder builder(cluster_func);
83 // Rewrite results before rewriting operands as `tf.TPUPartitionedInput`
84 // resource handle results is an indicator for a partitioned resource
85 // variable. These `tf.TPUPartitionedInput` will be removed when rewriting
86 // the operands.
87 for (Value result : cluster_func.results()) {
88 if (!result.hasOneUse()) continue;
89 auto assign_var =
90 llvm::dyn_cast<TF::AssignVariableOp>(*result.getUsers().begin());
91 if (!assign_var || assign_var.value() != result) continue;
92 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
93 assign_var.resource().getDefiningOp());
94 if (!partitioned_input ||
95 !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
96 continue;
97
98 builder.setInsertionPoint(assign_var);
99 llvm::SmallVector<Type, 4> partitioned_output_types;
100 partitioned_output_types.reserve(partitioned_input.N());
101 for (Type input_type : partitioned_input.inputs().getTypes())
102 partitioned_output_types.push_back(GetResourceSubtype(input_type));
103 auto partitioned_output = builder.create<TF::TPUPartitionedOutputOp>(
104 cluster_func->getLoc(), partitioned_output_types, result,
105 partitioned_input.partition_dimAttr(),
106 partitioned_input._XlaShardingAttr());
107 for (auto resource_write :
108 llvm::zip(partitioned_input.inputs(), partitioned_output.output()))
109 builder.create<TF::AssignVariableOp>(
110 assign_var->getLoc(), /*resource=*/std::get<0>(resource_write),
111 /*value=*/std::get<1>(resource_write));
112 assign_var.erase();
113 }
114
115 for (OpOperand& operand : cluster_func->getOpOperands()) {
116 auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
117 operand.get().getDefiningOp());
118 if (!read_var || !read_var.value().hasOneUse()) continue;
119 auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
120 read_var.resource().getDefiningOp());
121 if (!partitioned_input ||
122 !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
123 continue;
124
125 builder.setInsertionPoint(partitioned_input);
126 llvm::SmallVector<Value, 4> partitioned_reads;
127 for (Value input : partitioned_input.inputs()) {
128 auto partitioned_read = builder.create<TF::ReadVariableOp>(
129 read_var->getLoc(), GetResourceSubtype(input), input);
130 partitioned_reads.push_back(partitioned_read.value());
131 }
132 auto partitioned_read = builder.create<TF::TPUPartitionedInputOp>(
133 partitioned_input->getLoc(), read_var.value().getType(),
134 partitioned_reads, partitioned_input.partition_dimAttr(),
135 partitioned_input._XlaShardingAttr());
136 operand.set(partitioned_read);
137 read_var->erase();
138 if (partitioned_input->use_empty()) partitioned_input->erase();
139 }
140 }
141
runOnFunction()142 void TPUResourceReadsWritesPartitioningPass::runOnFunction() {
143 llvm::SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
144 getFunction()->walk([&cluster_funcs](tf_device::ClusterFuncOp cluster_func) {
145 cluster_funcs.push_back(cluster_func);
146 });
147 for (tf_device::ClusterFuncOp cluster_func : cluster_funcs)
148 PartitionResourceReadsWrites(cluster_func);
149 }
150
151 } // namespace
152
153 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUResourceReadsWritesPartitioningPass()154 CreateTPUResourceReadsWritesPartitioningPass() {
155 return std::make_unique<TPUResourceReadsWritesPartitioningPass>();
156 }
157
158 } // namespace TFTPU
159 } // namespace mlir
160