• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "absl/container/flat_hash_set.h"
17 #include "llvm/ADT/StringRef.h"
18 #include "mlir/IR/Attributes.h"  // from @llvm-project
19 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
20 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
21 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
22 #include "mlir/IR/PatternMatch.h"  // from @llvm-project
23 #include "mlir/IR/Region.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
28 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_remaining_ops.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
33 #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h"
34 #include "tensorflow/compiler/tf2xla/side_effect_util.h"
35 
36 namespace mlir {
37 namespace TF {
38 namespace {
39 
40 // Returns true if the given op is TF/XLA communication op in the old bridge.
IsCommunicationOp(Operation * op)41 bool IsCommunicationOp(Operation* op) {
42   return isa<TF::XlaHostComputeOp, TF::XlaSendToHostOp, TF::XlaRecvFromHostOp>(
43       op);
44 }
45 
46 // Returns true if the given op is one of ops supported to have communication
47 // subcomputation in the TF/XLA bridge.
SupportsCommunicationComputation(Operation * op)48 bool SupportsCommunicationComputation(Operation* op) {
49   return isa<TF::IfRegionOp, TF::WhileRegionOp, TF::CaseRegionOp,
50              TF::StatefulPartitionedCallOp, TF::PartitionedCallOp,
51              TF::LegacyCallOp>(op);
52 }
53 
54 class PrepareTpuComputationForTfExportPass
55     : public PrepareTpuComputationForTfExportPassBase<
56           PrepareTpuComputationForTfExportPass> {
57   void runOnOperation() override;
58 };
59 
60 class RewriteXlaHostComputeMlir
61     : public OpRewritePattern<TF::_XlaHostComputeMlirOp> {
62  public:
63   using OpRewritePattern<TF::_XlaHostComputeMlirOp>::OpRewritePattern;
64 
matchAndRewrite(TF::_XlaHostComputeMlirOp op,PatternRewriter & rewriter) const65   LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op,
66                                 PatternRewriter& rewriter) const override {
67     llvm::SmallVector<Attribute> shape_attrs;
68     shape_attrs.reserve(op.getNumResults());
69     for (Type ty : op.getResultTypes()) {
70       shape_attrs.push_back(
71           TF::ShapeAttr::get(rewriter.getContext(), ty.cast<ShapedType>()));
72     }
73 
74     // Clone the `host_func` in the `host_mlir_module` attribute if it exists
75     // and use it for `shape_inference_graph` attribute on XlaHostCompute.
76     FuncOp cloned_func;
77     SymbolTable manager(op->getParentOfType<ModuleOp>());
78     StringRef host_module = op.host_mlir_module();
79     if (!host_module.empty()) {
80       mlir::OwningModuleRef module_for_func;
81 
82       FuncOp func = op.GetHostFunc(&module_for_func);
83 
84       OpBuilder::InsertionGuard guard(rewriter);
85       rewriter.setInsertionPointAfter(op->getParentOfType<FuncOp>());
86       cloned_func =
87           llvm::dyn_cast_or_null<FuncOp>(rewriter.clone(*func.getOperation()));
88       manager.insert(cloned_func);
89       rewriter.setInsertionPointToStart(&cloned_func.body().front());
90       auto result_type =
91           RankedTensorType::get({3}, rewriter.getType<TF::StringType>());
92       auto dynamic_key =
93           rewriter.create<TF::_TPUCompileMlirPlaceholderProgramKeyOp>(
94               func.getLoc(), /*program=*/result_type, llvm::ArrayRef<Value>{});
95 
96       auto recv_at_host = rewriter.create<TF::_XlaRecvAtHostOp>(
97           func.getLoc(), op.getOperandTypes(), /*dynamic_key=*/dynamic_key,
98           op.send_keyAttr(),
99           /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
100       for (auto result :
101            llvm::zip(cloned_func.getArguments(), recv_at_host->getResults())) {
102         std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
103       }
104 
105       rewriter.setInsertionPoint(cloned_func.body().front().getTerminator());
106       rewriter.create<TF::_XlaSendFromHostOp>(
107           func.getLoc(),
108           cloned_func.body().front().getTerminator()->getOperands(),
109           /*dynamic_key=*/dynamic_key, op.recv_keyAttr(),
110           /*device_ordinal=*/rewriter.getI64IntegerAttr(0));
111     }
112 
113     constexpr int64_t kDefaultCostEstimate = 1000000;
114     rewriter.replaceOpWithNewOp<TF::XlaHostComputeOp>(
115         op, op.getResultTypes(), op.inputs(),
116         /*ancestors=*/rewriter.getArrayAttr({}),
117         rewriter.getArrayAttr(shape_attrs),
118         /*shape_inference_graph=*/
119         cloned_func ? rewriter.getSymbolRefAttr(cloned_func) : SymbolRefAttr(),
120         /*key=*/rewriter.getStringAttr(""), op.send_keyAttr(),
121         op.recv_keyAttr(),
122         /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate),
123         op.tpu_coreAttr());
124     return success();
125   }
126 };
127 
UpdateArgAttributes(mlir::FuncOp func)128 void UpdateArgAttributes(mlir::FuncOp func) {
129   OpBuilder builder(func.getBody());
130   for (int i = 0; i < func.getNumArguments(); ++i) {
131     constexpr char kShardingAttr[] = "mhlo.sharding";
132     if (auto sharding =
133             func.getArgAttrOfType<mlir::StringAttr>(i, kShardingAttr)) {
134       if (!sharding.getValue().empty()) {
135         BlockArgument arg = func.getArgument(i);
136         // TODO(hinsu): Instead of setting both 'sharding' and '_XlaSharding'
137         // attributes, only set the 'sharding' attribute. Both attributes are
138         // currently required as the XlaSharding xla op kernel doesn't use the
139         // 'sharding' attribute.
140         auto updated_arg = builder.create<TF::XlaShardingOp>(
141             func.getLoc(), arg.getType(), arg, sharding, sharding);
142         func.getArgument(i).replaceAllUsesExcept(
143             updated_arg, llvm::SmallPtrSet<Operation*, 1>({updated_arg}));
144       }
145 
146       func.removeArgAttr(i, builder.getIdentifier(kShardingAttr));
147     }
148   }
149 }
150 
RewriteCommunicationOps(ModuleOp module)151 LogicalResult RewriteCommunicationOps(ModuleOp module) {
152   MLIRContext* ctx = module.getContext();
153   mlir::OwningRewritePatternList patterns(ctx);
154   patterns.insert<RewriteXlaHostComputeMlir>(ctx);
155   if (failed(mlir::applyPatternsAndFoldGreedily(module, std::move(patterns)))) {
156     return module.emitError("failed to apply tf export preparation patterns");
157   }
158 
159   // TODO(hinsu): Investigate if the semantics of keys for these communication
160   // ops between the old bridge and new bridge can be reconciled.
161   module.walk([&](Operation* op) {
162     if (isa<TF::XlaSendToHostOp>(op)) {
163       StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
164       auto new_key = StringAttr::get(ctx, old_key.str() + "_dtoh_0");
165       op->setAttr("key", new_key);
166     } else if (isa<TF::XlaRecvFromHostOp>(op)) {
167       StringRef old_key = op->getAttrOfType<StringAttr>("key").getValue();
168       auto new_key = StringAttr::get(ctx, old_key.str() + "_htod_0");
169       op->setAttr("key", new_key);
170     }
171   });
172   return success();
173 }
174 
175 // Sets token input node names attribute and their corresponding original node
176 // names for tf/xla communication related ops. These attributes are used to
177 // order operations on device. First op in the region should have a special
178 // argument token and then remaining operations should have node name of the
179 // previous communication ops.
SetTokenInputAttrs(ModuleOp module)180 LogicalResult SetTokenInputAttrs(ModuleOp module) {
181   // Collect all the ops that needs to have token input names attributes. These
182   // ops are communication ops and all their parent ops via nesting or function
183   // calls. For example, IfRegion op and PartitionedCall op.
184   std::vector<Operation*> worklist;
185   absl::flat_hash_set<Operation*> ops_with_tokens;
186   module.walk([&](Operation* op) {
187     if (IsCommunicationOp(op)) {
188       ops_with_tokens.insert(op);
189       worklist.push_back(op);
190     }
191   });
192 
193   SymbolTableCollection table;
194   SymbolUserMap symbol_map(table, module);
195 
196   // Regions that contains ops requiring token input attributes.
197   absl::flat_hash_set<Region*> regions_with_token;
198   while (!worklist.empty()) {
199     Operation* op = worklist.back();
200     worklist.pop_back();
201 
202     Region* region = op->getParentRegion();
203     regions_with_token.insert(region);
204 
205     // If the parent is not a FuncOp, then add the parent op containing a region
206     // to worklist.
207     Operation* parent = region->getParentOp();
208     if (!isa<FuncOp>(parent)) {
209       if (ops_with_tokens.insert(parent).second) {
210         worklist.push_back(parent);
211       }
212       continue;
213     }
214 
215     // For functions, get all the users and add them to the worklist.
216     for (auto& user : symbol_map.getUsers(parent)) {
217       if (ops_with_tokens.insert(user).second) {
218         worklist.push_back(user);
219       }
220     }
221   }
222 
223   // Use name mapper to uniquely name all ops in the module as export to
224   // TensorFlow graph may change node names. These op names here doesn't need to
225   // match the actual names in the graph as this sets original node name
226   // attribute for all the relevant nodes.
227   tensorflow::OpOrArgLocNameMapper name_mapper;
228   MLIRContext* ctx = module.getContext();
229   for (Region* region : regions_with_token) {
230     // Initialize the token with the special argument token. This gets mapped to
231     // input token in the parent op or a new token for the entry computation.
232     auto token = StringAttr::get(ctx, tensorflow::kXlaTokenArgNodeName);
233     for (Operation& op : region->getOps()) {
234       // Only communication related ops that needs to have token should have the
235       // extra attribute.
236       if (!ops_with_tokens.contains(&op)) continue;
237 
238       if (!IsCommunicationOp(&op) && !SupportsCommunicationComputation(&op)) {
239         return op.emitOpError(
240             "does not support subcomputations with tf/xla communication ops");
241       }
242 
243       op.setAttr(tensorflow::kXlaTokenInputNodesAttrName,
244                  ArrayAttr::get(ctx, {token}));
245 
246       auto node_name = StringAttr::get(ctx, name_mapper.GetUniqueName(&op));
247       op.setAttr(tensorflow::kXlaOriginalOutsideCompilationNodeName, node_name);
248       token = node_name;
249     }
250   }
251   return success();
252 }
253 
runOnOperation()254 void PrepareTpuComputationForTfExportPass::runOnOperation() {
255   ModuleOp module = getOperation();
256 
257   for (FuncOp func : module.getOps<FuncOp>()) {
258     UpdateArgAttributes(func);
259   }
260 
261   // First rewrite communication ops used in the new bridge to match old bridge
262   // semantics and then set token input node names attributes on the supported
263   // ops.
264   if (failed(RewriteCommunicationOps(module)) ||
265       failed(SetTokenInputAttrs(module))) {
266     signalPassFailure();
267     return;
268   }
269 }
270 
271 }  // namespace
272 
273 std::unique_ptr<OperationPass<ModuleOp>>
CreatePrepareTpuComputationForTfExportPass()274 CreatePrepareTpuComputationForTfExportPass() {
275   return std::make_unique<PrepareTpuComputationForTfExportPass>();
276 }
277 
278 }  // namespace TF
279 }  // namespace mlir
280