1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's communication
17 // ops (TF/XLA) to the HLO dialect.
18
19 #include <memory>
20 #include <string>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/IR/Visitors.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
42 #include "tensorflow/compiler/xla/client/sharding_builder.h"
43 #include "tensorflow/compiler/xla/primitive_util.h"
44
45 namespace mlir {
46 namespace mhlo {
47
48 namespace {
49 constexpr char kShardingAttr[] = "mhlo.sharding";
50 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
51 const char kXlaHostTransferRendezvousNameAttr[] =
52 "_xla_host_transfer_rendezvous";
53 const char kXlaHostTransferOriginalTypeAttr[] =
54 "_xla_host_transfer_original_type";
55
56 // A pass that legalizes TF/XLA communication ops, propagate their respective
57 // tokens (for ordering), and rewrite their respective functions and control
58 // flow ops when necessary.
59 // Note, this currently does not handle nested modules/functions or region based
60 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
61 class LegalizeTFCommunication
62 : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const63 void getDependentDialects(DialectRegistry& registry) const override {
64 registry.insert<mhlo::MhloDialect>();
65 }
66
67 public:
68 void runOnOperation() override;
69 };
70
71 // Checks if an op is a TF/XLA communication op.
IsCommunicationOp(Operation * op)72 bool IsCommunicationOp(Operation* op) {
73 return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
74 TF::XlaRecvFromHostOp>(op);
75 }
76
77 // Checks if an op is a supported HLO control flow op.
IsControlFlowOp(Operation * op)78 bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
79
80 // Collects control flow op ancestors of a given op, up until FuncOp. If any
81 // ancestor is not a control flow op or a FuncOp, or of a single block region,
82 // an error will be returned.
GetControlFlowAncestors(Operation * op,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)83 LogicalResult GetControlFlowAncestors(
84 Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
85 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
86 Block* block = op->getBlock();
87 Operation* parent = block->getParentOp();
88 while (block && parent && !isa<FuncOp>(parent)) {
89 if (!IsControlFlowOp(parent))
90 return op->emitOpError()
91 << "expects ancestor(s) to be of ['" << IfOp::getOperationName()
92 << "', '" << FuncOp::getOperationName() << "']";
93
94 if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
95 return op->emitOpError() << "expects single block region ancestor(s)";
96
97 control_flow_ops.insert(parent);
98 control_flow_blocks.insert(block);
99
100 parent = block->getParentOp();
101 block = parent->getBlock();
102 }
103 return success();
104 }
105
106 // Finds communication ops in a function. `control_flow_ops` and
107 // `control_flow_blocks` will be populated with control flow op ancestors for
108 // every communication op.
FindCommunicationOps(FuncOp func,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool & has_communication_ops)109 LogicalResult FindCommunicationOps(
110 FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
111 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
112 bool& has_communication_ops) {
113 auto result = func.walk([&](Operation* op) {
114 if (!IsCommunicationOp(op)) return WalkResult::advance();
115 has_communication_ops = true;
116 if (failed(
117 GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
118 return WalkResult::interrupt();
119 return WalkResult::advance();
120 });
121 return failure(result.wasInterrupted());
122 }
123
124 // Helper struct holding a function to be rewritten, it's control flow ops that
125 // lead to a communication op or function call with a communication op
126 // (transitively), and an optional clone of itself. If `clone` is set, function
127 // calls to `original` will be replaced with `clone`.
128 struct FuncToRewrite {
129 FuncOp original;
130 llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
131 llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
132 FuncOp clone;
133 };
134
135 // Finds all functions that need to be rewritten with communication ops and
136 // and associated tokens.
GetFunctionsToRewrite(ModuleOp module,llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)137 LogicalResult GetFunctionsToRewrite(
138 ModuleOp module,
139 llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
140 // Find functions containing communication ops.
141 SmallVector<FuncOp, 4> funcs_to_visit;
142 for (FuncOp func : module.getOps<FuncOp>()) {
143 FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
144 /*control_flow_blocks=*/{},
145 /*clone=*/nullptr};
146 bool has_communication_ops = false;
147 if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
148 func_to_rewrite.control_flow_blocks,
149 has_communication_ops)))
150 return failure();
151
152 if (!has_communication_ops) continue;
153 funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
154 funcs_to_visit.push_back(func);
155 }
156
157 // Find functions that call functions with communication ops, transitively.
158 while (!funcs_to_visit.empty()) {
159 SmallVector<FuncOp, 4> new_funcs_to_visit;
160 for (FuncOp& func : funcs_to_visit) {
161 auto uses = func.getSymbolUses(module);
162 if (!uses) continue;
163 for (auto& use : *uses) {
164 // Only `mlir::CallOp` is supported as this requires knowing how to
165 // rewrite arguments and results to a function.
166 if (!isa<mlir::CallOp>(use.getUser())) continue;
167 auto caller_parent_func = use.getUser()->getParentOfType<FuncOp>();
168 if (!caller_parent_func) continue;
169
170 FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
171 /*control_flow_ops=*/{},
172 /*control_flow_blocks=*/{},
173 /*clone=*/nullptr};
174 if (failed(GetControlFlowAncestors(
175 use.getUser(), func_to_rewrite.control_flow_ops,
176 func_to_rewrite.control_flow_blocks)))
177 return failure();
178
179 auto it = funcs_to_rewrite.insert(
180 {caller_parent_func.getName(), func_to_rewrite});
181 if (it.second) {
182 new_funcs_to_visit.push_back(caller_parent_func);
183 } else {
184 it.first->getSecond().control_flow_ops.insert(
185 func_to_rewrite.control_flow_ops.begin(),
186 func_to_rewrite.control_flow_ops.end());
187 it.first->getSecond().control_flow_blocks.insert(
188 func_to_rewrite.control_flow_blocks.begin(),
189 func_to_rewrite.control_flow_blocks.end());
190 }
191 }
192 }
193
194 funcs_to_visit.swap(new_funcs_to_visit);
195 }
196
197 // Clone public functions that need to be rewritten. Function calls to this
198 // function will be replaced with the cloned function.
199 SymbolTable symbol_table(module);
200 for (auto& func : funcs_to_rewrite) {
201 if (func.getSecond().original.isPublic() &&
202 !func.getSecond().original.symbolKnownUseEmpty(module)) {
203 auto clone = func.getSecond().original.clone();
204 clone.setPrivate();
205 symbol_table.insert(clone);
206 func.getSecond().clone = clone;
207 }
208 }
209
210 return success();
211 }
212
213 // Assigns op sharding to an op for a given device core.
SetOpSharding(Operation * op,int64_t tpu_core)214 void SetOpSharding(Operation* op, int64_t tpu_core) {
215 std::string sharding_serialized =
216 ::xla::sharding_builder::AssignDevice(tpu_core).SerializeAsString();
217 op->setAttr(kShardingAttr,
218 StringAttr::get(op->getContext(), sharding_serialized));
219 }
220
221 // Assigns frontend attributes holding information about data type and
222 // TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
223 // handled differently as individual names are used per data send and receive.
SetFrontendAttributes(Operation * op,int32_t index,StringRef key,Type type,bool device_to_host)224 void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
225 Type type, bool device_to_host) {
226 MLIRContext* context = op->getContext();
227
228 std::string formatted_key =
229 device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
230 : llvm::formatv("{0}_htod_{1}", key, index).str();
231
232 auto rendezvous_name = StringAttr::get(context, formatted_key);
233 auto rendezvous_name_attr = NamedAttribute(
234 Identifier::get(kXlaHostTransferRendezvousNameAttr, context),
235 rendezvous_name);
236
237 auto element_type = getElementTypeOrSelf(type);
238 auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
239 const std::string& xla_element_type_str =
240 ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
241 auto original_type = StringAttr::get(context, xla_element_type_str);
242 auto original_type_attr =
243 NamedAttribute(Identifier::get(kXlaHostTransferOriginalTypeAttr, context),
244 original_type);
245
246 auto frontend_attributes = DictionaryAttr::get(
247 context,
248 ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr});
249 op->setAttr(kFrontendAttributesAttr, frontend_attributes);
250 }
251
252 // Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
253 // op sharding for the respective device will be set.
CreateSendOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value operand,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)254 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
255 Value operand, StringRef key, size_t index,
256 const Optional<int64_t>& tpu_core, Value token) {
257 // type 2 == DEVICE_TO_HOST
258 auto channel_handle = ChannelHandle::get(
259 /*handle=*/builder.getI64IntegerAttr(channel_id++),
260 /*type=*/builder.getI64IntegerAttr(2), builder.getContext());
261 auto send = builder.create<SendOp>(
262 loc, token.getType(), operand, token, channel_handle,
263 /*is_host_transfer=*/builder.getBoolAttr(true));
264
265 SetFrontendAttributes(send, index, key, operand.getType(),
266 /*device_to_host=*/true);
267
268 if (tpu_core) SetOpSharding(send, *tpu_core);
269
270 return send.getResult();
271 }
272
273 // Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
274 // sharding for the respective device will be set.
CreateRecvOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value result,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)275 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
276 Value result, StringRef key, size_t index,
277 const Optional<int64_t>& tpu_core, Value token) {
278 // type 3 == HOST_TO_DEVICE
279 auto channel_handle = ChannelHandle::get(
280 /*handle=*/builder.getI64IntegerAttr(channel_id++),
281 /*type=*/builder.getI64IntegerAttr(3), builder.getContext());
282 auto result_type = result.getType();
283 auto recv_result_type =
284 TupleType::get(builder.getContext(), {result_type, token.getType()});
285 auto recv =
286 builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
287 /*is_host_transfer=*/builder.getBoolAttr(true));
288
289 SetFrontendAttributes(recv, index, key, result_type,
290 /*device_to_host=*/false);
291
292 if (tpu_core) SetOpSharding(recv, *tpu_core);
293
294 auto get_tuple_element =
295 builder.create<GetTupleElementOp>(loc, recv.getResult(), /*index=*/0);
296 if (tpu_core) SetOpSharding(get_tuple_element, *tpu_core);
297
298 result.replaceAllUsesWith(get_tuple_element);
299
300 auto new_token = builder.create<GetTupleElementOp>(loc, recv.getResult(),
301 /*index=*/1);
302 if (tpu_core) SetOpSharding(new_token, *tpu_core);
303
304 return new_token.getResult();
305 }
306
307 // Creates a new token if necessary, acting as a sink to previous tokens. If
308 // there is only one token in `tokens`, the only token is returned. If `tokens`
309 // is empty, `original_token` is returned instead.
CreateSinkToken(OpBuilder & builder,Location loc,ArrayRef<Value> tokens,Value original_token)310 Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
311 Value original_token) {
312 if (tokens.empty()) {
313 return original_token;
314 } else if (llvm::hasSingleElement(tokens)) {
315 return tokens[0];
316 } else {
317 return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
318 .getResult();
319 }
320 }
321
322 // Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
323 // ops per operand and result. Unique Channel Id's are assigned per transfer.
324 // Sink tokens are created across all `mhlo.send` ops first and then by
325 // all `mhlo.recv` ops.
RewriteHostComputeOp(OpBuilder & builder,int64_t & channel_id,TF::_XlaHostComputeMlirOp host_compute,Value token)326 Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
327 TF::_XlaHostComputeMlirOp host_compute,
328 Value token) {
329 builder.setInsertionPoint(host_compute);
330 Location loc = host_compute.getLoc();
331 int64_t tpu_core = host_compute.tpu_coreAttr().getInt();
332
333 SmallVector<Value, 4> send_tokens;
334 for (auto operand : llvm::enumerate(host_compute.inputs())) {
335 auto send_token =
336 CreateSendOp(builder, channel_id, loc, operand.value(),
337 host_compute.send_key(), operand.index(), tpu_core, token);
338 send_tokens.push_back(send_token);
339 }
340 token = CreateSinkToken(builder, loc, send_tokens, token);
341
342 SmallVector<Value, 4> recv_tokens;
343 for (auto result : llvm::enumerate(host_compute.outputs())) {
344 auto recv_token =
345 CreateRecvOp(builder, channel_id, loc, result.value(),
346 host_compute.recv_key(), result.index(), tpu_core, token);
347 recv_tokens.push_back(recv_token);
348 }
349 token = CreateSinkToken(builder, loc, recv_tokens, token);
350
351 host_compute.erase();
352 return token;
353 }
354
355 // Replaces `tf.XlaSendToHost` with a `mhlo.send`.
RewriteSendToHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaSendToHostOp send_to_host,Value token)356 Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
357 TF::XlaSendToHostOp send_to_host, Value token) {
358 builder.setInsertionPoint(send_to_host);
359 token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
360 send_to_host.input(), send_to_host.key(),
361 /*index=*/0, /*tpu_core=*/llvm::None, token);
362
363 send_to_host.erase();
364 return token;
365 }
366
367 // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
RewriteRecvFromHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaRecvFromHostOp recv_from_host,Value token)368 Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
369 TF::XlaRecvFromHostOp recv_from_host, Value token) {
370 builder.setInsertionPoint(recv_from_host);
371 token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
372 recv_from_host.output(), recv_from_host.key(),
373 /*index=*/0, /*tpu_core=*/llvm::None, token);
374
375 recv_from_host.erase();
376 return token;
377 }
378
379 // Replaces a `mlir::CallOp` with one that has an extra `!mhlo.token` operand
380 // and `!mhlo.token` result. If `new_symbol` is set, the new call will be
381 // updated to call the `new_symbol` instead.
RewriteCallOp(OpBuilder & builder,CallOp call,const Optional<StringRef> & new_symbol,Value token)382 Value RewriteCallOp(OpBuilder& builder, CallOp call,
383 const Optional<StringRef>& new_symbol, Value token) {
384 builder.setInsertionPoint(call);
385 auto new_operands = llvm::to_vector<4>(call.getArgOperands());
386 new_operands.push_back(token);
387 auto new_result_types = llvm::to_vector<4>(call.getResultTypes());
388 new_result_types.push_back(token.getType());
389 auto new_call = builder.create<CallOp>(
390 call.getLoc(), new_result_types, new_symbol ? *new_symbol : call.callee(),
391 new_operands);
392
393 for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
394 std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
395 call.erase();
396 return new_call.getResults().back();
397 }
398
399 // Helper struct holding state of which op to visit to next. If `op` is in a
400 // control flow op region, `region_idx` will be set with the respective region
401 // index. `token` will be current token from the last communication op/control
402 // flow op transitive communication ops.
403 struct OpVisitorState {
404 Optional<unsigned> region_idx;
405 Value token;
406 Operation* op;
407 };
408
409 // Creates a tuple from a sequence of values.
CreateTuple(OpBuilder & builder,Location loc,ArrayRef<Value> operands)410 Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
411 return builder.create<TupleOp>(loc, operands).getResult();
412 }
413
414 // Replaces a value `value` with a new value but the token attached. If `value`
415 // is not a tuple, a new tuple is formed with `token`. If `value` is a tuple,
416 // `value` is extended instead. New tuple values created are cached.
GetValueWithToken(OpBuilder & builder,Value value,Value token,llvm::SmallDenseMap<Value,Value> & rewritten_values)417 Value GetValueWithToken(OpBuilder& builder, Value value, Value token,
418 llvm::SmallDenseMap<Value, Value>& rewritten_values) {
419 // If value with token already exists, reuse it.
420 auto it = rewritten_values.find(value);
421 if (it != rewritten_values.end()) return it->getSecond();
422
423 auto create_tuple = [&](ArrayRef<Value> operands) {
424 auto new_result = CreateTuple(builder, value.getLoc(), operands);
425 rewritten_values.insert({value, new_result});
426 return new_result;
427 };
428
429 auto tuple_type = value.getType().dyn_cast<TupleType>();
430 // `value` is not a tuple, create a new tuple.
431 if (!tuple_type) return create_tuple({value, token});
432
433 // Extend tuple if `value` is a tuple.
434 // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
435 // the tuple.
436 if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
437 auto tuple_operands = llvm::to_vector<4>(tuple_op.getOperands());
438 tuple_operands.push_back(token);
439 return create_tuple(tuple_operands);
440 }
441
442 // `value` is not created via a `mhlo.tuple` directly, unpack individual
443 // elements directly with `mhlo.get_tuple_element`.
444 SmallVector<Value, 4> tuple_operands;
445 for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
446 tuple_operands.push_back(
447 builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
448 .getResult());
449
450 tuple_operands.push_back(token);
451 return create_tuple(tuple_operands);
452 }
453
454 // Extends a type to include a `mhlo.token` type. If `type` is not a tuple type,
455 // a new tuple type with `type` and `mhlo.token` type is created instead.
GetTypeWithToken(OpBuilder & builder,Type type)456 TupleType GetTypeWithToken(OpBuilder& builder, Type type) {
457 auto token_type = TokenType::get(builder.getContext());
458 if (auto tuple_type = type.dyn_cast<TupleType>()) {
459 auto result_types = llvm::to_vector<4>(tuple_type.getTypes());
460 result_types.push_back(token_type);
461 return builder.getTupleType(result_types);
462 }
463
464 return builder.getTupleType({type, token_type});
465 }
466
467 // Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
468 // to `end`, exclusive.
CreateSubTuple(OpBuilder & builder,Value value,size_t end)469 Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
470 SmallVector<Value, 4> tuple_operands;
471 for (auto idx : llvm::seq<int32_t>(0, end))
472 tuple_operands.push_back(
473 builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
474 .getResult());
475
476 return CreateTuple(builder, value.getLoc(), tuple_operands);
477 }
478
479 // Replaces uses of `value` with `replacement`. If `value` is not a tuple type,
480 // an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
481 // return the first element. Otherwise, `mhlo.get_tuple_element` users are
482 // simply updated with `replacement`, and all other users are updated with a
483 // slice of `replacement`.
ReplaceWithTupleResult(OpBuilder & builder,Value value,Value replacement)484 void ReplaceWithTupleResult(OpBuilder& builder, Value value,
485 Value replacement) {
486 auto tuple_type = value.getType().dyn_cast<TupleType>();
487 if (!tuple_type) {
488 if (!value.use_empty()) {
489 auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
490 replacement, 0);
491 value.replaceAllUsesWith(new_element.getResult());
492 }
493 return;
494 }
495
496 Value sub_tuple;
497 for (auto& use : llvm::make_early_inc_range(value.getUses())) {
498 if (isa<GetTupleElementOp>(use.getOwner())) {
499 use.set(replacement);
500 continue;
501 }
502
503 if (!sub_tuple)
504 sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
505
506 use.set(sub_tuple);
507 }
508 }
509
510 // Replaces control flow op block single block argument with new block argument
511 // of type `new_type` (tuple type). The last element of the new block argument
512 // (token) is returned.
UpdateControlFlowBlockArgWithToken(OpBuilder & builder,Block & block,Type token_type)513 Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
514 Type token_type) {
515 assert(block.getNumArguments() == 1);
516 builder.setInsertionPointToStart(&block);
517 auto new_arg = block.addArgument(token_type);
518 ReplaceWithTupleResult(builder, block.getArgument(0), new_arg);
519 block.eraseArgument(0);
520 return builder
521 .create<GetTupleElementOp>(new_arg.getLoc(), new_arg,
522 token_type.cast<TupleType>().size() - 1)
523 .getResult();
524 }
525
526 // Updates control flow op terminator with an extra element `token`. If the
527 // original return value is not a tuple, a new tuple is formed. Otherwise the
528 // tuple is extended.
RewriteControlFlowTerminator(OpBuilder & builder,Operation * terminator,Value token)529 void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
530 Value token) {
531 assert(terminator->getNumOperands() == 1);
532 assert(terminator->getBlock()->getNumArguments() == 1);
533 // `mhlo.while` cond terminator does not need to be rewritten as it always
534 // returns a tensor<i1> predicate value.
535 if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
536 if (terminator->getParentRegion() == &while_parent.cond()) return;
537
538 builder.setInsertionPoint(terminator);
539 llvm::SmallDenseMap<Value, Value> rewritten_operands;
540 Value new_result = GetValueWithToken(builder, terminator->getOperand(0),
541 token, rewritten_operands);
542 terminator->setOperand(0, new_result);
543 }
544
545 // Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. Operands to
546 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionIfOp(OpBuilder & builder,IfOp region_if,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)547 void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
548 SmallVectorImpl<OpVisitorState>& ops_to_visit,
549 Value token) {
550 llvm::SmallDenseMap<Value, Value> rewritten_operands;
551
552 // Rewrite all region operands to have an extra operand `token`.
553 Value new_true_operand = GetValueWithToken(builder, region_if.true_arg(),
554 token, rewritten_operands);
555 Value new_false_operand = GetValueWithToken(builder, region_if.false_arg(),
556 token, rewritten_operands);
557
558 auto new_result_type = GetTypeWithToken(builder, region_if.getType());
559
560 // Create new `mhlo.if` op with extra token operands and result.
561 auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_type,
562 region_if.pred(), new_true_operand,
563 new_false_operand);
564
565 // Move all regions from the old `mhlo.if` op to its replacement.
566 new_if.true_branch().takeBody(region_if.true_branch());
567 new_if.false_branch().takeBody(region_if.false_branch());
568
569 // Forward result from old `mhlo.if` with replacement, and unpack result when
570 // necessary.
571 ReplaceWithTupleResult(builder, region_if.getResult(), new_if.getResult());
572
573 auto new_token = builder.create<GetTupleElementOp>(
574 new_if.getLoc(), new_if.getResult(),
575 new_if.getResult().getType().cast<TupleType>().size() - 1);
576
577 region_if.erase();
578
579 // Remove leftover operands to old `mhlo.if` if they have no uses.
580 for (auto& rewritten_operand : rewritten_operands)
581 if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
582 if (tuple_op.use_empty()) tuple_op.erase();
583
584 // Next op to visit. The replacement is visited but at its first region. The
585 // token result of the new region if is propagated.
586 ops_to_visit.push_back({/*region_idx=*/0, new_token, new_if});
587 }
588
589 // Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
590 // `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
591 // element. If the region block is to be rewritten, the next op to visit is set
592 // to the first op in the block. Otherwise the terminator is updated to forward
593 // `token`.
RewriteControlFlowOpRegion(OpBuilder & builder,Operation * region_op,unsigned region_idx,Type block_arg_type,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)594 void RewriteControlFlowOpRegion(
595 OpBuilder& builder, Operation* region_op, unsigned region_idx,
596 Type block_arg_type, SmallVectorImpl<OpVisitorState>& ops_to_visit,
597 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
598 ops_to_visit.push_back({region_idx + 1, token, region_op});
599
600 Region& region = region_op->getRegion(region_idx);
601 assert(llvm::hasSingleElement(region));
602
603 auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
604 block_arg_type);
605
606 if (control_flow_blocks.contains(®ion.front())) {
607 ops_to_visit.push_back({/*region_idx=*/llvm::None, block_token,
608 block_token.getDefiningOp()->getNextNode()});
609 return;
610 }
611
612 RewriteControlFlowTerminator(builder, region.front().getTerminator(),
613 block_token);
614 }
615
616 // Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
617 // operands and results are rewritten. If `region_idx` is set, region
618 // `region_idx` is rewritten to take in and return an additional token. Returns
619 // true if the op or its region was rewritten.
ProcessRegionIfOp(OpBuilder & builder,IfOp region_if,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)620 bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
621 Optional<unsigned> region_idx,
622 SmallVectorImpl<OpVisitorState>& ops_to_visit,
623 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
624 Value token) {
625 builder.setInsertionPoint(region_if);
626
627 if (!region_idx) {
628 RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
629 return true;
630 }
631
632 if (*region_idx < region_if.getNumRegions()) {
633 RewriteControlFlowOpRegion(builder, region_if, *region_idx,
634 region_if.getOperand(*region_idx + 1).getType(),
635 ops_to_visit, control_flow_blocks, token);
636 return true;
637 }
638
639 return false;
640 }
641
642 // Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
643 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionWhileOp(OpBuilder & builder,WhileOp region_while,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)644 void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
645 SmallVectorImpl<OpVisitorState>& ops_to_visit,
646 Value token) {
647 llvm::SmallDenseMap<Value, Value> rewritten_operands;
648
649 // Rewrite region operand to have an extra operand `token`.
650 Value new_val_operand =
651 GetValueWithToken(builder, region_while.val(), token, rewritten_operands);
652
653 auto new_result_type = GetTypeWithToken(builder, region_while.getType());
654
655 // Create new `mhlo.while` op with extra token operand and result.
656 auto new_while = builder.create<WhileOp>(region_while.getLoc(),
657 new_result_type, new_val_operand);
658
659 // Move all regions from the old `mhlo.while` op to its replacement.
660 new_while.cond().takeBody(region_while.cond());
661 new_while.body().takeBody(region_while.body());
662
663 // Forward result from old `mhlo.while` with replacement, and unpack result
664 // when necessary.
665 ReplaceWithTupleResult(builder, region_while.getResult(),
666 new_while.getResult());
667
668 auto new_token = builder.create<GetTupleElementOp>(
669 new_while.getLoc(), new_while.getResult(),
670 new_while.getResult().getType().cast<TupleType>().size() - 1);
671
672 region_while.erase();
673
674 // Remove leftover operands to old `mhlo.while` if they have no uses.
675 for (auto& rewritten_operand : rewritten_operands)
676 if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
677 if (tuple_op.use_empty()) tuple_op.erase();
678
679 // Next op to visit. The replacement is visited but at its first region. The
680 // token result of the new region if is propagated.
681 ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
682 }
683
684 // Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
685 // operands and results are rewritten. If `region_idx` is set, region
686 // `region_idx` is rewritten to take in and return an additional token. Returns
687 // true if the op or its region was rewritten.
ProcessRegionWhileOp(OpBuilder & builder,WhileOp region_while,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)688 bool ProcessRegionWhileOp(
689 OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
690 SmallVectorImpl<OpVisitorState>& ops_to_visit,
691 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
692 builder.setInsertionPoint(region_while);
693
694 if (!region_idx) {
695 RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
696 return true;
697 }
698
699 if (*region_idx < region_while.getNumRegions()) {
700 RewriteControlFlowOpRegion(builder, region_while, *region_idx,
701 region_while.val().getType(), ops_to_visit,
702 control_flow_blocks, token);
703 return true;
704 }
705
706 return false;
707 }
708
709 // Updates function type based on current function body block arguments and
710 // terminator operand types.
UpdateFunctionType(OpBuilder & builder,FuncOp func,Block & func_body)711 void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
712 auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
713 auto new_result_types =
714 llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
715 func.setType(FunctionType::get(builder.getContext(), new_argument_types,
716 new_result_types));
717 }
718
719 // Replaces a function terminator `return` with another `return` that has an
720 // extra `mhlo.token` operand.
RewriteFunctionTerminator(OpBuilder & builder,mlir::ReturnOp terminator,Value token)721 void RewriteFunctionTerminator(OpBuilder& builder, mlir::ReturnOp terminator,
722 Value token) {
723 auto new_results = llvm::to_vector<4>(terminator.getOperands());
724 new_results.push_back(token);
725 builder.setInsertionPoint(terminator);
726 builder.create<mlir::ReturnOp>(terminator.getLoc(), new_results);
727 terminator.erase();
728 }
729
730 // Rewrites a function body and communication ops inside. Region control flow
731 // are updated when necessary, to propagate tokens. The function may either be
732 // rewritten to create a token or take in and return a token, depending on its
733 // visibility and if there are any callers.
RewriteFunction(OpBuilder & builder,int64_t & channel_id,ModuleOp module,FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs,const llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool is_clone)734 LogicalResult RewriteFunction(
735 OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
736 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
737 const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
738 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
739 MLIRContext* context = module.getContext();
740 if (!llvm::hasSingleElement(func.getBody()))
741 return func.emitError()
742 << "'" << FuncOp::getOperationName()
743 << "' ops with more than one block are not supported";
744
745 bool rewrite_block =
746 is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
747 Block& func_body = func.front();
748
749 builder.setInsertionPointToStart(&func_body);
750 auto token_type = TokenType::get(context);
751 // If a function is public, it's signature should not be modified, and instead
752 // a token will be created. Otherwise a token block argument is inserted.
753 Value init_token =
754 rewrite_block ? func_body.addArgument(token_type)
755 : builder.create<CreateTokenOp>(func.getLoc(), token_type)
756 .getResult();
757
758 // Stack to keep track of region based control flow op nesting and current
759 // op to visit.
760 SmallVector<OpVisitorState, 4> ops_to_visit{
761 {/*region_idx=*/llvm::None, init_token, &func_body.front()}};
762
763 while (!ops_to_visit.empty()) {
764 OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
765 Operation* curr_op = op_to_visit.op;
766
767 Value token = op_to_visit.token;
768 // Ops may be removed, so the next op is kept track of beforehand.
769 Operation* next_op = curr_op->getNextNode();
770
771 if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
772 token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
773 } else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
774 token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
775 } else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
776 token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
777 } else if (auto call = dyn_cast<mlir::CallOp>(curr_op)) {
778 // Only `mlir::CallOp` is supported as this requires knowing how to
779 // rewrite arguments and results to a function.
780 auto it = funcs.find(call.getCallee());
781 if (it != funcs.end()) {
782 FuncOp clone = it->getSecond().clone;
783 Optional<StringRef> symbol_name =
784 clone ? Optional<StringRef>(clone.getName()) : llvm::None;
785 // If the function being called is to be cloned, update the call to also
786 // point to the cloned function.
787 token = RewriteCallOp(builder, call, symbol_name, token);
788 }
789 } else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
790 if (op_to_visit.region_idx || control_flow_ops.contains(region_if))
791 if (ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
792 ops_to_visit, control_flow_blocks, token))
793 continue;
794 } else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
795 if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
796 if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
797 ops_to_visit, control_flow_blocks, token))
798 continue;
799 } else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
800 RewriteControlFlowTerminator(builder, region_terminator, token);
801 // There is no next op afer the control flow op terminator, simply let
802 // stack have one less element.
803 continue;
804 } else if (auto func_terminator = dyn_cast<mlir::ReturnOp>(curr_op)) {
805 if (rewrite_block)
806 RewriteFunctionTerminator(builder, func_terminator, token);
807
808 // There is no next op afer the function terminator, simply let stack have
809 // one less element/be empty.
810 continue;
811 }
812
813 // Visit next op.
814 ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
815 }
816
817 if (rewrite_block) UpdateFunctionType(builder, func, func_body);
818
819 return success();
820 }
821
822 // Checks if a function call is pointing to a function with communication ops.
IsFunctionCallWithCommunication(Operation * op,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)823 bool IsFunctionCallWithCommunication(
824 Operation* op,
825 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
826 if (auto call = dyn_cast<mlir::CallOp>(op))
827 return funcs_to_rewrite.count(call.callee());
828
829 return false;
830 }
831
832 // Collects all control flow op ancestors of communication ops or function calls
833 // with communication ops (transitively).
GetCommunicationControlFlowOps(FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)834 void GetCommunicationControlFlowOps(
835 FuncOp func,
836 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
837 llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
838 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
839 func.walk([&](Operation* op) {
840 if (IsCommunicationOp(op) ||
841 IsFunctionCallWithCommunication(op, funcs_to_rewrite))
842 if (failed(GetControlFlowAncestors(op, control_flow_ops,
843 control_flow_blocks)))
844 llvm_unreachable(
845 "checking original function for control flow ancestors should have "
846 "errored first");
847 });
848 }
849
runOnOperation()850 void LegalizeTFCommunication::runOnOperation() {
851 auto module = getOperation();
852 llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
853 if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
854 return signalPassFailure();
855
856 // Module level counter to make sure Channel Id's are unique.
857 int64_t channel_id = 1;
858 OpBuilder builder(&getContext());
859 for (const auto& func_and_name : funcs_to_rewrite) {
860 const auto& func_to_rewrite = func_and_name.getSecond();
861 FuncOp func = func_to_rewrite.original;
862 if (failed(RewriteFunction(builder, channel_id, module, func,
863 funcs_to_rewrite,
864 func_to_rewrite.control_flow_ops,
865 func_to_rewrite.control_flow_blocks,
866 /*is_clone=*/false)))
867 return signalPassFailure();
868
869 FuncOp clone = func_and_name.getSecond().clone;
870 if (!clone) continue;
871 llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
872 llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
873 GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
874 clone_control_flow_ops,
875 clone_control_flow_blocks);
876 if (failed(RewriteFunction(builder, channel_id, module, clone,
877 funcs_to_rewrite, clone_control_flow_ops,
878 clone_control_flow_blocks,
879 /*is_clone=*/true)))
880 llvm_unreachable(
881 "rewriting of original function should have errored first");
882 }
883 }
884
885 static PassRegistration<LegalizeTFCommunication> pass(
886 "xla-legalize-tf-communication",
887 "Legalize TF/XLA communication ops (TensorFlow dialect) to the HLO "
888 "dialect");
889 } // namespace
890
CreateLegalizeTFCommunicationPass()891 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
892 return std::make_unique<LegalizeTFCommunication>();
893 }
894
895 } // namespace mhlo
896 } // namespace mlir
897