1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <string>
18 #include <utility>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h" // from @llvm-project
28 #include "mlir/IR/Block.h" // from @llvm-project
29 #include "mlir/IR/Builders.h" // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
38
39 namespace mlir {
40 namespace TFTPU {
41
42 constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
43 constexpr char kPaddingMapAttr[] = "padding_map";
44
45 // This pass remaps and assigns padding maps to an encapsulated function's
46 // arguments from a `tf_device.cluster_func` `padding_map` attribute. Remapping
47 // is from replicated input index to encapsulated function's operand index
48 // (user).
49
50 namespace {
51 struct TPUDynamicPaddingMapper
52 : public PassWrapper<TPUDynamicPaddingMapper, OperationPass<ModuleOp>> {
53 void runOnOperation() override;
54 };
55
56 // Creates a mapping from replicated input index (in `tf_device.replicate` op)
57 // to `tf_device.cluster_func` operand index.
GetRemappedReplicatedInputIndices(tf_device::ClusterFuncOp cluster_func,tf_device::ReplicateOp replicate,ArrayAttr replicated_input_indices_attr)58 llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
59 tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate,
60 ArrayAttr replicated_input_indices_attr) {
61 Block* replicate_block = &replicate.GetBody();
62
63 llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
64 for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
65 if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>()) {
66 if (block_arg.getOwner() == replicate_block) {
67 int64_t replicated_input_index =
68 replicated_input_indices_attr[block_arg.getArgNumber()]
69 .cast<IntegerAttr>()
70 .getInt();
71 if (replicated_input_index != -1)
72 remapped_indices[replicated_input_index] = operand_and_idx.index();
73 }
74 }
75 }
76
77 return remapped_indices;
78 }
79
80 // Extracts `padding_map` from `tf_device.cluster_func` and remaps the
81 // associated replicated input indices to the encapsulated function operand
82 // indices. An error will be returned if an index is not found or parsing
83 // failed.
GetRemappedPaddings(tf_device::ClusterFuncOp cluster_func,const llvm::SmallDenseMap<int32_t,int32_t> & remapped_indices,llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap> * remapped_paddings)84 LogicalResult GetRemappedPaddings(
85 tf_device::ClusterFuncOp cluster_func,
86 const llvm::SmallDenseMap<int32_t, int32_t>& remapped_indices,
87 llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap>* remapped_paddings) {
88 auto bad_index_msg = [](int32_t index, llvm::StringRef arg_type,
89 int32_t arg_index) {
90 return llvm::formatv(
91 "bad '{0}' attribute at index {1}, {2} must be nonnegative, but "
92 "got {3}",
93 kPaddingMapAttr, index, arg_type, arg_index)
94 .str();
95 };
96
97 Attribute padding_map_attr = cluster_func->getAttr(kPaddingMapAttr);
98 if (!padding_map_attr) return success();
99
100 auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();
101 if (!padding_map)
102 return cluster_func.emitOpError()
103 << "requires '" << kPaddingMapAttr << "' array attribute";
104
105 for (auto padding_attr_and_idx : llvm::enumerate(padding_map)) {
106 int idx = padding_attr_and_idx.index();
107 auto& padding_attr = padding_attr_and_idx.value();
108 auto padding = padding_attr.dyn_cast<StringAttr>();
109 if (!padding)
110 return cluster_func.emitOpError(
111 llvm::formatv("bad '{0}' attribute at index {1}, not a string",
112 kPaddingMapAttr, padding_attr_and_idx.index()));
113
114 tensorflow::tpu::PaddingMap padding_proto;
115 if (!padding_proto.ParseFromString(padding.getValue().str()))
116 return cluster_func.emitOpError(llvm::formatv(
117 "bad '{0}' attribute at index {1}, failed to parse '{2}' as "
118 "tensorflow::tpu::PaddingMap",
119 kPaddingMapAttr, idx, padding.getValue()));
120
121 const int32_t arg_index = padding_proto.arg_index();
122 if (arg_index < 0)
123 return cluster_func.emitOpError()
124 << bad_index_msg(idx, "arg_index", arg_index);
125
126 const int32_t padding_arg_index = padding_proto.padding_arg_index();
127 if (padding_arg_index < 0)
128 return cluster_func.emitOpError()
129 << bad_index_msg(idx, "padding_arg_index", padding_arg_index);
130
131 auto arg_index_it = remapped_indices.find(arg_index);
132 // Skip unused arguments.
133 if (arg_index_it == remapped_indices.end()) continue;
134
135 auto padding_arg_index_it = remapped_indices.find(padding_arg_index);
136 if (padding_arg_index_it == remapped_indices.end()) {
137 cluster_func.emitWarning(llvm::formatv(
138 "bad '{0}' attribute at index {1}, unused padding_arg_index {2}",
139 kPaddingMapAttr, idx, padding_arg_index));
140 continue;
141 }
142
143 padding_proto.set_arg_index(arg_index_it->second);
144 padding_proto.set_padding_arg_index(padding_arg_index_it->getSecond());
145 remapped_paddings->push_back(std::move(padding_proto));
146 }
147
148 return success();
149 }
150
151 // Inserts padding maps for relevant arguments as argument attributes on the
152 // encapsulated function. The padding maps will be in the form of:
153 // %arg0 : type {mhlo.padding_map = {shape_indices = [...],
154 // padding_arg_indices = [...]}}
AnnotateFunctionArgumentsWithPaddings(FuncOp func,llvm::ArrayRef<tensorflow::tpu::PaddingMap> remapped_paddings)155 void AnnotateFunctionArgumentsWithPaddings(
156 FuncOp func,
157 llvm::ArrayRef<tensorflow::tpu::PaddingMap> remapped_paddings) {
158 // Group paddings by arg index.
159 llvm::SmallDenseMap<int32_t, std::pair<llvm::SmallVector<int32_t, 4>,
160 llvm::SmallVector<int32_t, 4>>>
161 paddings;
162 for (const auto& padding : remapped_paddings) {
163 auto& it = paddings[padding.arg_index()];
164 it.first.push_back(padding.shape_index());
165 it.second.push_back(padding.padding_arg_index());
166 }
167
168 Builder builder(func.getContext());
169 for (const auto& padding : paddings) {
170 auto shape_indices = builder.getNamedAttr(
171 "shape_indices", builder.getI32ArrayAttr(padding.getSecond().first));
172 auto padding_arg_indices = builder.getNamedAttr(
173 "padding_arg_indices",
174 builder.getI32ArrayAttr(padding.getSecond().second));
175 func.setArgAttr(
176 padding.getFirst(), "mhlo.padding_map",
177 builder.getDictionaryAttr({shape_indices, padding_arg_indices}));
178 }
179 }
180
RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,SymbolTable * symbol_table)181 LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
182 SymbolTable* symbol_table) {
183 auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
184 // LaunchFunc is not replicated, there will be no padding.
185 if (!replicate) return success();
186
187 auto func = symbol_table->lookup<FuncOp>(cluster_func.func());
188 if (!func) return success();
189
190 auto replicated_input_indices_attr =
191 replicate->getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
192 if (!replicated_input_indices_attr) return success();
193
194 llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
195 GetRemappedReplicatedInputIndices(cluster_func, replicate,
196 replicated_input_indices_attr);
197
198 llvm::SmallVector<tensorflow::tpu::PaddingMap, 4> remapped_paddings;
199 if (failed(GetRemappedPaddings(cluster_func, remapped_indices,
200 &remapped_paddings)))
201 return failure();
202
203 AnnotateFunctionArgumentsWithPaddings(func, remapped_paddings);
204
205 return success();
206 }
207
runOnOperation()208 void TPUDynamicPaddingMapper::runOnOperation() {
209 ModuleOp module = getOperation();
210 SymbolTable symbol_table(module);
211 module.walk([&](tf_device::ClusterFuncOp cluster_func) {
212 (void)RemapAndAssignPaddingMaps(cluster_func, &symbol_table);
213 });
214 }
215 } // anonymous namespace
216
CreateTPUDynamicPaddingMapperPass()217 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass() {
218 return std::make_unique<TPUDynamicPaddingMapper>();
219 }
220
221 static PassRegistration<TPUDynamicPaddingMapper> pass(
222 "tf-tpu-dynamic-padding",
223 "Remaps padding map from replicated inputs to argument ordering on "
224 "encapsulated function");
225
226 } // namespace TFTPU
227 } // namespace mlir
228