• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20 
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/OperationSupport.h"  // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
24 #include "mlir/Pass/Pass.h"  // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
26 #include "mlir/Rewrite/PatternApplicator.h"  // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37 
38 namespace mlir {
39 namespace TFDevice {
40 
41 namespace {
42 
43 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
44 constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement";
45 
46 auto* auto_outside_compilation_gauge =
47     tensorflow::monitoring::Gauge<bool, 0>::New(
48         "/tensorflow/core/use_auto_outside_compilation",
49         "Tracks if auto outside compilation is enabled");
50 
51 struct MarkOpsForOutsideCompilation
52     : public TF::MarkOpsForOutsideCompilationPassBase<
53           MarkOpsForOutsideCompilation> {
54   void runOnOperation() override;
55 };
56 
57 // Adds any canonicalization patterns to list of supported `patterns`.
58 // TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and
59 // remove this.
AddCanonicalizationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)60 void AddCanonicalizationPatterns(MLIRContext* context,
61                                  OwningRewritePatternList* patterns) {
62   for (auto* op : context->getRegisteredOperations())
63     op->getCanonicalizationPatterns(*patterns, context);
64 }
65 
66 // Adds the list of ops that are supported on TPU through constant folding which
67 // may depend on the inputs shapes not known at this point. Such ops may not
68 // have any legalization or canonicalization patterns but shouldn't be marked
69 // for outside compilation.
70 //
71 // TODO(b/177523289): Remove manual handling once we support constant folding
72 // and shape inference through the computation on the host side.
AddSupportedOpsUsingFolding(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)73 void AddSupportedOpsUsingFolding(MLIRContext* context,
74                                  llvm::DenseSet<OperationName>* supported_ops) {
75   llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
76       OperationName(TF::BroadcastArgsOp::getOperationName(), context),
77       OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context),
78       OperationName(TF::ConcatOffsetOp::getOperationName(), context),
79       OperationName(TF::EmptyOp::getOperationName(), context),
80       OperationName(TF::ListDiffOp::getOperationName(), context),
81       OperationName(TF::RangeOp::getOperationName(), context),
82   };
83 
84   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
85 }
86 
87 // TODO(b/159128666): Check the control flow legalization passes instead once
88 // added.
AddSupportedControlFlowOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)89 void AddSupportedControlFlowOps(MLIRContext* context,
90                                 llvm::DenseSet<OperationName>* supported_ops) {
91   supported_ops->insert(
92       OperationName(TF::IfRegionOp::getOperationName(), context));
93   supported_ops->insert(
94       OperationName(TF::WhileRegionOp::getOperationName(), context));
95   supported_ops->insert(
96       OperationName(TF::YieldOp::getOperationName(), context));
97 }
98 
99 // These embedding ops are rewritten when running TPUCompileOp.
AddRewrittenEmbeddingOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)100 void AddRewrittenEmbeddingOps(MLIRContext* context,
101                               llvm::DenseSet<OperationName>* supported_ops) {
102   supported_ops->insert(OperationName(
103       TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context));
104   supported_ops->insert(OperationName(
105       TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
106 }
107 
108 // Stack, TensorList and TensorArray ops are rewritten during the second phase
109 // of the bridge (compilation of TPUCompile op). They would not match any
110 // legalization/canonicalization pattern and have to be manually added to the
111 // list of supported ops.
AddRewrittenCompositeOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)112 void AddRewrittenCompositeOps(MLIRContext* context,
113                               llvm::DenseSet<OperationName>* supported_ops) {
114 #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
115   llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
116       // Stack ops.
117       GET_OPERATION_NAME(TF::StackV2Op),
118       GET_OPERATION_NAME(TF::StackPushV2Op),
119       GET_OPERATION_NAME(TF::StackPopV2Op),
120       // Tensor Array ops.
121       GET_OPERATION_NAME(TF::TensorArrayV3Op),
122       GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
123       GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
124       GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
125       GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
126       GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
127       GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
128       GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
129       GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
130       // Tensor List Ops.
131       GET_OPERATION_NAME(TF::EmptyTensorListOp),
132       GET_OPERATION_NAME(TF::TensorListReserveOp),
133       GET_OPERATION_NAME(TF::TensorListFromTensorOp),
134       GET_OPERATION_NAME(TF::TensorListPushBackOp),
135       GET_OPERATION_NAME(TF::TensorListPopBackOp),
136       GET_OPERATION_NAME(TF::TensorListGetItemOp),
137       GET_OPERATION_NAME(TF::TensorListSetItemOp),
138       GET_OPERATION_NAME(TF::TensorListLengthOp),
139       GET_OPERATION_NAME(TF::TensorListElementShapeOp),
140       GET_OPERATION_NAME(TF::TensorListGatherOp),
141       GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
142       GET_OPERATION_NAME(TF::TensorListStackOp),
143   };
144 #undef GET_OPERATION_NAME
145 
146   supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
147 }
148 
IsStringType(Type type)149 bool IsStringType(Type type) {
150   if (type.isa<TF::StringType>()) return true;
151 
152   auto sub_type = type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
153   if (!sub_type) return false;
154 
155   bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) {
156     return type.getElementType().isa<TF::StringType>();
157   });
158   return has_string;
159 }
160 
HasStringOperand(Operation & op)161 bool HasStringOperand(Operation& op) {
162   for (auto operand : op.getOperands()) {
163     auto operand_type = getElementTypeOrSelf(operand);
164     if (IsStringType(operand_type)) return true;
165   }
166   return false;
167 }
168 
HasStringResult(Operation & op)169 bool HasStringResult(Operation& op) {
170   for (auto result : op.getResults()) {
171     auto result_type = getElementTypeOrSelf(result);
172     if (IsStringType(result_type)) return true;
173   }
174   return false;
175 }
176 
MatchesPattern(Operation & op,const llvm::DenseSet<OperationName> & supported_ops)177 bool MatchesPattern(Operation& op,
178                     const llvm::DenseSet<OperationName>& supported_ops) {
179   return (supported_ops.contains(op.getName()));
180 }
181 
182 // Checks if the op is supported inside of a device cluster.  Ops not
183 // in `tf_dialect` are considered supported.
IsSupportedOp(Operation & op,const llvm::DenseSet<OperationName> & supported_ops,const Dialect * tf_dialect)184 bool IsSupportedOp(Operation& op,
185                    const llvm::DenseSet<OperationName>& supported_ops,
186                    const Dialect* tf_dialect) {
187   if (op.getDialect() != tf_dialect)
188     return true;
189   // Assert has a legalization that later removes it so we don't want to outside
190   // compile it ever for performance reasons.
191   if (llvm::isa<TF::AssertOp>(op)) return true;
192   return !HasStringOperand(op) && !HasStringResult(op) &&
193          (MatchesPattern(op, supported_ops) ||
194           mhlo::IsOpAllowedTf2XlaFallback(&op));
195 }
196 
197 // Checks all regions of `op` for captured string operands.
HasCapturedStringOperand(Operation * op)198 bool HasCapturedStringOperand(Operation* op) {
199   bool string_operand = false;
200   for (auto& region : op->getRegions()) {
201     mlir::visitUsedValuesDefinedAbove(
202         region, region, [&](mlir::OpOperand* operand) {
203           if (getElementTypeOrSelf(operand->get()).isa<TF::StringType>())
204             string_operand = true;
205         });
206     if (string_operand) return string_operand;
207   }
208   return string_operand;
209 }
210 
IsVariant(Value value)211 bool IsVariant(Value value) {
212   return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>();
213 }
214 
HasOutsideCompiledAncestor(Operation * op)215 bool HasOutsideCompiledAncestor(Operation* op) {
216   Operation* parent = op->getParentOp();
217   while (parent) {
218     if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
219       return true;
220     parent = parent->getParentOp();
221   }
222   return false;
223 }
224 
225 // If any tf.variants are inputs/outputs to the another outside compiled
226 // Operation, `op`, mark  them for outside compilation unless they are already
227 // marks with outside compilation attribute.
MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster)228 void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) {
229   std::queue<Operation*> outside_compiled_ops;
230   tpu_cluster.walk([&](Operation* op) {
231     if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
232       outside_compiled_ops.push(op);
233   });
234 
235   while (!outside_compiled_ops.empty()) {
236     Operation* host_op = outside_compiled_ops.front();
237     outside_compiled_ops.pop();
238     host_op->walk([&](Operation* op) {
239       // Add any operations that provide variant inputs to the cluster.
240       for (auto value : op->getOperands()) {
241         Operation* input_defining_op = value.getDefiningOp();
242         if (IsVariant(value) && input_defining_op &&
243             !HasOutsideCompiledAncestor(input_defining_op) &&
244             !input_defining_op->hasAttrOfType<StringAttr>(
245                 kXlaOutsideCompilationAttr)) {
246           input_defining_op->setAttr(
247               kXlaOutsideCompilationAttr,
248               StringAttr::get(input_defining_op->getContext(), "auto"));
249           outside_compiled_ops.push(input_defining_op);
250         }
251       }
252       // Mark for outside compilation any operations that consume variant
253       // outputs from an outside compiled operation.
254       for (auto value : op->getResults()) {
255         if (IsVariant(value)) {
256           for (auto user : value.getUsers()) {
257             if (!user->hasTrait<OpTrait::IsTerminator>() &&
258                 !HasOutsideCompiledAncestor(user) &&
259                 !user->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
260               user->setAttr(kXlaOutsideCompilationAttr,
261                             StringAttr::get(user->getContext(), "auto"));
262               outside_compiled_ops.push(user);
263             }
264           }
265         }
266       }
267     });
268   }
269 }
270 
271 // Marks uncompilable ops that are in `tf_dialect` for outside compilation.
MarkUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)272 LogicalResult MarkUncompilableOps(
273     const Dialect* tf_dialect, Block* block,
274     llvm::DenseSet<OperationName>& supported_ops) {
275   // Automatically marked ops for outside compilation have
276   // `_xla_outside_compilation` attribute value of "auto" plus
277   // an increasing counter.  Manually marked ops for outside compilation only
278   // have an increasing counteri for the attribute value.  Therefore there is no
279   // collision in
280   // `_xla_outside_compilation` attribute between automatically and manually
281   // marking ops.
282   int outside_compiled_cluster_counter = 0;
283   block->walk([&](Operation* op) {
284     if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
285       VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
286               << " isn't compilable, adding outside_compilation attr. "
287                  "This op will automatically be placed on CPU.";
288       op->setAttr(kXlaOutsideCompilationAttr,
289                   StringAttr::get(
290                       op->getContext(),
291                       llvm::formatv("auto{0}", outside_compiled_cluster_counter)
292                           .str()));
293       outside_compiled_cluster_counter++;
294     }
295   });
296   if (outside_compiled_cluster_counter > 0) {
297     auto_outside_compilation_gauge->GetCell()->Set(true);
298   }
299   return success();
300 }
301 
302 // Unmarks outside compilation for any op that has parents already
303 // marked for outside compilation since the child will be extracted
304 // anyways.
UnmarkChildren(Block * block)305 void UnmarkChildren(Block* block) {
306   block->walk([&](Operation* op) {
307     if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
308     Operation* iter_op = op;
309     bool remove_attr = false;
310     while (auto* parent_op = iter_op->getParentOp()) {
311       if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
312         remove_attr = true;
313         break;
314       }
315       iter_op = parent_op;
316     }
317     if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
318   });
319 }
320 
runOnOperation()321 void MarkOpsForOutsideCompilation::runOnOperation() {
322   auto module = getOperation();
323   const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
324   if (!tf_dialect) {
325     getOperation().emitError() << "'tf' dialect is not registered";
326     return signalPassFailure();
327   }
328   OwningRewritePatternList patterns;
329   mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
330   TF::PopulateLoweringTFPatterns(module.getContext(), &patterns);
331   AddCanonicalizationPatterns(module.getContext(), &patterns);
332 
333   // `supported_ops` contains the name of all of the ops that can potentially be
334   // lowered into HLO on the device. This doesn't always mean that the op can
335   // be lowered in the future passes but if the op is not in this set, it can't
336   // be lowered in a subsequent pass.
337   llvm::DenseSet<OperationName> supported_ops;
338   PatternApplicator(std::move(patterns))
339       .walkAllPatterns([&](const Pattern& pattern) {
340         Optional<OperationName> root_kind = pattern.getRootKind();
341         if (root_kind.hasValue()) supported_ops.insert(root_kind.getValue());
342       });
343   AddSupportedControlFlowOps(module.getContext(), &supported_ops);
344   AddSupportedOpsUsingFolding(module.getContext(), &supported_ops);
345   AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
346   AddRewrittenCompositeOps(module.getContext(), &supported_ops);
347 
348   auto result = module.walk([&](tf_device::ClusterOp cluster) {
349     // Only if `allow_soft_placement` attribute is true should we mark ops
350     // for outside compilation.
351     auto soft_placement_attr =
352         cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
353     if ((soft_placement_attr && soft_placement_attr.getValue())) {
354       if (failed(MarkUncompilableOps(tf_dialect, &cluster.GetBody(),
355                                      supported_ops)))
356         return WalkResult::interrupt();
357     }
358     MarkVariantInputsOutputs(cluster);
359 
360     return WalkResult::advance();
361   });
362 
363   if (result.wasInterrupted()) return signalPassFailure();
364 
365   module.walk([&](tf_device::ClusterOp cluster) {
366     // Only if `allow_soft_placement` attribute is true should we unmark ops
367     // for outside compilation.
368     auto soft_placement_attr =
369         cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
370     if (!(soft_placement_attr && soft_placement_attr.getValue())) {
371       return;
372     }
373     UnmarkChildren(&cluster.GetBody());
374   });
375 }
376 
377 }  // namespace
378 
379 std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass()380 CreateMarkOpsForOutsideCompilationPass() {
381   return std::make_unique<MarkOpsForOutsideCompilation>();
382 }
383 
384 }  // namespace TFDevice
385 }  // namespace mlir
386