1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/OperationSupport.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
26 #include "mlir/Rewrite/PatternApplicator.h" // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37
38 namespace mlir {
39 namespace TFDevice {
40
41 namespace {
42
43 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
44 constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement";
45
46 auto* auto_outside_compilation_gauge =
47 tensorflow::monitoring::Gauge<bool, 0>::New(
48 "/tensorflow/core/use_auto_outside_compilation",
49 "Tracks if auto outside compilation is enabled");
50
51 struct MarkOpsForOutsideCompilation
52 : public TF::MarkOpsForOutsideCompilationPassBase<
53 MarkOpsForOutsideCompilation> {
54 void runOnOperation() override;
55 };
56
57 // Adds any canonicalization patterns to list of supported `patterns`.
58 // TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and
59 // remove this.
AddCanonicalizationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)60 void AddCanonicalizationPatterns(MLIRContext* context,
61 OwningRewritePatternList* patterns) {
62 for (auto* op : context->getRegisteredOperations())
63 op->getCanonicalizationPatterns(*patterns, context);
64 }
65
66 // Adds the list of ops that are supported on TPU through constant folding which
67 // may depend on the inputs shapes not known at this point. Such ops may not
68 // have any legalization or canonicalization patterns but shouldn't be marked
69 // for outside compilation.
70 //
71 // TODO(b/177523289): Remove manual handling once we support constant folding
72 // and shape inference through the computation on the host side.
AddSupportedOpsUsingFolding(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)73 void AddSupportedOpsUsingFolding(MLIRContext* context,
74 llvm::DenseSet<OperationName>* supported_ops) {
75 llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
76 OperationName(TF::BroadcastArgsOp::getOperationName(), context),
77 OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context),
78 OperationName(TF::ConcatOffsetOp::getOperationName(), context),
79 OperationName(TF::EmptyOp::getOperationName(), context),
80 OperationName(TF::ListDiffOp::getOperationName(), context),
81 OperationName(TF::RankOp::getOperationName(), context),
82 OperationName(TF::RangeOp::getOperationName(), context),
83 OperationName(TF::ShapeOp::getOperationName(), context),
84 OperationName(TF::ShapeNOp::getOperationName(), context),
85 OperationName(TF::SizeOp::getOperationName(), context),
86 };
87
88 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
89 }
90
91 // Adds the list of ops that are supported through dynamic padder using op by op
92 // fallback to the TF2XLA bridge.
93 // TODO(b/168036682): Remove this once ops are supported using dynamic padder
94 // on MLIR bridge.
AddSupportedOpsUsingDynamicPadder(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)95 void AddSupportedOpsUsingDynamicPadder(
96 MLIRContext* context, llvm::DenseSet<OperationName>* supported_ops) {
97 llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
98 OperationName(TF::WhereOp::getOperationName(), context),
99 OperationName(TF::UniqueOp::getOperationName(), context),
100 OperationName(TF::XlaSetDynamicDimensionSizeOp::getOperationName(),
101 context),
102 };
103
104 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
105 }
106
107 // TODO(b/159128666): Check the control flow legalization passes instead once
108 // added.
AddSupportedFunctionalOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)109 void AddSupportedFunctionalOps(MLIRContext* context,
110 llvm::DenseSet<OperationName>* supported_ops) {
111 supported_ops->insert(
112 OperationName(TF::CaseRegionOp::getOperationName(), context));
113 supported_ops->insert(
114 OperationName(TF::IfRegionOp::getOperationName(), context));
115 supported_ops->insert(
116 OperationName(TF::InplaceAddOp::getOperationName(), context));
117 supported_ops->insert(
118 OperationName(TF::WhileRegionOp::getOperationName(), context));
119 supported_ops->insert(
120 OperationName(TF::XlaReduceOp::getOperationName(), context));
121 supported_ops->insert(
122 OperationName(TF::XlaReduceWindowOp::getOperationName(), context));
123 supported_ops->insert(
124 OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context));
125 supported_ops->insert(
126 OperationName(TF::XlaScatterOp::getOperationName(), context));
127 supported_ops->insert(
128 OperationName(TF::XlaSelectAndScatterOp::getOperationName(), context));
129 supported_ops->insert(
130 OperationName(TF::SymbolicGradientOp::getOperationName(), context));
131 supported_ops->insert(
132 OperationName(TF::XlaVariadicReduceOp::getOperationName(), context));
133 supported_ops->insert(
134 OperationName(TF::XlaVariadicReduceV2Op::getOperationName(), context));
135 supported_ops->insert(
136 OperationName(TF::XlaVariadicSortOp::getOperationName(), context));
137 supported_ops->insert(
138 OperationName(TF::XlaReplicaIdOp::getOperationName(), context));
139 supported_ops->insert(
140 OperationName(TF::YieldOp::getOperationName(), context));
141 }
142
143 // These embedding ops are rewritten when running TPUCompileOp.
AddRewrittenEmbeddingOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)144 void AddRewrittenEmbeddingOps(MLIRContext* context,
145 llvm::DenseSet<OperationName>* supported_ops) {
146 supported_ops->insert(OperationName(
147 TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context));
148 supported_ops->insert(OperationName(
149 TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
150 }
151
152 // Stack, TensorList and TensorArray ops are rewritten during the second phase
153 // of the bridge (compilation of TPUCompile op). They would not match any
154 // legalization/canonicalization pattern and have to be manually added to the
155 // list of supported ops.
AddRewrittenCompositeOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)156 void AddRewrittenCompositeOps(MLIRContext* context,
157 llvm::DenseSet<OperationName>* supported_ops) {
158 #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
159 llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
160 // Stack ops.
161 GET_OPERATION_NAME(TF::StackV2Op),
162 GET_OPERATION_NAME(TF::StackPushV2Op),
163 GET_OPERATION_NAME(TF::StackPopV2Op),
164 // Tensor Array ops.
165 GET_OPERATION_NAME(TF::TensorArrayV3Op),
166 GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
167 GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
168 GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
169 GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
170 GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
171 GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
172 GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
173 GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
174 // Tensor List Ops.
175 GET_OPERATION_NAME(TF::EmptyTensorListOp),
176 GET_OPERATION_NAME(TF::TensorListReserveOp),
177 GET_OPERATION_NAME(TF::TensorListFromTensorOp),
178 GET_OPERATION_NAME(TF::TensorListPushBackOp),
179 GET_OPERATION_NAME(TF::TensorListPopBackOp),
180 GET_OPERATION_NAME(TF::TensorListGetItemOp),
181 GET_OPERATION_NAME(TF::TensorListSetItemOp),
182 GET_OPERATION_NAME(TF::TensorListLengthOp),
183 GET_OPERATION_NAME(TF::TensorListElementShapeOp),
184 GET_OPERATION_NAME(TF::TensorListGatherOp),
185 GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
186 GET_OPERATION_NAME(TF::TensorListStackOp),
187 };
188 #undef GET_OPERATION_NAME
189
190 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
191 }
192
IsStringType(Type type)193 bool IsStringType(Type type) {
194 if (type.isa<TF::StringType>()) return true;
195
196 auto sub_type = type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
197 if (!sub_type) return false;
198
199 bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) {
200 return type.getElementType().isa<TF::StringType>();
201 });
202 return has_string;
203 }
204
HasStringOperand(Operation & op)205 bool HasStringOperand(Operation& op) {
206 for (auto operand : op.getOperands()) {
207 auto operand_type = getElementTypeOrSelf(operand);
208 if (IsStringType(operand_type)) return true;
209 }
210 return false;
211 }
212
HasStringResult(Operation & op)213 bool HasStringResult(Operation& op) {
214 for (auto result : op.getResults()) {
215 auto result_type = getElementTypeOrSelf(result);
216 if (IsStringType(result_type)) return true;
217 }
218 return false;
219 }
220
MatchesPattern(Operation & op,const llvm::DenseSet<OperationName> & supported_ops)221 bool MatchesPattern(Operation& op,
222 const llvm::DenseSet<OperationName>& supported_ops) {
223 return (supported_ops.contains(op.getName()));
224 }
225
226 // Checks if the op is supported inside of a device cluster. Ops not
227 // in `tf_dialect` are considered supported.
IsSupportedOp(Operation & op,const llvm::DenseSet<OperationName> & supported_ops,const Dialect * tf_dialect)228 bool IsSupportedOp(Operation& op,
229 const llvm::DenseSet<OperationName>& supported_ops,
230 const Dialect* tf_dialect) {
231 if (op.getDialect() != tf_dialect)
232 return true;
233 // Assert has a legalization that later removes it so we don't want to outside
234 // compile it ever for performance reasons.
235 if (llvm::isa<TF::AssertOp>(op)) return true;
236 return !HasStringOperand(op) && !HasStringResult(op) &&
237 (MatchesPattern(op, supported_ops) ||
238 mhlo::IsOpAllowedTf2XlaFallback(&op));
239 }
240
241 // Checks all regions of `op` for captured string operands.
HasCapturedStringOperand(Operation * op)242 bool HasCapturedStringOperand(Operation* op) {
243 bool string_operand = false;
244 for (auto& region : op->getRegions()) {
245 mlir::visitUsedValuesDefinedAbove(
246 region, region, [&](mlir::OpOperand* operand) {
247 if (getElementTypeOrSelf(operand->get()).isa<TF::StringType>())
248 string_operand = true;
249 });
250 if (string_operand) return string_operand;
251 }
252 return string_operand;
253 }
254
IsVariant(Value value)255 bool IsVariant(Value value) {
256 return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>();
257 }
258
HasOutsideCompiledAncestor(Operation * op)259 bool HasOutsideCompiledAncestor(Operation* op) {
260 Operation* parent = op->getParentOp();
261 while (parent) {
262 if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
263 return true;
264 parent = parent->getParentOp();
265 }
266 return false;
267 }
268
269 // If any tf.variants are inputs/outputs to the another outside compiled
270 // Operation, `op`, mark them for outside compilation unless they are already
271 // marks with outside compilation attribute.
MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster)272 void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) {
273 std::queue<Operation*> outside_compiled_ops;
274 tpu_cluster.walk([&](Operation* op) {
275 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
276 outside_compiled_ops.push(op);
277 });
278
279 while (!outside_compiled_ops.empty()) {
280 Operation* host_op = outside_compiled_ops.front();
281 outside_compiled_ops.pop();
282 host_op->walk([&](Operation* op) {
283 // Add any operations that provide variant inputs to the cluster.
284 for (auto value : op->getOperands()) {
285 Operation* input_defining_op = value.getDefiningOp();
286 if (IsVariant(value) && input_defining_op &&
287 !HasOutsideCompiledAncestor(input_defining_op) &&
288 !input_defining_op->hasAttrOfType<StringAttr>(
289 kXlaOutsideCompilationAttr)) {
290 input_defining_op->setAttr(
291 kXlaOutsideCompilationAttr,
292 StringAttr::get(input_defining_op->getContext(), "auto"));
293 outside_compiled_ops.push(input_defining_op);
294 }
295 }
296 // Mark for outside compilation any operations that consume variant
297 // outputs from an outside compiled operation.
298 for (auto value : op->getResults()) {
299 if (IsVariant(value)) {
300 for (auto user : value.getUsers()) {
301 if (!user->hasTrait<OpTrait::IsTerminator>() &&
302 !HasOutsideCompiledAncestor(user) &&
303 !user->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
304 user->setAttr(kXlaOutsideCompilationAttr,
305 StringAttr::get(user->getContext(), "auto"));
306 outside_compiled_ops.push(user);
307 }
308 }
309 }
310 }
311 });
312 }
313 }
314
315 // Marks uncompilable ops that are in `tf_dialect` for outside compilation.
MarkUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)316 LogicalResult MarkUncompilableOps(
317 const Dialect* tf_dialect, Block* block,
318 llvm::DenseSet<OperationName>& supported_ops) {
319 // Automatically marked ops for outside compilation have
320 // `_xla_outside_compilation` attribute value of "auto" plus
321 // an increasing counter. Manually marked ops for outside compilation only
322 // have an increasing counteri for the attribute value. Therefore there is no
323 // collision in
324 // `_xla_outside_compilation` attribute between automatically and manually
325 // marking ops.
326 int outside_compiled_cluster_counter = 0;
327 block->walk([&](Operation* op) {
328 if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
329 VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
330 << " isn't compilable, adding outside_compilation attr. "
331 "This op will automatically be placed on CPU.";
332 op->setAttr(kXlaOutsideCompilationAttr,
333 StringAttr::get(
334 op->getContext(),
335 llvm::formatv("auto{0}", outside_compiled_cluster_counter)
336 .str()));
337 outside_compiled_cluster_counter++;
338 }
339 });
340 if (outside_compiled_cluster_counter > 0) {
341 auto_outside_compilation_gauge->GetCell()->Set(true);
342 }
343 return success();
344 }
345
346 // Check for uncompilable ops that are in `tf_dialect` and are not already
347 // marked for outside compilation.
ContainsUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)348 bool ContainsUncompilableOps(const Dialect* tf_dialect, Block* block,
349 llvm::DenseSet<OperationName>& supported_ops) {
350 int uncompilable_op_count = 0;
351 // Check if op or any parent is already marked for outside compilation.
352 block->walk([&](Operation* op) {
353 Operation* iter_op = op;
354 while (iter_op && !llvm::isa<tf_device::ClusterOp>(iter_op)) {
355 if (iter_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
356 return;
357 }
358 iter_op = iter_op->getParentOp();
359 }
360
361 if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
362 op->emitOpError() << "isn't compilable for TPU device. enable "
363 "soft_device_placement option to run on CPU";
364 ++uncompilable_op_count;
365 }
366 });
367 return uncompilable_op_count > 0;
368 }
369
370 // Unmarks outside compilation for any op that has parents already
371 // marked for outside compilation since the child will be extracted
372 // anyways.
UnmarkChildren(Block * block)373 void UnmarkChildren(Block* block) {
374 block->walk([&](Operation* op) {
375 if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
376 Operation* iter_op = op;
377 bool remove_attr = false;
378 while (auto* parent_op = iter_op->getParentOp()) {
379 if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
380 remove_attr = true;
381 break;
382 }
383 iter_op = parent_op;
384 }
385 if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
386 });
387 }
388
runOnOperation()389 void MarkOpsForOutsideCompilation::runOnOperation() {
390 auto module = getOperation();
391 const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
392 if (!tf_dialect) {
393 getOperation().emitError() << "'tf' dialect is not registered";
394 return signalPassFailure();
395 }
396 OwningRewritePatternList patterns(&getContext());
397 mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
398 TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns);
399 TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns);
400 AddCanonicalizationPatterns(module.getContext(), &patterns);
401
402 // `supported_ops` contains the name of all of the ops that can potentially be
403 // lowered into HLO on the device. This doesn't always mean that the op can
404 // be lowered in the future passes but if the op is not in this set, it can't
405 // be lowered in a subsequent pass.
406 llvm::DenseSet<OperationName> supported_ops;
407 PatternApplicator(std::move(patterns))
408 .walkAllPatterns([&](const Pattern& pattern) {
409 Optional<OperationName> root_kind = pattern.getRootKind();
410 if (root_kind.hasValue()) supported_ops.insert(root_kind.getValue());
411 });
412 AddSupportedFunctionalOps(module.getContext(), &supported_ops);
413 AddSupportedOpsUsingFolding(module.getContext(), &supported_ops);
414 AddSupportedOpsUsingDynamicPadder(module.getContext(), &supported_ops);
415 AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
416 AddRewrittenCompositeOps(module.getContext(), &supported_ops);
417
418 auto result = module.walk([&](tf_device::ClusterOp cluster) {
419 // Only if `allow_soft_placement` attribute is true should we mark ops
420 // for outside compilation.
421 auto soft_placement_attr =
422 cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
423 if ((soft_placement_attr && soft_placement_attr.getValue())) {
424 if (failed(MarkUncompilableOps(tf_dialect, &cluster.GetBody(),
425 supported_ops)))
426 return WalkResult::interrupt();
427 } else {
428 if (ContainsUncompilableOps(tf_dialect, &cluster.GetBody(),
429 supported_ops))
430 return WalkResult::interrupt();
431 }
432 MarkVariantInputsOutputs(cluster);
433
434 return WalkResult::advance();
435 });
436
437 if (result.wasInterrupted()) return signalPassFailure();
438
439 module.walk([&](tf_device::ClusterOp cluster) {
440 // Only if `allow_soft_placement` attribute is true should we unmark ops
441 // for outside compilation.
442 auto soft_placement_attr =
443 cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
444 if (!(soft_placement_attr && soft_placement_attr.getValue())) {
445 return;
446 }
447 UnmarkChildren(&cluster.GetBody());
448 });
449 }
450
451 } // namespace
452
453 std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass()454 CreateMarkOpsForOutsideCompilationPass() {
455 return std::make_unique<MarkOpsForOutsideCompilation>();
456 }
457
458 } // namespace TFDevice
459 } // namespace mlir
460