• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file implements logic for lowering TensorFlow dialect's communication
17 // ops (TF/XLA) to the HLO dialect.
18 
19 #include <memory>
20 #include <string>
21 
22 #include "llvm/ADT/DenseMap.h"
23 #include "llvm/ADT/None.h"
24 #include "llvm/ADT/STLExtras.h"
25 #include "llvm/ADT/Sequence.h"
26 #include "llvm/Support/ErrorHandling.h"
27 #include "llvm/Support/FormatVariadic.h"
28 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
34 #include "mlir/IR/Value.h"  // from @llvm-project
35 #include "mlir/IR/Visitors.h"  // from @llvm-project
36 #include "mlir/Pass/Pass.h"  // from @llvm-project
37 #include "mlir/Support/LLVM.h"  // from @llvm-project
38 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
39 #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
41 #include "tensorflow/compiler/mlir/xla/type_to_shape.h"
42 #include "tensorflow/compiler/xla/client/sharding_builder.h"
43 #include "tensorflow/compiler/xla/primitive_util.h"
44 
45 namespace mlir {
46 namespace mhlo {
47 
48 namespace {
49 constexpr char kShardingAttr[] = "mhlo.sharding";
50 constexpr char kFrontendAttributesAttr[] = "mhlo.frontend_attributes";
51 const char kXlaHostTransferRendezvousNameAttr[] =
52     "_xla_host_transfer_rendezvous";
53 const char kXlaHostTransferOriginalTypeAttr[] =
54     "_xla_host_transfer_original_type";
55 
56 // A pass that legalizes TF/XLA communication ops, propagate their respective
57 // tokens (for ordering), and rewrite their respective functions and control
58 // flow ops when necessary.
59 // Note, this currently does not handle nested modules/functions or region based
60 // ops other than certain control flow ops (`mhlo.if`, `mhlo.while`).
61 class LegalizeTFCommunication
62     : public PassWrapper<LegalizeTFCommunication, OperationPass<ModuleOp>> {
getDependentDialects(DialectRegistry & registry) const63   void getDependentDialects(DialectRegistry& registry) const override {
64     registry.insert<mhlo::MhloDialect>();
65   }
66 
67  public:
getArgument() const68   StringRef getArgument() const final {
69     return "xla-legalize-tf-communication";
70   }
getDescription() const71   StringRef getDescription() const final {
72     return "Legalize TF/XLA communication ops (TensorFlow dialect) to the HLO "
73            "dialect";
74   }
75   void runOnOperation() override;
76 };
77 
78 // Checks if an op is a TF/XLA communication op.
IsCommunicationOp(Operation * op)79 bool IsCommunicationOp(Operation* op) {
80   return isa<TF::_XlaHostComputeMlirOp, TF::XlaSendToHostOp,
81              TF::XlaRecvFromHostOp>(op);
82 }
83 
84 // Checks if an op is a supported HLO control flow op.
IsControlFlowOp(Operation * op)85 bool IsControlFlowOp(Operation* op) { return isa<IfOp, WhileOp>(op); }
86 
87 // Collects control flow op ancestors of a given op, up until FuncOp. If any
88 // ancestor is not a control flow op or a FuncOp, or of a single block region,
89 // an error will be returned.
GetControlFlowAncestors(Operation * op,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)90 LogicalResult GetControlFlowAncestors(
91     Operation* op, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
92     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
93   Block* block = op->getBlock();
94   Operation* parent = block->getParentOp();
95   while (block && parent && !isa<FuncOp>(parent)) {
96     if (!IsControlFlowOp(parent))
97       return op->emitOpError()
98              << "expects ancestor(s) to be of ['" << IfOp::getOperationName()
99              << "', '" << FuncOp::getOperationName() << "']";
100 
101     if (!llvm::hasSingleElement(block->getParent()->getBlocks()))
102       return op->emitOpError() << "expects single block region ancestor(s)";
103 
104     control_flow_ops.insert(parent);
105     control_flow_blocks.insert(block);
106 
107     parent = block->getParentOp();
108     block = parent->getBlock();
109   }
110   return success();
111 }
112 
113 // Finds communication ops in a function. `control_flow_ops` and
114 // `control_flow_blocks` will be populated with control flow op ancestors for
115 // every communication op.
FindCommunicationOps(FuncOp func,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool & has_communication_ops)116 LogicalResult FindCommunicationOps(
117     FuncOp func, llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
118     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
119     bool& has_communication_ops) {
120   auto result = func.walk([&](Operation* op) {
121     if (!IsCommunicationOp(op)) return WalkResult::advance();
122     has_communication_ops = true;
123     if (failed(
124             GetControlFlowAncestors(op, control_flow_ops, control_flow_blocks)))
125       return WalkResult::interrupt();
126     return WalkResult::advance();
127   });
128   return failure(result.wasInterrupted());
129 }
130 
131 // Helper struct holding a function to be rewritten, it's control flow ops that
132 // lead to a communication op or function call with a communication op
133 // (transitively), and an optional clone of itself. If `clone` is set, function
134 // calls to `original` will be replaced with `clone`.
135 struct FuncToRewrite {
136   FuncOp original;
137   llvm::SmallPtrSet<Operation*, 4> control_flow_ops;
138   llvm::SmallPtrSet<Block*, 4> control_flow_blocks;
139   FuncOp clone;
140 };
141 
142 // Finds all functions that need to be rewritten with communication ops and
143 // and associated tokens.
GetFunctionsToRewrite(ModuleOp module,llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)144 LogicalResult GetFunctionsToRewrite(
145     ModuleOp module,
146     llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
147   // Find functions containing communication ops.
148   SmallVector<FuncOp, 4> funcs_to_visit;
149   for (FuncOp func : module.getOps<FuncOp>()) {
150     FuncToRewrite func_to_rewrite{/*original=*/func, /*control_flow_ops=*/{},
151                                   /*control_flow_blocks=*/{},
152                                   /*clone=*/nullptr};
153     bool has_communication_ops = false;
154     if (failed(FindCommunicationOps(func, func_to_rewrite.control_flow_ops,
155                                     func_to_rewrite.control_flow_blocks,
156                                     has_communication_ops)))
157       return failure();
158 
159     if (!has_communication_ops) continue;
160     funcs_to_rewrite.insert({func.getName(), func_to_rewrite});
161     funcs_to_visit.push_back(func);
162   }
163 
164   // Find functions that call functions with communication ops, transitively.
165   while (!funcs_to_visit.empty()) {
166     SmallVector<FuncOp, 4> new_funcs_to_visit;
167     for (FuncOp& func : funcs_to_visit) {
168       auto uses = func.getSymbolUses(module);
169       if (!uses) continue;
170       for (auto& use : *uses) {
171         // Only `mlir::CallOp` is supported as this requires knowing how to
172         // rewrite arguments and results to a function.
173         if (!isa<mlir::CallOp>(use.getUser())) continue;
174         auto caller_parent_func = use.getUser()->getParentOfType<FuncOp>();
175         if (!caller_parent_func) continue;
176 
177         FuncToRewrite func_to_rewrite{/*original=*/caller_parent_func,
178                                       /*control_flow_ops=*/{},
179                                       /*control_flow_blocks=*/{},
180                                       /*clone=*/nullptr};
181         if (failed(GetControlFlowAncestors(
182                 use.getUser(), func_to_rewrite.control_flow_ops,
183                 func_to_rewrite.control_flow_blocks)))
184           return failure();
185 
186         auto it = funcs_to_rewrite.insert(
187             {caller_parent_func.getName(), func_to_rewrite});
188         if (it.second) {
189           new_funcs_to_visit.push_back(caller_parent_func);
190         } else {
191           it.first->getSecond().control_flow_ops.insert(
192               func_to_rewrite.control_flow_ops.begin(),
193               func_to_rewrite.control_flow_ops.end());
194           it.first->getSecond().control_flow_blocks.insert(
195               func_to_rewrite.control_flow_blocks.begin(),
196               func_to_rewrite.control_flow_blocks.end());
197         }
198       }
199     }
200 
201     funcs_to_visit.swap(new_funcs_to_visit);
202   }
203 
204   // Clone public functions that need to be rewritten. Function calls to this
205   // function will be replaced with the cloned function.
206   SymbolTable symbol_table(module);
207   for (auto& func : funcs_to_rewrite) {
208     if (func.getSecond().original.isPublic() &&
209         !func.getSecond().original.symbolKnownUseEmpty(module)) {
210       auto clone = func.getSecond().original.clone();
211       clone.setPrivate();
212       symbol_table.insert(clone);
213       func.getSecond().clone = clone;
214     }
215   }
216 
217   return success();
218 }
219 
220 // Assigns op sharding to an op for a given device core.
SetOpSharding(Operation * op,int64_t tpu_core)221 void SetOpSharding(Operation* op, int64_t tpu_core) {
222   std::string sharding_serialized =
223       ::xla::sharding_builder::AssignDevice(tpu_core).SerializeAsString();
224   op->setAttr(kShardingAttr,
225               StringAttr::get(op->getContext(), sharding_serialized));
226 }
227 
228 // Assigns frontend attributes holding information about data type and
229 // TensorFlow rendezvous channel name. The TensorFlow rendezvous channel name is
230 // handled differently as individual names are used per data send and receive.
SetFrontendAttributes(Operation * op,int32_t index,StringRef key,Type type,bool device_to_host)231 void SetFrontendAttributes(Operation* op, int32_t index, StringRef key,
232                            Type type, bool device_to_host) {
233   MLIRContext* context = op->getContext();
234 
235   std::string formatted_key =
236       device_to_host ? llvm::formatv("{0}_dtoh_{1}", key, index).str()
237                      : llvm::formatv("{0}_htod_{1}", key, index).str();
238 
239   auto rendezvous_name = StringAttr::get(context, formatted_key);
240   auto rendezvous_name_attr = NamedAttribute(
241       Identifier::get(kXlaHostTransferRendezvousNameAttr, context),
242       rendezvous_name);
243 
244   auto element_type = getElementTypeOrSelf(type);
245   auto xla_element_type = ::xla::TypeToPrimitiveType(element_type);
246   const std::string& xla_element_type_str =
247       ::xla::primitive_util::LowercasePrimitiveTypeName(xla_element_type);
248   auto original_type = StringAttr::get(context, xla_element_type_str);
249   auto original_type_attr =
250       NamedAttribute(Identifier::get(kXlaHostTransferOriginalTypeAttr, context),
251                      original_type);
252 
253   auto frontend_attributes = DictionaryAttr::get(
254       context,
255       ArrayRef<NamedAttribute>{rendezvous_name_attr, original_type_attr});
256   op->setAttr(kFrontendAttributesAttr, frontend_attributes);
257 }
258 
259 // Creates a `mhlo.send` op for sending value `operand`. If `tpu_core` is set,
260 // op sharding for the respective device will be set.
CreateSendOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value operand,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)261 Value CreateSendOp(OpBuilder& builder, int64_t& channel_id, Location loc,
262                    Value operand, StringRef key, size_t index,
263                    const Optional<int64_t>& tpu_core, Value token) {
264   // type 2 == DEVICE_TO_HOST
265   auto channel_handle = ChannelHandle::get(
266       /*handle=*/builder.getI64IntegerAttr(channel_id++),
267       /*type=*/builder.getI64IntegerAttr(2), builder.getContext());
268   auto send = builder.create<SendOp>(
269       loc, token.getType(), operand, token, channel_handle,
270       /*is_host_transfer=*/builder.getBoolAttr(true));
271 
272   SetFrontendAttributes(send, index, key, operand.getType(),
273                         /*device_to_host=*/true);
274 
275   if (tpu_core) SetOpSharding(send, *tpu_core);
276 
277   return send.getResult();
278 }
279 
280 // Creates a `mhlo.recv` op for receiving a value. If `tpu_core` is set, op
281 // sharding for the respective device will be set.
CreateRecvOp(OpBuilder & builder,int64_t & channel_id,Location loc,Value result,StringRef key,size_t index,const Optional<int64_t> & tpu_core,Value token)282 Value CreateRecvOp(OpBuilder& builder, int64_t& channel_id, Location loc,
283                    Value result, StringRef key, size_t index,
284                    const Optional<int64_t>& tpu_core, Value token) {
285   // type 3 == HOST_TO_DEVICE
286   auto channel_handle = ChannelHandle::get(
287       /*handle=*/builder.getI64IntegerAttr(channel_id++),
288       /*type=*/builder.getI64IntegerAttr(3), builder.getContext());
289   auto result_type = result.getType();
290   auto recv_result_type =
291       TupleType::get(builder.getContext(), {result_type, token.getType()});
292   auto recv =
293       builder.create<RecvOp>(loc, recv_result_type, token, channel_handle,
294                              /*is_host_transfer=*/builder.getBoolAttr(true));
295 
296   SetFrontendAttributes(recv, index, key, result_type,
297                         /*device_to_host=*/false);
298 
299   if (tpu_core) SetOpSharding(recv, *tpu_core);
300 
301   auto get_tuple_element =
302       builder.create<GetTupleElementOp>(loc, recv.getResult(), /*index=*/0);
303   if (tpu_core) SetOpSharding(get_tuple_element, *tpu_core);
304 
305   result.replaceAllUsesWith(get_tuple_element);
306 
307   auto new_token = builder.create<GetTupleElementOp>(loc, recv.getResult(),
308                                                      /*index=*/1);
309   if (tpu_core) SetOpSharding(new_token, *tpu_core);
310 
311   return new_token.getResult();
312 }
313 
314 // Creates a new token if necessary, acting as a sink to previous tokens. If
315 // there is only one token in `tokens`, the only token is returned. If `tokens`
316 // is empty, `original_token` is returned instead.
CreateSinkToken(OpBuilder & builder,Location loc,ArrayRef<Value> tokens,Value original_token)317 Value CreateSinkToken(OpBuilder& builder, Location loc, ArrayRef<Value> tokens,
318                       Value original_token) {
319   if (tokens.empty()) {
320     return original_token;
321   } else if (llvm::hasSingleElement(tokens)) {
322     return tokens[0];
323   } else {
324     return builder.create<AfterAllOp>(loc, original_token.getType(), tokens)
325         .getResult();
326   }
327 }
328 
329 // Replaces `tf._XlaHostComputeMlir` with individual `mhlo.send` and `mhlo.recv`
330 // ops per operand and result. Unique Channel Id's are assigned per transfer.
331 // Sink tokens are created across all `mhlo.send` ops first and then by
332 // all `mhlo.recv` ops.
RewriteHostComputeOp(OpBuilder & builder,int64_t & channel_id,TF::_XlaHostComputeMlirOp host_compute,Value token)333 Value RewriteHostComputeOp(OpBuilder& builder, int64_t& channel_id,
334                            TF::_XlaHostComputeMlirOp host_compute,
335                            Value token) {
336   builder.setInsertionPoint(host_compute);
337   Location loc = host_compute.getLoc();
338   int64_t tpu_core = host_compute.tpu_coreAttr().getInt();
339 
340   SmallVector<Value, 4> send_tokens;
341   for (auto operand : llvm::enumerate(host_compute.inputs())) {
342     auto send_token =
343         CreateSendOp(builder, channel_id, loc, operand.value(),
344                      host_compute.send_key(), operand.index(), tpu_core, token);
345     send_tokens.push_back(send_token);
346   }
347   token = CreateSinkToken(builder, loc, send_tokens, token);
348 
349   SmallVector<Value, 4> recv_tokens;
350   for (auto result : llvm::enumerate(host_compute.outputs())) {
351     auto recv_token =
352         CreateRecvOp(builder, channel_id, loc, result.value(),
353                      host_compute.recv_key(), result.index(), tpu_core, token);
354     recv_tokens.push_back(recv_token);
355   }
356   token = CreateSinkToken(builder, loc, recv_tokens, token);
357 
358   host_compute.erase();
359   return token;
360 }
361 
362 // Replaces `tf.XlaSendToHost` with a `mhlo.send`.
RewriteSendToHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaSendToHostOp send_to_host,Value token)363 Value RewriteSendToHostOp(OpBuilder& builder, int64_t& channel_id,
364                           TF::XlaSendToHostOp send_to_host, Value token) {
365   builder.setInsertionPoint(send_to_host);
366   token = CreateSendOp(builder, channel_id, send_to_host.getLoc(),
367                        send_to_host.input(), send_to_host.key(),
368                        /*index=*/0, /*tpu_core=*/llvm::None, token);
369 
370   send_to_host.erase();
371   return token;
372 }
373 
374 // Replaces `tf.XlaRecvFromHost` with a `mhlo.recv`.
RewriteRecvFromHostOp(OpBuilder & builder,int64_t & channel_id,TF::XlaRecvFromHostOp recv_from_host,Value token)375 Value RewriteRecvFromHostOp(OpBuilder& builder, int64_t& channel_id,
376                             TF::XlaRecvFromHostOp recv_from_host, Value token) {
377   builder.setInsertionPoint(recv_from_host);
378   token = CreateRecvOp(builder, channel_id, recv_from_host.getLoc(),
379                        recv_from_host.output(), recv_from_host.key(),
380                        /*index=*/0, /*tpu_core=*/llvm::None, token);
381 
382   recv_from_host.erase();
383   return token;
384 }
385 
386 // Replaces a `mlir::CallOp` with one that has an extra `!mhlo.token` operand
387 // and `!mhlo.token` result. If `new_symbol` is set, the new call will be
388 // updated to call the `new_symbol` instead.
RewriteCallOp(OpBuilder & builder,CallOp call,const Optional<StringRef> & new_symbol,Value token)389 Value RewriteCallOp(OpBuilder& builder, CallOp call,
390                     const Optional<StringRef>& new_symbol, Value token) {
391   builder.setInsertionPoint(call);
392   auto new_operands = llvm::to_vector<4>(call.getArgOperands());
393   new_operands.push_back(token);
394   auto new_result_types = llvm::to_vector<4>(call.getResultTypes());
395   new_result_types.push_back(token.getType());
396   auto new_call = builder.create<CallOp>(
397       call.getLoc(), new_result_types, new_symbol ? *new_symbol : call.callee(),
398       new_operands);
399 
400   for (auto results : llvm::zip(call.getResults(), new_call.getResults()))
401     std::get<0>(results).replaceAllUsesWith(std::get<1>(results));
402   call.erase();
403   return new_call.getResults().back();
404 }
405 
406 // Helper struct holding state of which op to visit to next. If `op` is in a
407 // control flow op region, `region_idx` will be set with the respective region
408 // index. `token` will be current token from the last communication op/control
409 // flow op transitive communication ops.
410 struct OpVisitorState {
411   Optional<unsigned> region_idx;
412   Value token;
413   Operation* op;
414 };
415 
416 // Creates a tuple from a sequence of values.
CreateTuple(OpBuilder & builder,Location loc,ArrayRef<Value> operands)417 Value CreateTuple(OpBuilder& builder, Location loc, ArrayRef<Value> operands) {
418   return builder.create<TupleOp>(loc, operands).getResult();
419 }
420 
421 // Replaces a value `value` with a new value but the token attached. If `value`
422 // is not a tuple, a new tuple is formed with `token`. If `value` is a tuple,
423 // `value` is extended instead. New tuple values created are cached.
GetValueWithToken(OpBuilder & builder,Value value,Value token,llvm::SmallDenseMap<Value,Value> & rewritten_values)424 Value GetValueWithToken(OpBuilder& builder, Value value, Value token,
425                         llvm::SmallDenseMap<Value, Value>& rewritten_values) {
426   // If value with token already exists, reuse it.
427   auto it = rewritten_values.find(value);
428   if (it != rewritten_values.end()) return it->getSecond();
429 
430   auto create_tuple = [&](ArrayRef<Value> operands) {
431     auto new_result = CreateTuple(builder, value.getLoc(), operands);
432     rewritten_values.insert({value, new_result});
433     return new_result;
434   };
435 
436   auto tuple_type = value.getType().dyn_cast<TupleType>();
437   // `value` is not a tuple, create a new tuple.
438   if (!tuple_type) return create_tuple({value, token});
439 
440   // Extend tuple if `value` is a tuple.
441   // If `value` is an op result and the owner is a `mhlo.tuple`, simply unpack
442   // the tuple.
443   if (auto tuple_op = value.getDefiningOp<TupleOp>()) {
444     auto tuple_operands = llvm::to_vector<4>(tuple_op.getOperands());
445     tuple_operands.push_back(token);
446     return create_tuple(tuple_operands);
447   }
448 
449   // `value` is not created via a `mhlo.tuple` directly, unpack individual
450   // elements directly with `mhlo.get_tuple_element`.
451   SmallVector<Value, 4> tuple_operands;
452   for (auto idx : llvm::seq<int32_t>(0, tuple_type.getTypes().size()))
453     tuple_operands.push_back(
454         builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
455             .getResult());
456 
457   tuple_operands.push_back(token);
458   return create_tuple(tuple_operands);
459 }
460 
461 // Extends a type to include a `mhlo.token` type. If `type` is not a tuple type,
462 // a new tuple type with `type` and `mhlo.token` type is created instead.
GetTypeWithToken(OpBuilder & builder,Type type)463 TupleType GetTypeWithToken(OpBuilder& builder, Type type) {
464   auto token_type = TokenType::get(builder.getContext());
465   if (auto tuple_type = type.dyn_cast<TupleType>()) {
466     auto result_types = llvm::to_vector<4>(tuple_type.getTypes());
467     result_types.push_back(token_type);
468     return builder.getTupleType(result_types);
469   }
470 
471   return builder.getTupleType({type, token_type});
472 }
473 
474 // Creates a slice of a tuple `value` with `mhlo.get_tuple_element` from index 0
475 // to `end`, exclusive.
CreateSubTuple(OpBuilder & builder,Value value,size_t end)476 Value CreateSubTuple(OpBuilder& builder, Value value, size_t end) {
477   SmallVector<Value, 4> tuple_operands;
478   for (auto idx : llvm::seq<int32_t>(0, end))
479     tuple_operands.push_back(
480         builder.create<GetTupleElementOp>(value.getLoc(), value, idx)
481             .getResult());
482 
483   return CreateTuple(builder, value.getLoc(), tuple_operands);
484 }
485 
486 // Replaces uses of `value` with `replacement`. If `value` is not a tuple type,
487 // an explicit `mhlo.get_tuple_element` is created to unpack the tuple and
488 // return the first element. Otherwise, `mhlo.get_tuple_element` users are
489 // simply updated with `replacement`, and all other users are updated with a
490 // slice of `replacement`.
ReplaceWithTupleResult(OpBuilder & builder,Value value,Value replacement)491 void ReplaceWithTupleResult(OpBuilder& builder, Value value,
492                             Value replacement) {
493   auto tuple_type = value.getType().dyn_cast<TupleType>();
494   if (!tuple_type) {
495     if (!value.use_empty()) {
496       auto new_element = builder.create<GetTupleElementOp>(replacement.getLoc(),
497                                                            replacement, 0);
498       value.replaceAllUsesWith(new_element.getResult());
499     }
500     return;
501   }
502 
503   Value sub_tuple;
504   for (auto& use : llvm::make_early_inc_range(value.getUses())) {
505     if (isa<GetTupleElementOp>(use.getOwner())) {
506       use.set(replacement);
507       continue;
508     }
509 
510     if (!sub_tuple)
511       sub_tuple = CreateSubTuple(builder, replacement, tuple_type.size());
512 
513     use.set(sub_tuple);
514   }
515 }
516 
517 // Replaces control flow op block single block argument with new block argument
518 // of type `new_type` (tuple type). The last element of the new block argument
519 // (token) is returned.
UpdateControlFlowBlockArgWithToken(OpBuilder & builder,Block & block,Type token_type)520 Value UpdateControlFlowBlockArgWithToken(OpBuilder& builder, Block& block,
521                                          Type token_type) {
522   assert(block.getNumArguments() == 1);
523   builder.setInsertionPointToStart(&block);
524   auto new_arg = block.addArgument(token_type);
525   ReplaceWithTupleResult(builder, block.getArgument(0), new_arg);
526   block.eraseArgument(0);
527   return builder
528       .create<GetTupleElementOp>(new_arg.getLoc(), new_arg,
529                                  token_type.cast<TupleType>().size() - 1)
530       .getResult();
531 }
532 
533 // Updates control flow op terminator with an extra element `token`. If the
534 // original return value is not a tuple, a new tuple is formed. Otherwise the
535 // tuple is extended.
RewriteControlFlowTerminator(OpBuilder & builder,Operation * terminator,Value token)536 void RewriteControlFlowTerminator(OpBuilder& builder, Operation* terminator,
537                                   Value token) {
538   assert(terminator->getNumOperands() == 1);
539   assert(terminator->getBlock()->getNumArguments() == 1);
540   // `mhlo.while` cond terminator does not need to be rewritten as it always
541   // returns a tensor<i1> predicate value.
542   if (auto while_parent = dyn_cast_or_null<WhileOp>(terminator->getParentOp()))
543     if (terminator->getParentRegion() == &while_parent.cond()) return;
544 
545   builder.setInsertionPoint(terminator);
546   llvm::SmallDenseMap<Value, Value> rewritten_operands;
547   Value new_result = GetValueWithToken(builder, terminator->getOperand(0),
548                                        token, rewritten_operands);
549   terminator->setOperand(0, new_result);
550 }
551 
552 // Rewrites a `mhlo.if` op to receive and forward a `mhlo.token`. Operands to
553 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionIfOp(OpBuilder & builder,IfOp region_if,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)554 void RewriteRegionIfOp(OpBuilder& builder, IfOp region_if,
555                        SmallVectorImpl<OpVisitorState>& ops_to_visit,
556                        Value token) {
557   llvm::SmallDenseMap<Value, Value> rewritten_operands;
558 
559   // Rewrite all region operands to have an extra operand `token`.
560   Value new_true_operand = GetValueWithToken(builder, region_if.true_arg(),
561                                              token, rewritten_operands);
562   Value new_false_operand = GetValueWithToken(builder, region_if.false_arg(),
563                                               token, rewritten_operands);
564 
565   auto new_result_type = GetTypeWithToken(builder, region_if.getType());
566 
567   // Create new `mhlo.if` op with extra token operands and result.
568   auto new_if = builder.create<IfOp>(region_if.getLoc(), new_result_type,
569                                      region_if.pred(), new_true_operand,
570                                      new_false_operand);
571 
572   // Move all regions from the old `mhlo.if` op to its replacement.
573   new_if.true_branch().takeBody(region_if.true_branch());
574   new_if.false_branch().takeBody(region_if.false_branch());
575 
576   // Forward result from old `mhlo.if` with replacement, and unpack result when
577   // necessary.
578   ReplaceWithTupleResult(builder, region_if.getResult(), new_if.getResult());
579 
580   auto new_token = builder.create<GetTupleElementOp>(
581       new_if.getLoc(), new_if.getResult(),
582       new_if.getResult().getType().cast<TupleType>().size() - 1);
583 
584   region_if.erase();
585 
586   // Remove leftover operands to old `mhlo.if` if they have no uses.
587   for (auto& rewritten_operand : rewritten_operands)
588     if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
589       if (tuple_op.use_empty()) tuple_op.erase();
590 
591   // Next op to visit. The replacement is visited but at its first region. The
592   // token result of the new region if is propagated.
593   ops_to_visit.push_back({/*region_idx=*/0, new_token, new_if});
594 }
595 
596 // Rewrites a `mhlo.if`/`mhlo.while` region to receive and forward a
597 // `mhlo.token`. The block argument is updated to have an extra `mhlo.token`
598 // element. If the region block is to be rewritten, the next op to visit is set
599 // to the first op in the block. Otherwise the terminator is updated to forward
600 // `token`.
RewriteControlFlowOpRegion(OpBuilder & builder,Operation * region_op,unsigned region_idx,Type block_arg_type,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)601 void RewriteControlFlowOpRegion(
602     OpBuilder& builder, Operation* region_op, unsigned region_idx,
603     Type block_arg_type, SmallVectorImpl<OpVisitorState>& ops_to_visit,
604     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
605   ops_to_visit.push_back({region_idx + 1, token, region_op});
606 
607   Region& region = region_op->getRegion(region_idx);
608   assert(llvm::hasSingleElement(region));
609 
610   auto block_token = UpdateControlFlowBlockArgWithToken(builder, region.front(),
611                                                         block_arg_type);
612 
613   if (control_flow_blocks.contains(&region.front())) {
614     ops_to_visit.push_back({/*region_idx=*/llvm::None, block_token,
615                             block_token.getDefiningOp()->getNextNode()});
616     return;
617   }
618 
619   RewriteControlFlowTerminator(builder, region.front().getTerminator(),
620                                block_token);
621 }
622 
623 // Rewrites an `mhlo.if` op or its region. If `region_idx` is not set, the op
624 // operands and results are rewritten. If `region_idx` is set, region
625 // `region_idx` is rewritten to take in and return an additional token. Returns
626 // true if the op or its region was rewritten.
ProcessRegionIfOp(OpBuilder & builder,IfOp region_if,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)627 bool ProcessRegionIfOp(OpBuilder& builder, IfOp region_if,
628                        Optional<unsigned> region_idx,
629                        SmallVectorImpl<OpVisitorState>& ops_to_visit,
630                        const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks,
631                        Value token) {
632   builder.setInsertionPoint(region_if);
633 
634   if (!region_idx) {
635     RewriteRegionIfOp(builder, region_if, ops_to_visit, token);
636     return true;
637   }
638 
639   if (*region_idx < region_if.getNumRegions()) {
640     RewriteControlFlowOpRegion(builder, region_if, *region_idx,
641                                region_if.getOperand(*region_idx + 1).getType(),
642                                ops_to_visit, control_flow_blocks, token);
643     return true;
644   }
645 
646   return false;
647 }
648 
649 // Rewrites a `mhlo.while` op to receive and forward a `mhlo.token`. Operands to
650 // the op for all of its regions are extended to have an extra operand `token`.
RewriteRegionWhileOp(OpBuilder & builder,WhileOp region_while,SmallVectorImpl<OpVisitorState> & ops_to_visit,Value token)651 void RewriteRegionWhileOp(OpBuilder& builder, WhileOp region_while,
652                           SmallVectorImpl<OpVisitorState>& ops_to_visit,
653                           Value token) {
654   llvm::SmallDenseMap<Value, Value> rewritten_operands;
655 
656   // Rewrite region operand to have an extra operand `token`.
657   // TODO(jpienaar): Support multi-operand while op.
658   Value new_val_operand = GetValueWithToken(builder, region_while.arg()[0],
659                                             token, rewritten_operands);
660 
661   auto new_result_type = GetTypeWithToken(builder, region_while.getType(0));
662 
663   // Create new `mhlo.while` op with extra token operand and result.
664   auto new_while = builder.create<WhileOp>(region_while.getLoc(),
665                                            new_result_type, new_val_operand);
666 
667   // Move all regions from the old `mhlo.while` op to its replacement.
668   new_while.cond().takeBody(region_while.cond());
669   new_while.body().takeBody(region_while.body());
670 
671   // Forward result from old `mhlo.while` with replacement, and unpack result
672   // when necessary.
673   // TODO(jpienaar): Support multi-operand while op.
674   ReplaceWithTupleResult(builder, region_while.getResult(0),
675                          new_while.getResult(0));
676 
677   auto new_token = builder.create<GetTupleElementOp>(
678       new_while.getLoc(), new_while.getResult(0),
679       new_while.getType(0).cast<TupleType>().size() - 1);
680 
681   region_while.erase();
682 
683   // Remove leftover operands to old `mhlo.while` if they have no uses.
684   for (auto& rewritten_operand : rewritten_operands)
685     if (auto tuple_op = rewritten_operand.getFirst().getDefiningOp<TupleOp>())
686       if (tuple_op.use_empty()) tuple_op.erase();
687 
688   // Next op to visit. The replacement is visited but at its first region. The
689   // token result of the new region if is propagated.
690   ops_to_visit.push_back({/*region_idx=*/0, new_token, new_while});
691 }
692 
693 // Rewrites an `mhlo.while` op or its region. If `region_idx` is not set, the op
694 // operands and results are rewritten. If `region_idx` is set, region
695 // `region_idx` is rewritten to take in and return an additional token. Returns
696 // true if the op or its region was rewritten.
ProcessRegionWhileOp(OpBuilder & builder,WhileOp region_while,Optional<unsigned> region_idx,SmallVectorImpl<OpVisitorState> & ops_to_visit,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,Value token)697 bool ProcessRegionWhileOp(
698     OpBuilder& builder, WhileOp region_while, Optional<unsigned> region_idx,
699     SmallVectorImpl<OpVisitorState>& ops_to_visit,
700     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, Value token) {
701   builder.setInsertionPoint(region_while);
702 
703   if (!region_idx) {
704     RewriteRegionWhileOp(builder, region_while, ops_to_visit, token);
705     return true;
706   }
707 
708   if (*region_idx < region_while.getNumRegions()) {
709     // TODO(jpienaar): Support multi-operand while op.
710     RewriteControlFlowOpRegion(builder, region_while, *region_idx,
711                                region_while.arg()[0].getType(), ops_to_visit,
712                                control_flow_blocks, token);
713     return true;
714   }
715 
716   return false;
717 }
718 
719 // Updates function type based on current function body block arguments and
720 // terminator operand types.
UpdateFunctionType(OpBuilder & builder,FuncOp func,Block & func_body)721 void UpdateFunctionType(OpBuilder& builder, FuncOp func, Block& func_body) {
722   auto new_argument_types = llvm::to_vector<4>(func_body.getArgumentTypes());
723   auto new_result_types =
724       llvm::to_vector<4>(func_body.getTerminator()->getOperandTypes());
725   func.setType(FunctionType::get(builder.getContext(), new_argument_types,
726                                  new_result_types));
727 }
728 
729 // Replaces a function terminator `return` with another `return` that has an
730 // extra `mhlo.token` operand.
RewriteFunctionTerminator(OpBuilder & builder,mlir::ReturnOp terminator,Value token)731 void RewriteFunctionTerminator(OpBuilder& builder, mlir::ReturnOp terminator,
732                                Value token) {
733   auto new_results = llvm::to_vector<4>(terminator.getOperands());
734   new_results.push_back(token);
735   builder.setInsertionPoint(terminator);
736   builder.create<mlir::ReturnOp>(terminator.getLoc(), new_results);
737   terminator.erase();
738 }
739 
740 // Rewrites a function body and communication ops inside. Region control flow
741 // are updated when necessary, to propagate tokens. The function may either be
742 // rewritten to create a token or take in and return a token, depending on its
743 // visibility and if there are any callers.
RewriteFunction(OpBuilder & builder,int64_t & channel_id,ModuleOp module,FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs,const llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,const llvm::SmallPtrSetImpl<Block * > & control_flow_blocks,bool is_clone)744 LogicalResult RewriteFunction(
745     OpBuilder& builder, int64_t& channel_id, ModuleOp module, FuncOp func,
746     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs,
747     const llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
748     const llvm::SmallPtrSetImpl<Block*>& control_flow_blocks, bool is_clone) {
749   MLIRContext* context = module.getContext();
750   if (!llvm::hasSingleElement(func.getBody()))
751     return func.emitError()
752            << "'" << FuncOp::getOperationName()
753            << "' ops with more than one block are not supported";
754 
755   bool rewrite_block =
756       is_clone || (!func.isPublic() && !func.symbolKnownUseEmpty(module));
757   Block& func_body = func.front();
758 
759   builder.setInsertionPointToStart(&func_body);
760   auto token_type = TokenType::get(context);
761   // If a function is public, it's signature should not be modified, and instead
762   // a token will be created. Otherwise a token block argument is inserted.
763   Value init_token =
764       rewrite_block ? func_body.addArgument(token_type)
765                     : builder.create<CreateTokenOp>(func.getLoc(), token_type)
766                           .getResult();
767 
768   // Stack to keep track of region based control flow op nesting and current
769   // op to visit.
770   SmallVector<OpVisitorState, 4> ops_to_visit{
771       {/*region_idx=*/llvm::None, init_token, &func_body.front()}};
772 
773   while (!ops_to_visit.empty()) {
774     OpVisitorState op_to_visit = ops_to_visit.pop_back_val();
775     Operation* curr_op = op_to_visit.op;
776 
777     Value token = op_to_visit.token;
778     // Ops may be removed, so the next op is kept track of beforehand.
779     Operation* next_op = curr_op->getNextNode();
780 
781     if (auto host_compute = dyn_cast<TF::_XlaHostComputeMlirOp>(curr_op)) {
782       token = RewriteHostComputeOp(builder, channel_id, host_compute, token);
783     } else if (auto send_to_host = dyn_cast<TF::XlaSendToHostOp>(curr_op)) {
784       token = RewriteSendToHostOp(builder, channel_id, send_to_host, token);
785     } else if (auto recv_from_host = dyn_cast<TF::XlaRecvFromHostOp>(curr_op)) {
786       token = RewriteRecvFromHostOp(builder, channel_id, recv_from_host, token);
787     } else if (auto call = dyn_cast<mlir::CallOp>(curr_op)) {
788       // Only `mlir::CallOp` is supported as this requires knowing how to
789       // rewrite arguments and results to a function.
790       auto it = funcs.find(call.getCallee());
791       if (it != funcs.end()) {
792         FuncOp clone = it->getSecond().clone;
793         Optional<StringRef> symbol_name =
794             clone ? Optional<StringRef>(clone.getName()) : llvm::None;
795         // If the function being called is to be cloned, update the call to also
796         // point to the cloned function.
797         token = RewriteCallOp(builder, call, symbol_name, token);
798       }
799     } else if (auto region_if = dyn_cast<IfOp>(curr_op)) {
800       if (op_to_visit.region_idx || control_flow_ops.contains(region_if))
801         if (ProcessRegionIfOp(builder, region_if, op_to_visit.region_idx,
802                               ops_to_visit, control_flow_blocks, token))
803           continue;
804     } else if (auto region_while = dyn_cast<WhileOp>(curr_op)) {
805       if (op_to_visit.region_idx || control_flow_ops.contains(region_while))
806         if (ProcessRegionWhileOp(builder, region_while, op_to_visit.region_idx,
807                                  ops_to_visit, control_flow_blocks, token))
808           continue;
809     } else if (auto region_terminator = dyn_cast<mhlo::ReturnOp>(curr_op)) {
810       RewriteControlFlowTerminator(builder, region_terminator, token);
811       // There is no next op afer the control flow op terminator, simply let
812       // stack have one less element.
813       continue;
814     } else if (auto func_terminator = dyn_cast<mlir::ReturnOp>(curr_op)) {
815       if (rewrite_block)
816         RewriteFunctionTerminator(builder, func_terminator, token);
817 
818       // There is no next op afer the function terminator, simply let stack have
819       // one less element/be empty.
820       continue;
821     }
822 
823     // Visit next op.
824     ops_to_visit.push_back({/*region_idx=*/llvm::None, token, next_op});
825   }
826 
827   if (rewrite_block) UpdateFunctionType(builder, func, func_body);
828 
829   return success();
830 }
831 
832 // Checks if a function call is pointing to a function with communication ops.
IsFunctionCallWithCommunication(Operation * op,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite)833 bool IsFunctionCallWithCommunication(
834     Operation* op,
835     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite) {
836   if (auto call = dyn_cast<mlir::CallOp>(op))
837     return funcs_to_rewrite.count(call.callee());
838 
839   return false;
840 }
841 
842 // Collects all control flow op ancestors of communication ops or function calls
843 // with communication ops (transitively).
GetCommunicationControlFlowOps(FuncOp func,const llvm::SmallDenseMap<StringRef,FuncToRewrite> & funcs_to_rewrite,llvm::SmallPtrSetImpl<Operation * > & control_flow_ops,llvm::SmallPtrSetImpl<Block * > & control_flow_blocks)844 void GetCommunicationControlFlowOps(
845     FuncOp func,
846     const llvm::SmallDenseMap<StringRef, FuncToRewrite>& funcs_to_rewrite,
847     llvm::SmallPtrSetImpl<Operation*>& control_flow_ops,
848     llvm::SmallPtrSetImpl<Block*>& control_flow_blocks) {
849   func.walk([&](Operation* op) {
850     if (IsCommunicationOp(op) ||
851         IsFunctionCallWithCommunication(op, funcs_to_rewrite))
852       if (failed(GetControlFlowAncestors(op, control_flow_ops,
853                                          control_flow_blocks)))
854         llvm_unreachable(
855             "checking original function for control flow ancestors should have "
856             "errored first");
857   });
858 }
859 
runOnOperation()860 void LegalizeTFCommunication::runOnOperation() {
861   auto module = getOperation();
862   llvm::SmallDenseMap<StringRef, FuncToRewrite> funcs_to_rewrite;
863   if (failed(GetFunctionsToRewrite(module, funcs_to_rewrite)))
864     return signalPassFailure();
865 
866   // Module level counter to make sure Channel Id's are unique.
867   int64_t channel_id = 1;
868   OpBuilder builder(&getContext());
869   for (const auto& func_and_name : funcs_to_rewrite) {
870     const auto& func_to_rewrite = func_and_name.getSecond();
871     FuncOp func = func_to_rewrite.original;
872     if (failed(RewriteFunction(builder, channel_id, module, func,
873                                funcs_to_rewrite,
874                                func_to_rewrite.control_flow_ops,
875                                func_to_rewrite.control_flow_blocks,
876                                /*is_clone=*/false)))
877       return signalPassFailure();
878 
879     FuncOp clone = func_and_name.getSecond().clone;
880     if (!clone) continue;
881     llvm::SmallPtrSet<Operation*, 4> clone_control_flow_ops;
882     llvm::SmallPtrSet<Block*, 4> clone_control_flow_blocks;
883     GetCommunicationControlFlowOps(clone, funcs_to_rewrite,
884                                    clone_control_flow_ops,
885                                    clone_control_flow_blocks);
886     if (failed(RewriteFunction(builder, channel_id, module, clone,
887                                funcs_to_rewrite, clone_control_flow_ops,
888                                clone_control_flow_blocks,
889                                /*is_clone=*/true)))
890       llvm_unreachable(
891           "rewriting of original function should have errored first");
892   }
893 }
894 
895 static PassRegistration<LegalizeTFCommunication> pass;
896 }  // namespace
897 
CreateLegalizeTFCommunicationPass()898 std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTFCommunicationPass() {
899   return std::make_unique<LegalizeTFCommunication>();
900 }
901 
902 }  // namespace mhlo
903 }  // namespace mlir
904