• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <string>
18 #include <utility>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/ADT/StringRef.h"
25 #include "llvm/Support/Casting.h"
26 #include "llvm/Support/FormatVariadic.h"
27 #include "mlir/IR/Attributes.h"  // from @llvm-project
28 #include "mlir/IR/Block.h"  // from @llvm-project
29 #include "mlir/IR/Builders.h"  // from @llvm-project
30 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
34 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
38 
39 namespace mlir {
40 namespace TFTPU {
41 
42 constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
43 constexpr char kPaddingMapAttr[] = "padding_map";
44 
45 // This pass remaps and assigns padding maps to an encapsulated function's
46 // arguments from a `tf_device.cluster_func` `padding_map` attribute. Remapping
47 // is from replicated input index to encapsulated function's operand index
48 // (user).
49 
50 namespace {
51 struct TPUDynamicPaddingMapper
52     : public PassWrapper<TPUDynamicPaddingMapper, OperationPass<ModuleOp>> {
53   void runOnOperation() override;
54 };
55 
56 // Creates a mapping from replicated input index (in `tf_device.replicate` op)
57 // to `tf_device.cluster_func` operand index.
GetRemappedReplicatedInputIndices(tf_device::ClusterFuncOp cluster_func,tf_device::ReplicateOp replicate,ArrayAttr replicated_input_indices_attr)58 llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
59     tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate,
60     ArrayAttr replicated_input_indices_attr) {
61   Block* replicate_block = &replicate.GetBody();
62 
63   llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
64   for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
65     if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>()) {
66       if (block_arg.getOwner() == replicate_block) {
67         int64_t replicated_input_index =
68             replicated_input_indices_attr[block_arg.getArgNumber()]
69                 .cast<IntegerAttr>()
70                 .getInt();
71         if (replicated_input_index != -1)
72           remapped_indices[replicated_input_index] = operand_and_idx.index();
73       }
74     }
75   }
76 
77   return remapped_indices;
78 }
79 
80 // Extracts `padding_map` from `tf_device.cluster_func` and remaps the
81 // associated replicated input indices to the encapsulated function operand
82 // indices. An error will be returned if an index is not found or parsing
83 // failed.
GetRemappedPaddings(tf_device::ClusterFuncOp cluster_func,const llvm::SmallDenseMap<int32_t,int32_t> & remapped_indices,llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap> * remapped_paddings)84 LogicalResult GetRemappedPaddings(
85     tf_device::ClusterFuncOp cluster_func,
86     const llvm::SmallDenseMap<int32_t, int32_t>& remapped_indices,
87     llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap>* remapped_paddings) {
88   auto bad_index_msg = [](int32_t index, llvm::StringRef arg_type,
89                           int32_t arg_index) {
90     return llvm::formatv(
91                "bad '{0}' attribute at index {1}, {2} must be nonnegative, but "
92                "got {3}",
93                kPaddingMapAttr, index, arg_type, arg_index)
94         .str();
95   };
96 
97   Attribute padding_map_attr = cluster_func->getAttr(kPaddingMapAttr);
98   if (!padding_map_attr) return success();
99 
100   auto padding_map = padding_map_attr.dyn_cast<ArrayAttr>();
101   if (!padding_map)
102     return cluster_func.emitOpError()
103            << "requires '" << kPaddingMapAttr << "' array attribute";
104 
105   for (auto padding_attr_and_idx : llvm::enumerate(padding_map)) {
106     int idx = padding_attr_and_idx.index();
107     auto& padding_attr = padding_attr_and_idx.value();
108     auto padding = padding_attr.dyn_cast<StringAttr>();
109     if (!padding)
110       return cluster_func.emitOpError(
111           llvm::formatv("bad '{0}' attribute at index {1}, not a string",
112                         kPaddingMapAttr, padding_attr_and_idx.index()));
113 
114     tensorflow::tpu::PaddingMap padding_proto;
115     if (!padding_proto.ParseFromString(padding.getValue().str()))
116       return cluster_func.emitOpError(llvm::formatv(
117           "bad '{0}' attribute at index {1}, failed to parse '{2}' as "
118           "tensorflow::tpu::PaddingMap",
119           kPaddingMapAttr, idx, padding.getValue()));
120 
121     const int32_t arg_index = padding_proto.arg_index();
122     if (arg_index < 0)
123       return cluster_func.emitOpError()
124              << bad_index_msg(idx, "arg_index", arg_index);
125 
126     const int32_t padding_arg_index = padding_proto.padding_arg_index();
127     if (padding_arg_index < 0)
128       return cluster_func.emitOpError()
129              << bad_index_msg(idx, "padding_arg_index", padding_arg_index);
130 
131     auto arg_index_it = remapped_indices.find(arg_index);
132     // Skip unused arguments.
133     if (arg_index_it == remapped_indices.end()) continue;
134 
135     auto padding_arg_index_it = remapped_indices.find(padding_arg_index);
136     if (padding_arg_index_it == remapped_indices.end()) {
137       cluster_func.emitWarning(llvm::formatv(
138           "bad '{0}' attribute at index {1}, unused padding_arg_index {2}",
139           kPaddingMapAttr, idx, padding_arg_index));
140       continue;
141     }
142 
143     padding_proto.set_arg_index(arg_index_it->second);
144     padding_proto.set_padding_arg_index(padding_arg_index_it->getSecond());
145     remapped_paddings->push_back(std::move(padding_proto));
146   }
147 
148   return success();
149 }
150 
151 // Inserts padding maps for relevant arguments as argument attributes on the
152 // encapsulated function. The padding maps will be in the form of:
153 //   %arg0 : type {mhlo.padding_map = {shape_indices = [...],
154 //                                        padding_arg_indices = [...]}}
AnnotateFunctionArgumentsWithPaddings(FuncOp func,llvm::ArrayRef<tensorflow::tpu::PaddingMap> remapped_paddings)155 void AnnotateFunctionArgumentsWithPaddings(
156     FuncOp func,
157     llvm::ArrayRef<tensorflow::tpu::PaddingMap> remapped_paddings) {
158   // Group paddings by arg index.
159   llvm::SmallDenseMap<int32_t, std::pair<llvm::SmallVector<int32_t, 4>,
160                                          llvm::SmallVector<int32_t, 4>>>
161       paddings;
162   for (const auto& padding : remapped_paddings) {
163     auto& it = paddings[padding.arg_index()];
164     it.first.push_back(padding.shape_index());
165     it.second.push_back(padding.padding_arg_index());
166   }
167 
168   Builder builder(func.getContext());
169   for (const auto& padding : paddings) {
170     auto shape_indices = builder.getNamedAttr(
171         "shape_indices", builder.getI32ArrayAttr(padding.getSecond().first));
172     auto padding_arg_indices = builder.getNamedAttr(
173         "padding_arg_indices",
174         builder.getI32ArrayAttr(padding.getSecond().second));
175     func.setArgAttr(
176         padding.getFirst(), "mhlo.padding_map",
177         builder.getDictionaryAttr({shape_indices, padding_arg_indices}));
178   }
179 }
180 
RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,SymbolTable * symbol_table)181 LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
182                                         SymbolTable* symbol_table) {
183   auto replicate = cluster_func->getParentOfType<tf_device::ReplicateOp>();
184   // LaunchFunc is not replicated, there will be no padding.
185   if (!replicate) return success();
186 
187   auto func = symbol_table->lookup<FuncOp>(cluster_func.func());
188   if (!func) return success();
189 
190   auto replicated_input_indices_attr =
191       replicate->getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
192   if (!replicated_input_indices_attr) return success();
193 
194   llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
195       GetRemappedReplicatedInputIndices(cluster_func, replicate,
196                                         replicated_input_indices_attr);
197 
198   llvm::SmallVector<tensorflow::tpu::PaddingMap, 4> remapped_paddings;
199   if (failed(GetRemappedPaddings(cluster_func, remapped_indices,
200                                  &remapped_paddings)))
201     return failure();
202 
203   AnnotateFunctionArgumentsWithPaddings(func, remapped_paddings);
204 
205   return success();
206 }
207 
runOnOperation()208 void TPUDynamicPaddingMapper::runOnOperation() {
209   ModuleOp module = getOperation();
210   SymbolTable symbol_table(module);
211   module.walk([&](tf_device::ClusterFuncOp cluster_func) {
212     (void)RemapAndAssignPaddingMaps(cluster_func, &symbol_table);
213   });
214 }
215 }  // anonymous namespace
216 
CreateTPUDynamicPaddingMapperPass()217 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicPaddingMapperPass() {
218   return std::make_unique<TPUDynamicPaddingMapper>();
219 }
220 
221 static PassRegistration<TPUDynamicPaddingMapper> pass(
222     "tf-tpu-dynamic-padding",
223     "Remaps padding map from replicated inputs to argument ordering on "
224     "encapsulated function");
225 
226 }  // namespace TFTPU
227 }  // namespace mlir
228