• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
26 #include "mlir/Rewrite/PatternApplicator.h"  // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37 
38 namespace mlir {
39 namespace TFDevice {
40 
41 namespace {
42 
43 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
44 constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement";
45 
46 auto* auto_outside_compilation_gauge =
47     tensorflow::monitoring::Gauge<bool, 0>::New(
48         "/tensorflow/core/use_auto_outside_compilation",
49         "Tracks if auto outside compilation is enabled");
50 
51 struct MarkOpsForOutsideCompilation
52     : public TF::MarkOpsForOutsideCompilationPassBase<
53           MarkOpsForOutsideCompilation> {
54   void runOnOperation() override;
55 };
56 
57 // Adds any canonicalization patterns to list of supported `patterns`.
58 // TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and
59 // remove this.
AddCanonicalizationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)60 void AddCanonicalizationPatterns(MLIRContext* context,
61                                  OwningRewritePatternList* patterns) {
62   for (auto* op : context->getRegisteredOperations())
63     op->getCanonicalizationPatterns(*patterns, context);
64 }
65 
66 // Adds the list of ops that are supported on TPU through constant folding which
67 // may depend on the inputs shapes not known at this point. Such ops may not
68 // have any legalization or canonicalization patterns but shouldn't be marked
69 // for outside compilation.
70 //
71 // TODO(b/177523289): Remove manual handling once we support constant folding
72 // and shape inference through the computation on the host side.
AddSupportedOpsUsingFolding(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)73 void AddSupportedOpsUsingFolding(MLIRContext* context,
74                                  llvm::DenseSet<OperationName>* supported_ops) {
75   llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
76       OperationName(TF::BroadcastArgsOp::getOperationName(), context),
77       OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context),
78       OperationName(TF::ConcatOffsetOp::getOperationName(), context),
79       OperationName(TF::EmptyOp::getOperationName(), context),
80       OperationName(TF::ListDiffOp::getOperationName(), context),
81       OperationName(TF::RankOp::getOperationName(), context),
82       OperationName(TF::RangeOp::getOperationName(), context),
83       OperationName(TF::ShapeOp::getOperationName(), context),
84       OperationName(TF::ShapeNOp::getOperationName(), context),
85       OperationName(TF::SizeOp::getOperationName(), context),
86   };
87 
88   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
89 }
90 
91 // Adds the list of ops that are supported through dynamic padder using op by op
92 // fallback to the TF2XLA bridge.
93 // TODO(b/168036682): Remove this once ops are supported using dynamic padder
94 // on MLIR bridge.
AddSupportedOpsUsingDynamicPadder(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)95 void AddSupportedOpsUsingDynamicPadder(
96     MLIRContext* context, llvm::DenseSet<OperationName>* supported_ops) {
97   llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
98       OperationName(TF::WhereOp::getOperationName(), context),
99       OperationName(TF::UniqueOp::getOperationName(), context),
100       OperationName(TF::XlaSetDynamicDimensionSizeOp::getOperationName(),
101                     context),
102   };
103 
104   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
105 }
106 
107 // TODO(b/159128666): Check the control flow legalization passes instead once
108 // added.
AddSupportedFunctionalOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)109 void AddSupportedFunctionalOps(MLIRContext* context,
110                                llvm::DenseSet<OperationName>* supported_ops) {
111   supported_ops->insert(
112       OperationName(TF::CaseRegionOp::getOperationName(), context));
113   supported_ops->insert(
114       OperationName(TF::IfRegionOp::getOperationName(), context));
115   supported_ops->insert(
116       OperationName(TF::InplaceAddOp::getOperationName(), context));
117   supported_ops->insert(
118       OperationName(TF::WhileRegionOp::getOperationName(), context));
119   supported_ops->insert(
120       OperationName(TF::XlaReduceOp::getOperationName(), context));
121   supported_ops->insert(
122       OperationName(TF::XlaReduceWindowOp::getOperationName(), context));
123   supported_ops->insert(
124       OperationName(TF::XlaRngBitGeneratorOp::getOperationName(), context));
125   supported_ops->insert(
126       OperationName(TF::XlaScatterOp::getOperationName(), context));
127   supported_ops->insert(
128       OperationName(TF::XlaSelectAndScatterOp::getOperationName(), context));
129   supported_ops->insert(
130       OperationName(TF::SymbolicGradientOp::getOperationName(), context));
131   supported_ops->insert(
132       OperationName(TF::XlaVariadicReduceOp::getOperationName(), context));
133   supported_ops->insert(
134       OperationName(TF::XlaVariadicReduceV2Op::getOperationName(), context));
135   supported_ops->insert(
136       OperationName(TF::XlaVariadicSortOp::getOperationName(), context));
137   supported_ops->insert(
138       OperationName(TF::XlaReplicaIdOp::getOperationName(), context));
139   supported_ops->insert(
140       OperationName(TF::YieldOp::getOperationName(), context));
141 }
142 
143 // These embedding ops are rewritten when running TPUCompileOp.
AddRewrittenEmbeddingOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)144 void AddRewrittenEmbeddingOps(MLIRContext* context,
145                               llvm::DenseSet<OperationName>* supported_ops) {
146   supported_ops->insert(OperationName(
147       TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context));
148   supported_ops->insert(OperationName(
149       TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
150 }
151 
152 // Stack, TensorList and TensorArray ops are rewritten during the second phase
153 // of the bridge (compilation of TPUCompile op). They would not match any
154 // legalization/canonicalization pattern and have to be manually added to the
155 // list of supported ops.
AddRewrittenCompositeOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)156 void AddRewrittenCompositeOps(MLIRContext* context,
157                               llvm::DenseSet<OperationName>* supported_ops) {
158 #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
159   llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
160       // Stack ops.
161       GET_OPERATION_NAME(TF::StackV2Op),
162       GET_OPERATION_NAME(TF::StackPushV2Op),
163       GET_OPERATION_NAME(TF::StackPopV2Op),
164       // Tensor Array ops.
165       GET_OPERATION_NAME(TF::TensorArrayV3Op),
166       GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
167       GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
168       GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
169       GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
170       GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
171       GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
172       GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
173       GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
174       // Tensor List Ops.
175       GET_OPERATION_NAME(TF::EmptyTensorListOp),
176       GET_OPERATION_NAME(TF::TensorListReserveOp),
177       GET_OPERATION_NAME(TF::TensorListFromTensorOp),
178       GET_OPERATION_NAME(TF::TensorListPushBackOp),
179       GET_OPERATION_NAME(TF::TensorListPopBackOp),
180       GET_OPERATION_NAME(TF::TensorListGetItemOp),
181       GET_OPERATION_NAME(TF::TensorListSetItemOp),
182       GET_OPERATION_NAME(TF::TensorListLengthOp),
183       GET_OPERATION_NAME(TF::TensorListElementShapeOp),
184       GET_OPERATION_NAME(TF::TensorListGatherOp),
185       GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
186       GET_OPERATION_NAME(TF::TensorListStackOp),
187   };
188 #undef GET_OPERATION_NAME
189 
190   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
191 }
192 
IsStringType(Type type)193 bool IsStringType(Type type) {
194   if (type.isa<TF::StringType>()) return true;
195 
196   auto sub_type = type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
197   if (!sub_type) return false;
198 
199   bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) {
200     return type.getElementType().isa<TF::StringType>();
201   });
202   return has_string;
203 }
204 
HasStringOperand(Operation & op)205 bool HasStringOperand(Operation& op) {
206   for (auto operand : op.getOperands()) {
207     auto operand_type = getElementTypeOrSelf(operand);
208     if (IsStringType(operand_type)) return true;
209   }
210   return false;
211 }
212 
HasStringResult(Operation & op)213 bool HasStringResult(Operation& op) {
214   for (auto result : op.getResults()) {
215     auto result_type = getElementTypeOrSelf(result);
216     if (IsStringType(result_type)) return true;
217   }
218   return false;
219 }
220 
MatchesPattern(Operation & op,const llvm::DenseSet<OperationName> & supported_ops)221 bool MatchesPattern(Operation& op,
222                     const llvm::DenseSet<OperationName>& supported_ops) {
223   return (supported_ops.contains(op.getName()));
224 }
225 
226 // Checks if the op is supported inside of a device cluster.  Ops not
227 // in `tf_dialect` are considered supported.
IsSupportedOp(Operation & op,const llvm::DenseSet<OperationName> & supported_ops,const Dialect * tf_dialect)228 bool IsSupportedOp(Operation& op,
229                    const llvm::DenseSet<OperationName>& supported_ops,
230                    const Dialect* tf_dialect) {
231   if (op.getDialect() != tf_dialect)
232     return true;
233   // Assert has a legalization that later removes it so we don't want to outside
234   // compile it ever for performance reasons.
235   if (llvm::isa<TF::AssertOp>(op)) return true;
236   return !HasStringOperand(op) && !HasStringResult(op) &&
237          (MatchesPattern(op, supported_ops) ||
238           mhlo::IsOpAllowedTf2XlaFallback(&op));
239 }
240 
241 // Checks all regions of `op` for captured string operands.
HasCapturedStringOperand(Operation * op)242 bool HasCapturedStringOperand(Operation* op) {
243   bool string_operand = false;
244   for (auto& region : op->getRegions()) {
245     mlir::visitUsedValuesDefinedAbove(
246         region, region, [&](mlir::OpOperand* operand) {
247           if (getElementTypeOrSelf(operand->get()).isa<TF::StringType>())
248             string_operand = true;
249         });
250     if (string_operand) return string_operand;
251   }
252   return string_operand;
253 }
254 
IsVariant(Value value)255 bool IsVariant(Value value) {
256   return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>();
257 }
258 
HasOutsideCompiledAncestor(Operation * op)259 bool HasOutsideCompiledAncestor(Operation* op) {
260   Operation* parent = op->getParentOp();
261   while (parent) {
262     if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
263       return true;
264     parent = parent->getParentOp();
265   }
266   return false;
267 }
268 
269 // If any tf.variants are inputs/outputs to the another outside compiled
270 // Operation, `op`, mark  them for outside compilation unless they are already
271 // marks with outside compilation attribute.
MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster)272 void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) {
273   std::queue<Operation*> outside_compiled_ops;
274   tpu_cluster.walk([&](Operation* op) {
275     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
276       outside_compiled_ops.push(op);
277   });
278 
279   while (!outside_compiled_ops.empty()) {
280     Operation* host_op = outside_compiled_ops.front();
281     outside_compiled_ops.pop();
282     host_op->walk([&](Operation* op) {
283       // Add any operations that provide variant inputs to the cluster.
284       for (auto value : op->getOperands()) {
285         Operation* input_defining_op = value.getDefiningOp();
286         if (IsVariant(value) && input_defining_op &&
287             !HasOutsideCompiledAncestor(input_defining_op) &&
288             !input_defining_op->hasAttrOfType<StringAttr>(
289                 kXlaOutsideCompilationAttr)) {
290           input_defining_op->setAttr(
291               kXlaOutsideCompilationAttr,
292               StringAttr::get(input_defining_op->getContext(), "auto"));
293           outside_compiled_ops.push(input_defining_op);
294         }
295       }
296       // Mark for outside compilation any operations that consume variant
297       // outputs from an outside compiled operation.
298       for (auto value : op->getResults()) {
299         if (IsVariant(value)) {
300           for (auto user : value.getUsers()) {
301             if (!user->hasTrait<OpTrait::IsTerminator>() &&
302                 !HasOutsideCompiledAncestor(user) &&
303                 !user->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
304               user->setAttr(kXlaOutsideCompilationAttr,
305                             StringAttr::get(user->getContext(), "auto"));
306               outside_compiled_ops.push(user);
307             }
308           }
309         }
310       }
311     });
312   }
313 }
314 
315 // Marks uncompilable ops that are in `tf_dialect` for outside compilation.
MarkUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)316 LogicalResult MarkUncompilableOps(
317     const Dialect* tf_dialect, Block* block,
318     llvm::DenseSet<OperationName>& supported_ops) {
319   // Automatically marked ops for outside compilation have
320   // `_xla_outside_compilation` attribute value of "auto" plus
321   // an increasing counter.  Manually marked ops for outside compilation only
322   // have an increasing counteri for the attribute value.  Therefore there is no
323   // collision in
324   // `_xla_outside_compilation` attribute between automatically and manually
325   // marking ops.
326   int outside_compiled_cluster_counter = 0;
327   block->walk([&](Operation* op) {
328     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
329       VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
330               << " isn't compilable, adding outside_compilation attr. "
331                  "This op will automatically be placed on CPU.";
332       op->setAttr(kXlaOutsideCompilationAttr,
333                   StringAttr::get(
334                       op->getContext(),
335                       llvm::formatv("auto{0}", outside_compiled_cluster_counter)
336                           .str()));
337       outside_compiled_cluster_counter++;
338     }
339   });
340   if (outside_compiled_cluster_counter > 0) {
341     auto_outside_compilation_gauge->GetCell()->Set(true);
342   }
343   return success();
344 }
345 
346 // Check for uncompilable ops that are in `tf_dialect` and are not already
347 // marked for outside compilation.
ContainsUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)348 bool ContainsUncompilableOps(const Dialect* tf_dialect, Block* block,
349                              llvm::DenseSet<OperationName>& supported_ops) {
350   int uncompilable_op_count = 0;
351   // Check if op or any parent is already marked for outside compilation.
352   block->walk([&](Operation* op) {
353     Operation* iter_op = op;
354     while (iter_op && !llvm::isa<tf_device::ClusterOp>(iter_op)) {
355       if (iter_op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
356         return;
357       }
358       iter_op = iter_op->getParentOp();
359     }
360 
361     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
362       op->emitOpError() << "isn't compilable for TPU device. enable "
363                            "soft_device_placement option to run on CPU";
364       ++uncompilable_op_count;
365     }
366   });
367   return uncompilable_op_count > 0;
368 }
369 
370 // Unmarks outside compilation for any op that has parents already
371 // marked for outside compilation since the child will be extracted
372 // anyways.
UnmarkChildren(Block * block)373 void UnmarkChildren(Block* block) {
374   block->walk([&](Operation* op) {
375     if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
376     Operation* iter_op = op;
377     bool remove_attr = false;
378     while (auto* parent_op = iter_op->getParentOp()) {
379       if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
380         remove_attr = true;
381         break;
382       }
383       iter_op = parent_op;
384     }
385     if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
386   });
387 }
388 
runOnOperation()389 void MarkOpsForOutsideCompilation::runOnOperation() {
390   auto module = getOperation();
391   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
392   if (!tf_dialect) {
393     getOperation().emitError() << "'tf' dialect is not registered";
394     return signalPassFailure();
395   }
396   OwningRewritePatternList patterns(&getContext());
397   mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
398   TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns);
399   TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns);
400   AddCanonicalizationPatterns(module.getContext(), &patterns);
401 
402   // `supported_ops` contains the name of all of the ops that can potentially be
403   // lowered into HLO on the device. This doesn't always mean that the op can
404   // be lowered in the future passes but if the op is not in this set, it can't
405   // be lowered in a subsequent pass.
406   llvm::DenseSet<OperationName> supported_ops;
407   PatternApplicator(std::move(patterns))
408       .walkAllPatterns([&](const Pattern& pattern) {
409         Optional<OperationName> root_kind = pattern.getRootKind();
410         if (root_kind.hasValue()) supported_ops.insert(root_kind.getValue());
411       });
412   AddSupportedFunctionalOps(module.getContext(), &supported_ops);
413   AddSupportedOpsUsingFolding(module.getContext(), &supported_ops);
414   AddSupportedOpsUsingDynamicPadder(module.getContext(), &supported_ops);
415   AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
416   AddRewrittenCompositeOps(module.getContext(), &supported_ops);
417 
418   auto result = module.walk([&](tf_device::ClusterOp cluster) {
419     // Only if `allow_soft_placement` attribute is true should we mark ops
420     // for outside compilation.
421     auto soft_placement_attr =
422         cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
423     if ((soft_placement_attr && soft_placement_attr.getValue())) {
424       if (failed(MarkUncompilableOps(tf_dialect, &cluster.GetBody(),
425                                      supported_ops)))
426         return WalkResult::interrupt();
427     } else {
428       if (ContainsUncompilableOps(tf_dialect, &cluster.GetBody(),
429                                   supported_ops))
430         return WalkResult::interrupt();
431     }
432     MarkVariantInputsOutputs(cluster);
433 
434     return WalkResult::advance();
435   });
436 
437   if (result.wasInterrupted()) return signalPassFailure();
438 
439   module.walk([&](tf_device::ClusterOp cluster) {
440     // Only if `allow_soft_placement` attribute is true should we unmark ops
441     // for outside compilation.
442     auto soft_placement_attr =
443         cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
444     if (!(soft_placement_attr && soft_placement_attr.getValue())) {
445       return;
446     }
447     UnmarkChildren(&cluster.GetBody());
448   });
449 }
450 
451 }  // namespace
452 
453 std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass()454 CreateMarkOpsForOutsideCompilationPass() {
455   return std::make_unique<MarkOpsForOutsideCompilation>();
456 }
457 
458 }  // namespace TFDevice
459 }  // namespace mlir
460