1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // This file implements logic for lowering TensorFlow dialect's communication
17 // ops (TF/XLA) to the HLO dialect.
18
19 #include <memory>
20 #include <string>
21
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
34 #include "mlir/IR/Value.h" // from @llvm-project
35 #include "mlir/IR/Visitors.h" // from @llvm-project
36 #include "mlir/Pass/Pass.h" // from @llvm-project
37 #include "mlir/Support/LLVM.h" // from @llvm-project
38 #include "mlir/Support/LogicalResult.h" // from @llvm-project
39 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
42 #include "tensorflow/compiler/xla/client/sharding_builder.h"
43 #include "tensorflow/compiler/xla/primitive_util.h"
44
45 namespace mlir {
46 namespace mhlo {
47
48 namespace {
49 constexpr char kShardingAttr[] = "mhlo.sharding";
50 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
51 const char kXlaHostTransferRendezvousNameAttr[] =
52 "_xla_host_transfer_rendezvous";
53 const char kXlaHostTransferOriginalTypeAttr[] =
54 "_xla_host_transfer_original_type";
55
56 // A pass that legalizes TF/XLA communication ops, propagate their respective
57 // tokens (for ordering), and rewrite their respective functions and control
58 // flow ops when necessary.
59 // Note, this currently does not handle nested modules/functions or region based
60 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
61 class LegalizeTFCommunication
62 : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const63 void getDependentDialects(DialectRegistry& registry) const override {
64 registry.insert<mhlo::MhloDialect>();
65 }
66
67 public:
getArgument() const68 StringRef getArgument() const final {
69 return "xla-legalize-tf-communication";
70 }
getDescription() const71 StringRef getDescription() const final {
72 return "Legalize TF/XLA communication ops (TensorFlow dialect) to the HLO "
73 "dialect";
74 }
75 void runOnOperation() override;
76 };
77
78 // Checks if an op is a TF/XLA communication op.
IsCommunicationOp(Operation * op)79 bool IsCommunicationOp(Operation* op) {
80 return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
81 TF::XlaRecvFromHostOp>(op);
82 }
83
84 // Checks if an op is a supported HLO control flow op.
IsControlFlowOp(Operation * op)85 bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
86
87 // Collects control flow op ancestors of a given op, up until FuncOp. If any
88 // ancestor is not a control flow op or a FuncOp, or of a single block region,
89 // an error will be returned.
GetControlFlowAncestors(Operation * op,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)90 LogicalResult GetControlFlowAncestors(
91 Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
92 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
93 Block* block = op->getBlock();
94 Operation* parent = block->getParentOp();
95 while (block && parent && !isa<FuncOp>(parent)) {
96 if (!IsControlFlowOp(parent))
97 return op->emitOpError()
98 << "expects ancestor(s) to be of ['" << IfOp::getOperationName()
99 << "', '" << FuncOp::getOperationName() << "']";
100
101 if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
102 return op->emitOpError() << "expects single block region ancestor(s)";
103
104 control_flow_ops.insert(parent);
105 control_flow_blocks.insert(block);
106
107 parent = block->getParentOp();
108 block = parent->getBlock();
109 }
110 return success();
111 }
112
113 // Finds communication ops in a function. `control_flow_ops` and
114 // `control_flow_blocks` will be populated with control flow op ancestors for
115 // every communication op.
FindCommunicationOps(FuncOp func,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool & has_communication_ops)116 LogicalResult FindCommunicationOps(
117 FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
118 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
119 bool& has_communication_ops) {
120 auto result = func.walk([&](Operation* op) {
121 if (!IsCommunicationOp(op)) return WalkResult::advance();
122 has_communication_ops = true;
123 if (failed(
124 GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
125 return WalkResult::interrupt();
126 return WalkResult::advance();
127 });
128 return failure(result.wasInterrupted());
129 }
130
131 // Helper struct holding a function to be rewritten, it's control flow ops that
132 // lead to a communication op or function call with a communication op
133 // (transitively), and an optional clone of itself. If `clone` is set, function
134 // calls to `original` will be replaced with `clone`.
135 struct FuncToRewrite {
136 FuncOp original;
137 llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
138 llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
139 FuncOp clone;
140 };
141
142 // Finds all functions that need to be rewritten with communication ops and
143 // and associated tokens.
GetFunctionsToRewrite(ModuleOp module,llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)144 LogicalResult GetFunctionsToRewrite(
145 ModuleOp module,
146 llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
147 // Find functions containing communication ops.
148 SmallVector<FuncOp, 4> funcs_to_visit;
149 for (FuncOp func : module.getOps<FuncOp>()) {
150 FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
151 /*control_flow_blocks=*/{},
152 /*clone=*/nullptr};
153 bool has_communication_ops = false;
154 if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
155 func_to_rewrite.control_flow_blocks,
156 has_communication_ops)))
157 return failure();
158
159 if (!has_communication_ops) continue;
160 funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
161 funcs_to_visit.push_back(func);
162 }
163
164 // Find functions that call functions with communication ops, transitively.
165 while (!funcs_to_visit.empty()) {
166 SmallVector<FuncOp, 4> new_funcs_to_visit;
167 for (FuncOp& func : funcs_to_visit) {
168 auto uses = func.getSymbolUses(module);
169 if (!uses) continue;
170 for (auto& use : *uses) {
171 // Only `mlir::CallOp` is supported as this requires knowing how to
172 // rewrite arguments and results to a function.
173 if (!isa<mlir::CallOp>(use.getUser())) continue;
174 auto caller_parent_func = use.getUser()->getParentOfType<FuncOp>();
175 if (!caller_parent_func) continue;
176
177 FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
178 /*control_flow_ops=*/{},
179 /*control_flow_blocks=*/{},
180 /*clone=*/nullptr};
181 if (failed(GetControlFlowAncestors(
182 use.getUser(), func_to_rewrite.control_flow_ops,
183 func_to_rewrite.control_flow_blocks)))
184 return failure();
185
186 auto it = funcs_to_rewrite.insert(
187 {caller_parent_func.getName(), func_to_rewrite});
188 if (it.second) {
189 new_funcs_to_visit.push_back(caller_parent_func);
190 } else {
191 it.first->getSecond().control_flow_ops.insert(
192 func_to_rewrite.control_flow_ops.begin(),
193 func_to_rewrite.control_flow_ops.end());
194 it.first->getSecond().control_flow_blocks.insert(
195 func_to_rewrite.control_flow_blocks.begin(),
196 func_to_rewrite.control_flow_blocks.end());
197 }
198 }
199 }
200
201 funcs_to_visit.swap(new_funcs_to_visit);
202 }
203
204 // Clone public functions that need to be rewritten. Function calls to this
205 // function will be replaced with the cloned function.
206 SymbolTable symbol_table(module);
207 for (auto& func : funcs_to_rewrite) {
208 if (func.getSecond().original.isPublic() &&
209 !func.getSecond().original.symbolKnownUseEmpty(module)) {
210 auto clone = func.getSecond().original.clone();
211 clone.setPrivate();
212 symbol_table.insert(clone);
213 func.getSecond().clone = clone;
214 }
215 }
216
217 return success();
218 }
219
220 // Assigns op sharding to an op for a given device core.
SetOpSharding(Operation * op,int64_t tpu_core)221 void SetOpSharding(Operation* op, int64_t tpu_core) {
222 std::string sharding_serialized =
223 ::xla::sharding_builder::AssignDevice(tpu_core).SerializeAsString();
224 op->setAttr(kShardingAttr,
225 StringAttr::get(op->getContext(), sharding_serialized));
226 }
227
228 // Assigns frontend attributes holding information about data type and
229 // TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
230 // handled differently as individual names are used per data send and receive.
SetFrontendAttributes(Operation * op,int32_t index,StringRef key,Type type,bool device_to_host)231 void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
232 Type type, bool device_to_host) {
233 MLIRContext* context = op->getContext();
234
235 std::string formatted_key =
236 device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
237 : llvm::formatv("{0}_htod_{1}", key, index).str();
238
239 auto rendezvous_name = StringAttr::get(context, formatted_key);
240 auto rendezvous_name_attr = NamedAttribute(
241 Identifier::get(kXlaHostTransferRendezvousNameAttr, context),
242 rendezvous_name);
243
244 auto element_type = getElementTypeOrSelf(type);
245 auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
246 const std::string& xla_element_type_str =
247 ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
248 auto original_type = StringAttr::get(context, xla_element_type_str);
249 auto original_type_attr =
250 NamedAttribute(Identifier::get(kXlaHostTransferOriginalTypeAttr, context),
251 original_type);
252
253 auto frontend_attributes = DictionaryAttr::get(
254 context,
255 ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr});
256 op->setAttr(kFrontendAttributesAttr, frontend_attributes);
257 }
258
259 // Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
260 // op sharding for the respective device will be set.
CreateSendOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value operand,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)261 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
262 Value operand, StringRef key, size_t index,
263 const Optional<int64_t>& tpu_core, Value token) {
264 // type 2 == DEVICE_TO_HOST
265 auto channel_handle = ChannelHandle::get(
266 /*handle=*/builder.getI64IntegerAttr(channel_id++),
267 /*type=*/builder.getI64IntegerAttr(2), builder.getContext());
268 auto send = builder.create<SendOp>(
269 loc, token.getType(), operand, token, channel_handle,
270 /*is_host_transfer=*/builder.getBoolAttr(true));
271
272 SetFrontendAttributes(send, index, key, operand.getType(),
273 /*device_to_host=*/true);
274
275 if (tpu_core) SetOpSharding(send, *tpu_core);
276
277 return send.getResult();
278 }
279
280 // Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
281 // sharding for the respective device will be set.
CreateRecvOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value result,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)282 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
283 Value result, StringRef key, size_t index,
284 const Optional<int64_t>& tpu_core, Value token) {
285 // type 3 == HOST_TO_DEVICE
286 auto channel_handle = ChannelHandle::get(
287 /*handle=*/builder.getI64IntegerAttr(channel_id++),
288 /*type=*/builder.getI64IntegerAttr(3), builder.getContext());
289 auto result_type = result.getType();
290 auto recv_result_type =
291 TupleType::get(builder.getContext(), {result_type, token.getType()});
292 auto recv =
293 builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
294 /*is_host_transfer=*/builder.getBoolAttr(true));
295
296 SetFrontendAttributes(recv, index, key, result_type,
297 /*device_to_host=*/false);
298
299 if (tpu_core) SetOpSharding(recv, *tpu_core);
300
301 auto get_tuple_element =
302 builder.create<GetTupleElementOp>(loc, recv.getResult(), /*index=*/0);
303 if (tpu_core) SetOpSharding(get_tuple_element, *tpu_core);
304
305 result.replaceAllUsesWith(get_tuple_element);
306
307 auto new_token = builder.create<GetTupleElementOp>(loc, recv.getResult(),
308 /*index=*/1);
309 if (tpu_core) SetOpSharding(new_token, *tpu_core);
310
311 return new_token.getResult();
312 }
313
314 // Creates a new token if necessary, acting as a sink to previous tokens. If
315 // there is only one token in `tokens`, the only token is returned. If `tokens`
316 // is empty, `original_token` is returned instead.
CreateSinkToken(OpBuilder & builder,Location loc,ArrayRef<Value> tokens,Value original_token)317 Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
318 Value original_token) {
319 if (tokens.empty()) {
320 return original_token;
321 } else if (llvm::hasSingleElement(tokens)) {
322 return tokens[0];
323 } else {
324 return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
325 .getResult();
326 }
327 }
328
329 // Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
330 // ops per operand and result. Unique Channel Id's are assigned per transfer.
331 // Sink tokens are created across all `mhlo.send` ops first and then by
332 // all `mhlo.recv` ops.
RewriteHostComputeOp(OpBuilder & builder,int64_t & channel_id,TF::_XlaHostComputeMlirOp host_compute,Value token)333 Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
334 TF::_XlaHostComputeMlirOp host_compute,
335 Value token) {
336 builder.setInsertionPoint(host_compute);
337 Location loc = host_compute.getLoc();
338 int64_t tpu_core = host_compute.tpu_coreAttr().getInt();
339
340 SmallVector<Value, 4> send_tokens;
341 for (auto operand : llvm::enumerate(host_compute.inputs())) {
342 auto send_token =
343 CreateSendOp(builder, channel_id, loc, operand.value(),
344 host_compute.send_key(), operand.index(), tpu_core, token);
345 send_tokens.push_back(send_token);
346 }
347 token = CreateSinkToken(builder, loc, send_tokens, token);
348
349 SmallVector<Value, 4> recv_tokens;
350 for (auto result : llvm::enumerate(host_compute.outputs())) {
351 auto recv_token =
352 CreateRecvOp(builder, channel_id, loc, result.value(),
353 host_compute.recv_key(), result.index(), tpu_core, token);
354 recv_tokens.push_back(recv_token);
355 }
356 token = CreateSinkToken(builder, loc, recv_tokens, token);
357
358 host_compute.erase();
359 return token;
360 }
361
362 // Replaces `tf.XlaSendToHost` with a `mhlo.send`.
RewriteSendToHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaSendToHostOp send_to_host,Value token)363 Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
364 TF::XlaSendToHostOp send_to_host, Value token) {
365 builder.setInsertionPoint(send_to_host);
366 token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
367 send_to_host.input(), send_to_host.key(),
368 /*index=*/0, /*tpu_core=*/llvm::None, token);
369
370 send_to_host.erase();
371 return token;
372 }
373
374 // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
RewriteRecvFromHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaRecvFromHostOp recv_from_host,Value token)375 Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
376 TF::XlaRecvFromHostOp recv_from_host, Value token) {
377 builder.setInsertionPoint(recv_from_host);
378 token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
379 recv_from_host.output(), recv_from_host.key(),
380 /*index=*/0, /*tpu_core=*/llvm::None, token);
381
382 recv_from_host.erase();
383 return token;
384 }
385
386 // Replaces a `mlir::CallOp` with one that has an extra `!mhlo.token` operand
387 // and `!mhlo.token` result. If `new_symbol` is set, the new call will be
388 // updated to call the `new_symbol` instead.
RewriteCallOp(OpBuilder & builder,CallOp call,const Optional<StringRef> & new_symbol,Value token)389 Value RewriteCallOp(OpBuilder& builder, CallOp call,
390 const Optional<StringRef>& new_symbol, Value token) {
391 builder.setInsertionPoint(call);
392 auto new_operands = llvm::to_vector<4>(call.getArgOperands());
393 new_operands.push_back(token);
394 auto new_result_types = llvm::to_vector<4>(call.getResultTypes());
395 new_result_types.push_back(token.getType());
396 auto new_call = builder.create<CallOp>(
397 call.getLoc(), new_result_types, new_symbol ? *new_symbol : call.callee(),
398 new_operands);
399
400 for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
401 std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
402 call.erase();
403 return new_call.getResults().back();
404 }
405
406 // Helper struct holding state of which op to visit to next. If `op` is in a
407 // control flow op region, `region_idx` will be set with the respective region
408 // index. `token` will be current token from the last communication op/control
409 // flow op transitive communication ops.
410 struct OpVisitorState {
411 Optional<unsigned> region_idx;
412 Value token;
413 Operation* op;
414 };
415
416 // Creates a tuple from a sequence of values.
CreateTuple(OpBuilder & builder,Location loc,ArrayRef<Value> operands)417 Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
418 return builder.create<TupleOp>(loc, operands).getResult();
419 }
420
421 // Replaces a value `value` with a new value but the token attached. If `value`
422 // is not a tuple, a new tuple is formed with `token`. If `value` is a tuple,
423 // `value` is extended instead. New tuple values created are cached.
GetValueWithToken(OpBuilder & builder,Value value,Value token,llvm::SmallDenseMap<Value,Value> & rewritten_values)424 Value GetValueWithToken(OpBuilder& builder, Value value, Value token,
425 llvm::SmallDenseMap<Value, Value>& rewritten_values) {
426 // If value with token already exists, reuse it.
427 auto it = rewritten_values.find(value);
428 if (it != rewritten_values.end()) return it->getSecond();
429
430 auto create_tuple = [&](ArrayRef<Value> operands) {
431 auto new_result = CreateTuple(builder, value.getLoc(), operands);
432 rewritten_values.insert({value, new_result});
433 return new_result;
434 };
435
436 auto tuple_type = value.getType().dyn_cast<TupleType>();
437 // `value` is not a tuple, create a new tuple.
438 if (!tuple_type) return create_tuple({value, token});
439
440 // Extend tuple if `value` is a tuple.
441 // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
442 // the tuple.
443 if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
444 auto tuple_operands = llvm::to_vector<4>(tuple_op.getOperands());
445 tuple_operands.push_back(token);
446 return create_tuple(tuple_operands);
447 }
448
449 // `value` is not created via a `mhlo.tuple` directly, unpack individual
450 // elements directly with `mhlo.get_tuple_element`.
451 SmallVector<Value, 4> tuple_operands;
452 for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
453 tuple_operands.push_back(
454 builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
455 .getResult());
456
457 tuple_operands.push_back(token);
458 return create_tuple(tuple_operands);
459 }
460
461 // Extends a type to include a `mhlo.token` type. If `type` is not a tuple type,
462 // a new tuple type with `type` and `mhlo.token` type is created instead.
GetTypeWithToken(OpBuilder & builder,Type type)463 TupleType GetTypeWithToken(OpBuilder& builder, Type type) {
464 auto token_type = TokenType::get(builder.getContext());
465 if (auto tuple_type = type.dyn_cast<TupleType>()) {
466 auto result_types = llvm::to_vector<4>(tuple_type.getTypes());
467 result_types.push_back(token_type);
468 return builder.getTupleType(result_types);
469 }
470
471 return builder.getTupleType({type, token_type});
472 }
473
474 // Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
475 // to `end`, exclusive.
CreateSubTuple(OpBuilder & builder,Value value,size_t end)476 Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
477 SmallVector<Value, 4> tuple_operands;
478 for (auto idx : llvm::seq<int32_t>(0, end))
479 tuple_operands.push_back(
480 builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
481 .getResult());
482
483 return CreateTuple(builder, value.getLoc(), tuple_operands);
484 }
485
486 // Replaces uses of `value` with `replacement`. If `value` is not a tuple type,
487 // an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
488 // return the first element. Otherwise, `mhlo.get_tuple_element` users are
489 // simply updated with `replacement`, and all other users are updated with a
490 // slice of `replacement`.
ReplaceWithTupleResult(OpBuilder & builder,Value value,Value replacement)491 void ReplaceWithTupleResult(OpBuilder& builder, Value value,
492 Value replacement) {
493 auto tuple_type = value.getType().dyn_cast<TupleType>();
494 if (!tuple_type) {
495 if (!value.use_empty()) {
496 auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
497 replacement, 0);
498 value.replaceAllUsesWith(new_element.getResult());
499 }
500 return;
501 }
502
503 Value sub_tuple;
504 for (auto& use : llvm::make_early_inc_range(value.getUses())) {
505 if (isa<GetTupleElementOp>(use.getOwner())) {
506 use.set(replacement);
507 continue;
508 }
509
510 if (!sub_tuple)
511 sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
512
513 use.set(sub_tuple);
514 }
515 }
516
517 // Replaces control flow op block single block argument with new block argument
518 // of type `new_type` (tuple type). The last element of the new block argument
519 // (token) is returned.
UpdateControlFlowBlockArgWithToken(OpBuilder & builder,Block & block,Type token_type)520 Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
521 Type token_type) {
522 assert(block.getNumArguments() == 1);
523 builder.setInsertionPointToStart(&block);
524 auto new_arg = block.addArgument(token_type);
525 ReplaceWithTupleResult(builder, block.getArgument(0), new_arg);
526 block.eraseArgument(0);
527 return builder
528 .create<GetTupleElementOp>(new_arg.getLoc(), new_arg,
529 token_type.cast<TupleType>().size() - 1)
530 .getResult();
531 }
532
533 // Updates control flow op terminator with an extra element `token`. If the
534 // original return value is not a tuple, a new tuple is formed. Otherwise the
535 // tuple is extended.
RewriteControlFlowTerminator(OpBuilder & builder,Operation * terminator,Value token)536 void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
537 Value token) {
538 assert(terminator->getNumOperands() == 1);
539 assert(terminator->getBlock()->getNumArguments() == 1);
540 // `mhlo.while` cond terminator does not need to be rewritten as it always
541 // returns a tensor<i1> predicate value.
542 if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
543 if (terminator->getParentRegion() == &while_parent.cond()) return;
544
545 builder.setInsertionPoint(terminator);
546 llvm::SmallDenseMap<Value, Value> rewritten_operands;
547 Value new_result = GetValueWithToken(builder, terminator->getOperand(0),
548 token, rewritten_operands);
549 terminator->setOperand(0, new_result);
550 }
551
552 // Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. Operands to
553 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionIfOp(OpBuilder & builder,IfOp region_if,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)554 void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
555 SmallVectorImpl<OpVisitorState>& ops_to_visit,
556 Value token) {
557 llvm::SmallDenseMap<Value, Value> rewritten_operands;
558
559 // Rewrite all region operands to have an extra operand `token`.
560 Value new_true_operand = GetValueWithToken(builder, region_if.true_arg(),
561 token, rewritten_operands);
562 Value new_false_operand = GetValueWithToken(builder, region_if.false_arg(),
563 token, rewritten_operands);
564
565 auto new_result_type = GetTypeWithToken(builder, region_if.getType());
566
567 // Create new `mhlo.if` op with extra token operands and result.
568 auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_type,
569 region_if.pred(), new_true_operand,
570 new_false_operand);
571
572 // Move all regions from the old `mhlo.if` op to its replacement.
573 new_if.true_branch().takeBody(region_if.true_branch());
574 new_if.false_branch().takeBody(region_if.false_branch());
575
576 // Forward result from old `mhlo.if` with replacement, and unpack result when
577 // necessary.
578 ReplaceWithTupleResult(builder, region_if.getResult(), new_if.getResult());
579
580 auto new_token = builder.create<GetTupleElementOp>(
581 new_if.getLoc(), new_if.getResult(),
582 new_if.getResult().getType().cast<TupleType>().size() - 1);
583
584 region_if.erase();
585
586 // Remove leftover operands to old `mhlo.if` if they have no uses.
587 for (auto& rewritten_operand : rewritten_operands)
588 if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
589 if (tuple_op.use_empty()) tuple_op.erase();
590
591 // Next op to visit. The replacement is visited but at its first region. The
592 // token result of the new region if is propagated.
593 ops_to_visit.push_back({/*region_idx=*/0, new_token, new_if});
594 }
595
596 // Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
597 // `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
598 // element. If the region block is to be rewritten, the next op to visit is set
599 // to the first op in the block. Otherwise the terminator is updated to forward
600 // `token`.
RewriteControlFlowOpRegion(OpBuilder & builder,Operation * region_op,unsigned region_idx,Type block_arg_type,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)601 void RewriteControlFlowOpRegion(
602 OpBuilder& builder, Operation* region_op, unsigned region_idx,
603 Type block_arg_type, SmallVectorImpl<OpVisitorState>& ops_to_visit,
604 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
605 ops_to_visit.push_back({region_idx + 1, token, region_op});
606
607 Region& region = region_op->getRegion(region_idx);
608 assert(llvm::hasSingleElement(region));
609
610 auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
611 block_arg_type);
612
613 if (control_flow_blocks.contains(®ion.front())) {
614 ops_to_visit.push_back({/*region_idx=*/llvm::None, block_token,
615 block_token.getDefiningOp()->getNextNode()});
616 return;
617 }
618
619 RewriteControlFlowTerminator(builder, region.front().getTerminator(),
620 block_token);
621 }
622
623 // Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
624 // operands and results are rewritten. If `region_idx` is set, region
625 // `region_idx` is rewritten to take in and return an additional token. Returns
626 // true if the op or its region was rewritten.
ProcessRegionIfOp(OpBuilder & builder,IfOp region_if,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)627 bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
628 Optional<unsigned> region_idx,
629 SmallVectorImpl<OpVisitorState>& ops_to_visit,
630 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
631 Value token) {
632 builder.setInsertionPoint(region_if);
633
634 if (!region_idx) {
635 RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
636 return true;
637 }
638
639 if (*region_idx < region_if.getNumRegions()) {
640 RewriteControlFlowOpRegion(builder, region_if, *region_idx,
641 region_if.getOperand(*region_idx + 1).getType(),
642 ops_to_visit, control_flow_blocks, token);
643 return true;
644 }
645
646 return false;
647 }
648
649 // Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
650 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionWhileOp(OpBuilder & builder,WhileOp region_while,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)651 void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
652 SmallVectorImpl<OpVisitorState>& ops_to_visit,
653 Value token) {
654 llvm::SmallDenseMap<Value, Value> rewritten_operands;
655
656 // Rewrite region operand to have an extra operand `token`.
657 // TODO(jpienaar): Support multi-operand while op.
658 Value new_val_operand = GetValueWithToken(builder, region_while.arg()[0],
659 token, rewritten_operands);
660
661 auto new_result_type = GetTypeWithToken(builder, region_while.getType(0));
662
663 // Create new `mhlo.while` op with extra token operand and result.
664 auto new_while = builder.create<WhileOp>(region_while.getLoc(),
665 new_result_type, new_val_operand);
666
667 // Move all regions from the old `mhlo.while` op to its replacement.
668 new_while.cond().takeBody(region_while.cond());
669 new_while.body().takeBody(region_while.body());
670
671 // Forward result from old `mhlo.while` with replacement, and unpack result
672 // when necessary.
673 // TODO(jpienaar): Support multi-operand while op.
674 ReplaceWithTupleResult(builder, region_while.getResult(0),
675 new_while.getResult(0));
676
677 auto new_token = builder.create<GetTupleElementOp>(
678 new_while.getLoc(), new_while.getResult(0),
679 new_while.getType(0).cast<TupleType>().size() - 1);
680
681 region_while.erase();
682
683 // Remove leftover operands to old `mhlo.while` if they have no uses.
684 for (auto& rewritten_operand : rewritten_operands)
685 if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
686 if (tuple_op.use_empty()) tuple_op.erase();
687
688 // Next op to visit. The replacement is visited but at its first region. The
689 // token result of the new region if is propagated.
690 ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
691 }
692
693 // Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
694 // operands and results are rewritten. If `region_idx` is set, region
695 // `region_idx` is rewritten to take in and return an additional token. Returns
696 // true if the op or its region was rewritten.
ProcessRegionWhileOp(OpBuilder & builder,WhileOp region_while,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)697 bool ProcessRegionWhileOp(
698 OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
699 SmallVectorImpl<OpVisitorState>& ops_to_visit,
700 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
701 builder.setInsertionPoint(region_while);
702
703 if (!region_idx) {
704 RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
705 return true;
706 }
707
708 if (*region_idx < region_while.getNumRegions()) {
709 // TODO(jpienaar): Support multi-operand while op.
710 RewriteControlFlowOpRegion(builder, region_while, *region_idx,
711 region_while.arg()[0].getType(), ops_to_visit,
712 control_flow_blocks, token);
713 return true;
714 }
715
716 return false;
717 }
718
719 // Updates function type based on current function body block arguments and
720 // terminator operand types.
UpdateFunctionType(OpBuilder & builder,FuncOp func,Block & func_body)721 void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
722 auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
723 auto new_result_types =
724 llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
725 func.setType(FunctionType::get(builder.getContext(), new_argument_types,
726 new_result_types));
727 }
728
729 // Replaces a function terminator `return` with another `return` that has an
730 // extra `mhlo.token` operand.
RewriteFunctionTerminator(OpBuilder & builder,mlir::ReturnOp terminator,Value token)731 void RewriteFunctionTerminator(OpBuilder& builder, mlir::ReturnOp terminator,
732 Value token) {
733 auto new_results = llvm::to_vector<4>(terminator.getOperands());
734 new_results.push_back(token);
735 builder.setInsertionPoint(terminator);
736 builder.create<mlir::ReturnOp>(terminator.getLoc(), new_results);
737 terminator.erase();
738 }
739
740 // Rewrites a function body and communication ops inside. Region control flow
741 // are updated when necessary, to propagate tokens. The function may either be
742 // rewritten to create a token or take in and return a token, depending on its
743 // visibility and if there are any callers.
RewriteFunction(OpBuilder & builder,int64_t & channel_id,ModuleOp module,FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs,const llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool is_clone)744 LogicalResult RewriteFunction(
745 OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
746 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
747 const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
748 const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
749 MLIRContext* context = module.getContext();
750 if (!llvm::hasSingleElement(func.getBody()))
751 return func.emitError()
752 << "'" << FuncOp::getOperationName()
753 << "' ops with more than one block are not supported";
754
755 bool rewrite_block =
756 is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
757 Block& func_body = func.front();
758
759 builder.setInsertionPointToStart(&func_body);
760 auto token_type = TokenType::get(context);
761 // If a function is public, it's signature should not be modified, and instead
762 // a token will be created. Otherwise a token block argument is inserted.
763 Value init_token =
764 rewrite_block ? func_body.addArgument(token_type)
765 : builder.create<CreateTokenOp>(func.getLoc(), token_type)
766 .getResult();
767
768 // Stack to keep track of region based control flow op nesting and current
769 // op to visit.
770 SmallVector<OpVisitorState, 4> ops_to_visit{
771 {/*region_idx=*/llvm::None, init_token, &func_body.front()}};
772
773 while (!ops_to_visit.empty()) {
774 OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
775 Operation* curr_op = op_to_visit.op;
776
777 Value token = op_to_visit.token;
778 // Ops may be removed, so the next op is kept track of beforehand.
779 Operation* next_op = curr_op->getNextNode();
780
781 if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
782 token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
783 } else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
784 token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
785 } else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
786 token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
787 } else if (auto call = dyn_cast<mlir::CallOp>(curr_op)) {
788 // Only `mlir::CallOp` is supported as this requires knowing how to
789 // rewrite arguments and results to a function.
790 auto it = funcs.find(call.getCallee());
791 if (it != funcs.end()) {
792 FuncOp clone = it->getSecond().clone;
793 Optional<StringRef> symbol_name =
794 clone ? Optional<StringRef>(clone.getName()) : llvm::None;
795 // If the function being called is to be cloned, update the call to also
796 // point to the cloned function.
797 token = RewriteCallOp(builder, call, symbol_name, token);
798 }
799 } else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
800 if (op_to_visit.region_idx || control_flow_ops.contains(region_if))
801 if (ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
802 ops_to_visit, control_flow_blocks, token))
803 continue;
804 } else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
805 if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
806 if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
807 ops_to_visit, control_flow_blocks, token))
808 continue;
809 } else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
810 RewriteControlFlowTerminator(builder, region_terminator, token);
811 // There is no next op afer the control flow op terminator, simply let
812 // stack have one less element.
813 continue;
814 } else if (auto func_terminator = dyn_cast<mlir::ReturnOp>(curr_op)) {
815 if (rewrite_block)
816 RewriteFunctionTerminator(builder, func_terminator, token);
817
818 // There is no next op afer the function terminator, simply let stack have
819 // one less element/be empty.
820 continue;
821 }
822
823 // Visit next op.
824 ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
825 }
826
827 if (rewrite_block) UpdateFunctionType(builder, func, func_body);
828
829 return success();
830 }
831
832 // Checks if a function call is pointing to a function with communication ops.
IsFunctionCallWithCommunication(Operation * op,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)833 bool IsFunctionCallWithCommunication(
834 Operation* op,
835 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
836 if (auto call = dyn_cast<mlir::CallOp>(op))
837 return funcs_to_rewrite.count(call.callee());
838
839 return false;
840 }
841
842 // Collects all control flow op ancestors of communication ops or function calls
843 // with communication ops (transitively).
GetCommunicationControlFlowOps(FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)844 void GetCommunicationControlFlowOps(
845 FuncOp func,
846 const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
847 llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
848 llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
849 func.walk([&](Operation* op) {
850 if (IsCommunicationOp(op) ||
851 IsFunctionCallWithCommunication(op, funcs_to_rewrite))
852 if (failed(GetControlFlowAncestors(op, control_flow_ops,
853 control_flow_blocks)))
854 llvm_unreachable(
855 "checking original function for control flow ancestors should have "
856 "errored first");
857 });
858 }
859
runOnOperation()860 void LegalizeTFCommunication::runOnOperation() {
861 auto module = getOperation();
862 llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
863 if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
864 return signalPassFailure();
865
866 // Module level counter to make sure Channel Id's are unique.
867 int64_t channel_id = 1;
868 OpBuilder builder(&getContext());
869 for (const auto& func_and_name : funcs_to_rewrite) {
870 const auto& func_to_rewrite = func_and_name.getSecond();
871 FuncOp func = func_to_rewrite.original;
872 if (failed(RewriteFunction(builder, channel_id, module, func,
873 funcs_to_rewrite,
874 func_to_rewrite.control_flow_ops,
875 func_to_rewrite.control_flow_blocks,
876 /*is_clone=*/false)))
877 return signalPassFailure();
878
879 FuncOp clone = func_and_name.getSecond().clone;
880 if (!clone) continue;
881 llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
882 llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
883 GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
884 clone_control_flow_ops,
885 clone_control_flow_blocks);
886 if (failed(RewriteFunction(builder, channel_id, module, clone,
887 funcs_to_rewrite, clone_control_flow_ops,
888 clone_control_flow_blocks,
889 /*is_clone=*/true)))
890 llvm_unreachable(
891 "rewriting of original function should have errored first");
892 }
893 }
894
895 static PassRegistration<LegalizeTFCommunication> pass;
896 } // namespace
897
CreateLegalizeTFCommunicationPass()898 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
899 return std::make_unique<LegalizeTFCommunication>();
900 }
901
902 } // namespace mhlo
903 } // namespace mlir
904