1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "absl/container/flat_hash_set.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "mlir/IR/Attributes.h" // from @llvm-project
19 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
20 #include "mlir/IR/Diagnostics.h" // from @llvm-project
21 #include "mlir/IR/MLIRContext.h" // from @llvm-project
22 #include "mlir/IR/PatternMatch.h" // from @llvm-project
23 #include "mlir/IR/Region.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Support/LogicalResult.h" // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
34 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
35
36 namespace mlir {
37 namespace TF {
38 namespace {
39
40 // Returns true if the given op is TF/XLA communication op in the old bridge.
IsCommunicationOp(Operation * op)41 bool IsCommunicationOp(Operation* op) {
42 return isa<TF::XlaHostComputeOp, TF::XlaSendToHostOp, TF::XlaRecvFromHostOp>(
43 op);
44 }
45
46 // Returns true if the given op is one of ops supported to have communication
47 // subcomputation in the TF/XLA bridge.
SupportsCommunicationComputation(Operation * op)48 bool SupportsCommunicationComputation(Operation* op) {
49 return isa<TF::IfRegionOp, TF::WhileRegionOp, TF::CaseRegionOp,
50 TF::StatefulPartitionedCallOp, TF::PartitionedCallOp,
51 TF::LegacyCallOp>(op);
52 }
53
54 class PrepareTpuComputationForTfExportPass
55 : public PrepareTpuComputationForTfExportPassBase<
56 PrepareTpuComputationForTfExportPass> {
57 void runOnOperation() override;
58 };
59
60 class RewriteXlaHostComputeMlir
61 : public OpRewritePattern<TF::_XlaHostComputeMlirOp> {
62 public:
63 using OpRewritePattern<TF::_XlaHostComputeMlirOp>::OpRewritePattern;
64
matchAndRewrite(TF::_XlaHostComputeMlirOp op,PatternRewriter & rewriter) const65 LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op,
66 PatternRewriter& rewriter) const override {
67 llvm::SmallVector<Attribute> shape_attrs;
68 shape_attrs.reserve(op.getNumResults());
69 for (Type ty : op.getResultTypes()) {
70 shape_attrs.push_back(
71 TF::ShapeAttr::get(rewriter.getContext(), ty.cast<ShapedType>()));
72 }
73
74 // Clone the `host_func` in the `host_mlir_module` attribute if it exists
75 // and use it for `shape_inference_graph` attribute on XlaHostCompute.
76 FuncOp cloned_func;
77 SymbolTable manager(op->getParentOfType<ModuleOp>());
78 StringRef host_module = op.host_mlir_module();
79 if (!host_module.empty()) {
80 mlir::OwningModuleRef module_for_func;
81
82 FuncOp func = op.GetHostFunc(&module_for_func);
83
84 OpBuilder::InsertionGuard guard(rewriter);
85 rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
86 cloned_func =
87 llvm::dyn_cast_or_null<FuncOp>(rewriter.clone(*func.getOperation()));
88 manager.insert(cloned_func);
89 rewriter.setInsertionPointToStart(&cloned_func.body().front());
90 auto result_type =
91 RankedTensorType::get({3}, rewriter.getType<TF::StringType>());
92 auto dynamic_key =
93 rewriter.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
94 func.getLoc(), /*program=*/result_type, llvm::ArrayRef<Value>{});
95
96 auto recv_at_host = rewriter.create<TF::_XlaRecvAtHostOp>(
97 func.getLoc(), op.getOperandTypes(), /*dynamic_key=*/dynamic_key,
98 op.send_keyAttr(),
99 /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
100 for (auto result :
101 llvm::zip(cloned_func.getArguments(), recv_at_host->getResults())) {
102 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
103 }
104
105 rewriter.setInsertionPoint(cloned_func.body().front().getTerminator());
106 rewriter.create<TF::_XlaSendFromHostOp>(
107 func.getLoc(),
108 cloned_func.body().front().getTerminator()->getOperands(),
109 /*dynamic_key=*/dynamic_key, op.recv_keyAttr(),
110 /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
111 }
112
113 constexpr int64_t kDefaultCostEstimate = 1000000;
114 rewriter.replaceOpWithNewOp<TF::XlaHostComputeOp>(
115 op, op.getResultTypes(), op.inputs(),
116 /*ancestors=*/rewriter.getArrayAttr({}),
117 rewriter.getArrayAttr(shape_attrs),
118 /*shape_inference_graph=*/
119 cloned_func ? rewriter.getSymbolRefAttr(cloned_func) : SymbolRefAttr(),
120 /*key=*/rewriter.getStringAttr(""), op.send_keyAttr(),
121 op.recv_keyAttr(),
122 /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate),
123 op.tpu_coreAttr());
124 return success();
125 }
126 };
127
UpdateArgAttributes(mlir::FuncOp func)128 void UpdateArgAttributes(mlir::FuncOp func) {
129 OpBuilder builder(func.getBody());
130 for (int i = 0; i < func.getNumArguments(); ++i) {
131 constexpr char kShardingAttr[] = "mhlo.sharding";
132 if (auto sharding =
133 func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
134 if (!sharding.getValue().empty()) {
135 BlockArgument arg = func.getArgument(i);
136 // TODO(hinsu): Instead of setting both 'sharding' and '_XlaSharding'
137 // attributes, only set the 'sharding' attribute. Both attributes are
138 // currently required as the XlaSharding xla op kernel doesn't use the
139 // 'sharding' attribute.
140 auto updated_arg = builder.create<TF::XlaShardingOp>(
141 func.getLoc(), arg.getType(), arg, sharding, sharding);
142 func.getArgument(i).replaceAllUsesExcept(
143 updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
144 }
145
146 func.removeArgAttr(i, builder.getIdentifier(kShardingAttr));
147 }
148 }
149 }
150
RewriteCommunicationOps(ModuleOp module)151 LogicalResult RewriteCommunicationOps(ModuleOp module) {
152 MLIRContext* ctx = module.getContext();
153 mlir::OwningRewritePatternList patterns(ctx);
154 patterns.insert<RewriteXlaHostComputeMlir>(ctx);
155 if (failed(mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
156 return module.emitError("failed to apply tf export preparation patterns");
157 }
158
159 // TODO(hinsu): Investigate if the semantics of keys for these communication
160 // ops between the old bridge and new bridge can be reconciled.
161 module.walk([&](Operation* op) {
162 if (isa<TF::XlaSendToHostOp>(op)) {
163 StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
164 auto new_key = StringAttr::get(ctx, old_key.str() + "_dtoh_0");
165 op->setAttr("key", new_key);
166 } else if (isa<TF::XlaRecvFromHostOp>(op)) {
167 StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
168 auto new_key = StringAttr::get(ctx, old_key.str() + "_htod_0");
169 op->setAttr("key", new_key);
170 }
171 });
172 return success();
173 }
174
175 // Sets token input node names attribute and their corresponding original node
176 // names for tf/xla communication related ops. These attributes are used to
177 // order operations on device. First op in the region should have a special
178 // argument token and then remaining operations should have node name of the
179 // previous communication ops.
SetTokenInputAttrs(ModuleOp module)180 LogicalResult SetTokenInputAttrs(ModuleOp module) {
181 // Collect all the ops that needs to have token input names attributes. These
182 // ops are communication ops and all their parent ops via nesting or function
183 // calls. For example, IfRegion op and PartitionedCall op.
184 std::vector<Operation*> worklist;
185 absl::flat_hash_set<Operation*> ops_with_tokens;
186 module.walk([&](Operation* op) {
187 if (IsCommunicationOp(op)) {
188 ops_with_tokens.insert(op);
189 worklist.push_back(op);
190 }
191 });
192
193 SymbolTableCollection table;
194 SymbolUserMap symbol_map(table, module);
195
196 // Regions that contains ops requiring token input attributes.
197 absl::flat_hash_set<Region*> regions_with_token;
198 while (!worklist.empty()) {
199 Operation* op = worklist.back();
200 worklist.pop_back();
201
202 Region* region = op->getParentRegion();
203 regions_with_token.insert(region);
204
205 // If the parent is not a FuncOp, then add the parent op containing a region
206 // to worklist.
207 Operation* parent = region->getParentOp();
208 if (!isa<FuncOp>(parent)) {
209 if (ops_with_tokens.insert(parent).second) {
210 worklist.push_back(parent);
211 }
212 continue;
213 }
214
215 // For functions, get all the users and add them to the worklist.
216 for (auto& user : symbol_map.getUsers(parent)) {
217 if (ops_with_tokens.insert(user).second) {
218 worklist.push_back(user);
219 }
220 }
221 }
222
223 // Use name mapper to uniquely name all ops in the module as export to
224 // TensorFlow graph may change node names. These op names here doesn't need to
225 // match the actual names in the graph as this sets original node name
226 // attribute for all the relevant nodes.
227 tensorflow::OpOrArgLocNameMapper name_mapper;
228 MLIRContext* ctx = module.getContext();
229 for (Region* region : regions_with_token) {
230 // Initialize the token with the special argument token. This gets mapped to
231 // input token in the parent op or a new token for the entry computation.
232 auto token = StringAttr::get(ctx, tensorflow::kXlaTokenArgNodeName);
233 for (Operation& op : region->getOps()) {
234 // Only communication related ops that needs to have token should have the
235 // extra attribute.
236 if (!ops_with_tokens.contains(&op)) continue;
237
238 if (!IsCommunicationOp(&op) && !SupportsCommunicationComputation(&op)) {
239 return op.emitOpError(
240 "does not support subcomputations with tf/xla communication ops");
241 }
242
243 op.setAttr(tensorflow::kXlaTokenInputNodesAttrName,
244 ArrayAttr::get(ctx, {token}));
245
246 auto node_name = StringAttr::get(ctx, name_mapper.GetUniqueName(&op));
247 op.setAttr(tensorflow::kXlaOriginalOutsideCompilationNodeName, node_name);
248 token = node_name;
249 }
250 }
251 return success();
252 }
253
runOnOperation()254 void PrepareTpuComputationForTfExportPass::runOnOperation() {
255 ModuleOp module = getOperation();
256
257 for (FuncOp func : module.getOps<FuncOp>()) {
258 UpdateArgAttributes(func);
259 }
260
261 // First rewrite communication ops used in the new bridge to match old bridge
262 // semantics and then set token input node names attributes on the supported
263 // ops.
264 if (failed(RewriteCommunicationOps(module)) ||
265 failed(SetTokenInputAttrs(module))) {
266 signalPassFailure();
267 return;
268 }
269 }
270
271 } // namespace
272
273 std::unique_ptr<OperationPass<ModuleOp>>
CreatePrepareTpuComputationForTfExportPass()274 CreatePrepareTpuComputationForTfExportPass() {
275 return std::make_unique<PrepareTpuComputationForTfExportPass>();
276 }
277
278 } // namespace TF
279 } // namespace mlir
280