• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <tuple>
18 
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallVector.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
23 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
26 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
27 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h"
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
29 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
31 
32 namespace mlir {
33 namespace TFTPU {
34 namespace {
35 
36 constexpr char kReplicateSharding[] = "";
37 
38 struct TPUResourceReadsWritesPartitioningPass
39     : public TF::TPUResourceReadsWritesPartitioningPassBase<
40           TPUResourceReadsWritesPartitioningPass> {
41   void runOnFunction() override;
42 };
43 
AllResourceTypesHaveSubtypes(TypeRange resources)44 bool AllResourceTypesHaveSubtypes(TypeRange resources) {
45   for (Type resource : resources)
46     if (!llvm::hasSingleElement(resource.cast<TensorType>()
47                                     .getElementType()
48                                     .cast<TF::ResourceType>()
49                                     .getSubtypes()))
50       return false;
51 
52   return true;
53 }
54 
GetResourceSubtype(Type type)55 Type GetResourceSubtype(Type type) {
56   return type.cast<TensorType>()
57       .getElementType()
58       .cast<TF::ResourceType>()
59       .getSubtypes()
60       .front();
61 }
62 
GetResourceSubtype(Value resource)63 Type GetResourceSubtype(Value resource) {
64   return GetResourceSubtype(resource.getType());
65 }
66 
67 // Rewrites unpartitioned resource reads and writes to partitioned resource
68 // reads and writes. The TPU computation from the frontend is generated in such
69 // a way that resource operations operate on the unpartitioned resource handle
70 // (from a `tf.TPUReplicatedInput`). This results in resource reads and writes
71 // on the unpartitioned resource handle post resource op decomposition/lifting.
72 // Here the unpartitioned resource read and write is expanded to individual
73 // resource reads and writes per associated partitioned resource handle.
PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func)74 void PartitionResourceReadsWrites(tf_device::ClusterFuncOp cluster_func) {
75   bool use_spmd = false;
76   if (auto use_spmd_attr = cluster_func->getAttrOfType<BoolAttr>(
77           "use_spmd_for_xla_partitioning"))
78     use_spmd = use_spmd_attr.getValue();
79 
80   if (!use_spmd) return;
81 
82   OpBuilder builder(cluster_func);
83   // Rewrite results before rewriting operands as `tf.TPUPartitionedInput`
84   // resource handle results is an indicator for a partitioned resource
85   // variable. These `tf.TPUPartitionedInput` will be removed when rewriting
86   // the operands.
87   for (Value result : cluster_func.results()) {
88     if (!result.hasOneUse()) continue;
89     auto assign_var =
90         llvm::dyn_cast<TF::AssignVariableOp>(*result.getUsers().begin());
91     if (!assign_var || assign_var.value() != result) continue;
92     auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
93         assign_var.resource().getDefiningOp());
94     if (!partitioned_input ||
95         !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
96       continue;
97 
98     builder.setInsertionPoint(assign_var);
99     llvm::SmallVector<Type, 4> partitioned_output_types;
100     partitioned_output_types.reserve(partitioned_input.N());
101     for (Type input_type : partitioned_input.inputs().getTypes())
102       partitioned_output_types.push_back(GetResourceSubtype(input_type));
103     auto partitioned_output = builder.create<TF::TPUPartitionedOutputOp>(
104         cluster_func->getLoc(), partitioned_output_types, result,
105         partitioned_input.partition_dimAttr(),
106         partitioned_input._XlaShardingAttr());
107     for (auto resource_write :
108          llvm::zip(partitioned_input.inputs(), partitioned_output.output()))
109       builder.create<TF::AssignVariableOp>(
110           assign_var->getLoc(), /*resource=*/std::get<0>(resource_write),
111           /*value=*/std::get<1>(resource_write));
112     assign_var.erase();
113   }
114 
115   for (OpOperand& operand : cluster_func->getOpOperands()) {
116     auto read_var = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
117         operand.get().getDefiningOp());
118     if (!read_var || !read_var.value().hasOneUse()) continue;
119     auto partitioned_input = llvm::dyn_cast_or_null<TF::TPUPartitionedInputOp>(
120         read_var.resource().getDefiningOp());
121     if (!partitioned_input ||
122         !AllResourceTypesHaveSubtypes(partitioned_input.inputs().getTypes()))
123       continue;
124 
125     builder.setInsertionPoint(partitioned_input);
126     llvm::SmallVector<Value, 4> partitioned_reads;
127     for (Value input : partitioned_input.inputs()) {
128       auto partitioned_read = builder.create<TF::ReadVariableOp>(
129           read_var->getLoc(), GetResourceSubtype(input), input);
130       partitioned_reads.push_back(partitioned_read.value());
131     }
132     auto partitioned_read = builder.create<TF::TPUPartitionedInputOp>(
133         partitioned_input->getLoc(), read_var.value().getType(),
134         partitioned_reads, partitioned_input.partition_dimAttr(),
135         partitioned_input._XlaShardingAttr());
136     operand.set(partitioned_read);
137     read_var->erase();
138     if (partitioned_input->use_empty()) partitioned_input->erase();
139   }
140 }
141 
runOnFunction()142 void TPUResourceReadsWritesPartitioningPass::runOnFunction() {
143   llvm::SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
144   getFunction()->walk([&cluster_funcs](tf_device::ClusterFuncOp cluster_func) {
145     cluster_funcs.push_back(cluster_func);
146   });
147   for (tf_device::ClusterFuncOp cluster_func : cluster_funcs)
148     PartitionResourceReadsWrites(cluster_func);
149 }
150 
151 }  // namespace
152 
153 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUResourceReadsWritesPartitioningPass()154 CreateTPUResourceReadsWritesPartitioningPass() {
155   return std::make_unique<TPUResourceReadsWritesPartitioningPass>();
156 }
157 
158 }  // namespace TFTPU
159 }  // namespace mlir
160