1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17 #include <queue>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/Support/FormatVariadic.h"
22 #include "mlir/IR/OperationSupport.h" // from @llvm-project
23 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
24 #include "mlir/Pass/Pass.h" // from @llvm-project
25 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
26 #include "mlir/Rewrite/PatternApplicator.h" // from @llvm-project
27 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
28 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
34 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
35 #include "tensorflow/compiler/mlir/xla/transforms/passes.h"
36 #include "tensorflow/core/lib/monitoring/gauge.h"
37
38 namespace mlir {
39 namespace TFDevice {
40
41 namespace {
42
43 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
44 constexpr char kAllowSoftPlacementAttr[] = "allow_soft_placement";
45
46 auto* auto_outside_compilation_gauge =
47 tensorflow::monitoring::Gauge<bool, 0>::New(
48 "/tensorflow/core/use_auto_outside_compilation",
49 "Tracks if auto outside compilation is enabled");
50
51 struct MarkOpsForOutsideCompilation
52 : public TF::MarkOpsForOutsideCompilationPassBase<
53 MarkOpsForOutsideCompilation> {
54 void runOnOperation() override;
55 };
56
57 // Adds any canonicalization patterns to list of supported `patterns`.
58 // TODO(b/161726307): Move or import the relevant patterns to LowerTF pass and
59 // remove this.
AddCanonicalizationPatterns(MLIRContext * context,OwningRewritePatternList * patterns)60 void AddCanonicalizationPatterns(MLIRContext* context,
61 OwningRewritePatternList* patterns) {
62 for (auto* op : context->getRegisteredOperations())
63 op->getCanonicalizationPatterns(*patterns, context);
64 }
65
66 // Adds the list of ops that are supported on TPU through constant folding which
67 // may depend on the inputs shapes not known at this point. Such ops may not
68 // have any legalization or canonicalization patterns but shouldn't be marked
69 // for outside compilation.
70 //
71 // TODO(b/177523289): Remove manual handling once we support constant folding
72 // and shape inference through the computation on the host side.
AddSupportedOpsUsingFolding(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)73 void AddSupportedOpsUsingFolding(MLIRContext* context,
74 llvm::DenseSet<OperationName>* supported_ops) {
75 llvm::SmallDenseSet<OperationName, 8> allowlist_ops = {
76 OperationName(TF::BroadcastArgsOp::getOperationName(), context),
77 OperationName(TF::BroadcastGradientArgsOp::getOperationName(), context),
78 OperationName(TF::ConcatOffsetOp::getOperationName(), context),
79 OperationName(TF::EmptyOp::getOperationName(), context),
80 OperationName(TF::ListDiffOp::getOperationName(), context),
81 OperationName(TF::RangeOp::getOperationName(), context),
82 };
83
84 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
85 }
86
87 // TODO(b/159128666): Check the control flow legalization passes instead once
88 // added.
AddSupportedControlFlowOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)89 void AddSupportedControlFlowOps(MLIRContext* context,
90 llvm::DenseSet<OperationName>* supported_ops) {
91 supported_ops->insert(
92 OperationName(TF::IfRegionOp::getOperationName(), context));
93 supported_ops->insert(
94 OperationName(TF::WhileRegionOp::getOperationName(), context));
95 supported_ops->insert(
96 OperationName(TF::YieldOp::getOperationName(), context));
97 }
98
99 // These embedding ops are rewritten when running TPUCompileOp.
AddRewrittenEmbeddingOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)100 void AddRewrittenEmbeddingOps(MLIRContext* context,
101 llvm::DenseSet<OperationName>* supported_ops) {
102 supported_ops->insert(OperationName(
103 TF::RecvTPUEmbeddingActivationsOp::getOperationName(), context));
104 supported_ops->insert(OperationName(
105 TF::SendTPUEmbeddingGradientsOp::getOperationName(), context));
106 }
107
108 // Stack, TensorList and TensorArray ops are rewritten during the second phase
109 // of the bridge (compilation of TPUCompile op). They would not match any
110 // legalization/canonicalization pattern and have to be manually added to the
111 // list of supported ops.
AddRewrittenCompositeOps(MLIRContext * context,llvm::DenseSet<OperationName> * supported_ops)112 void AddRewrittenCompositeOps(MLIRContext* context,
113 llvm::DenseSet<OperationName>* supported_ops) {
114 #define GET_OPERATION_NAME(op) OperationName(op::getOperationName(), context)
115 llvm::SmallDenseSet<OperationName, 32> allowlist_ops = {
116 // Stack ops.
117 GET_OPERATION_NAME(TF::StackV2Op),
118 GET_OPERATION_NAME(TF::StackPushV2Op),
119 GET_OPERATION_NAME(TF::StackPopV2Op),
120 // Tensor Array ops.
121 GET_OPERATION_NAME(TF::TensorArrayV3Op),
122 GET_OPERATION_NAME(TF::TensorArrayReadV3Op),
123 GET_OPERATION_NAME(TF::TensorArrayWriteV3Op),
124 GET_OPERATION_NAME(TF::TensorArrayConcatV3Op),
125 GET_OPERATION_NAME(TF::TensorArraySplitV3Op),
126 GET_OPERATION_NAME(TF::TensorArraySizeV3Op),
127 GET_OPERATION_NAME(TF::TensorArrayGradV3Op),
128 GET_OPERATION_NAME(TF::TensorArrayGatherV3Op),
129 GET_OPERATION_NAME(TF::TensorArrayScatterV3Op),
130 // Tensor List Ops.
131 GET_OPERATION_NAME(TF::EmptyTensorListOp),
132 GET_OPERATION_NAME(TF::TensorListReserveOp),
133 GET_OPERATION_NAME(TF::TensorListFromTensorOp),
134 GET_OPERATION_NAME(TF::TensorListPushBackOp),
135 GET_OPERATION_NAME(TF::TensorListPopBackOp),
136 GET_OPERATION_NAME(TF::TensorListGetItemOp),
137 GET_OPERATION_NAME(TF::TensorListSetItemOp),
138 GET_OPERATION_NAME(TF::TensorListLengthOp),
139 GET_OPERATION_NAME(TF::TensorListElementShapeOp),
140 GET_OPERATION_NAME(TF::TensorListGatherOp),
141 GET_OPERATION_NAME(TF::TensorListScatterIntoExistingListOp),
142 GET_OPERATION_NAME(TF::TensorListStackOp),
143 };
144 #undef GET_OPERATION_NAME
145
146 supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
147 }
148
IsStringType(Type type)149 bool IsStringType(Type type) {
150 if (type.isa<TF::StringType>()) return true;
151
152 auto sub_type = type.dyn_cast<TF::TensorFlowTypeWithSubtype>();
153 if (!sub_type) return false;
154
155 bool has_string = llvm::any_of(sub_type.GetSubtypes(), [](TensorType type) {
156 return type.getElementType().isa<TF::StringType>();
157 });
158 return has_string;
159 }
160
HasStringOperand(Operation & op)161 bool HasStringOperand(Operation& op) {
162 for (auto operand : op.getOperands()) {
163 auto operand_type = getElementTypeOrSelf(operand);
164 if (IsStringType(operand_type)) return true;
165 }
166 return false;
167 }
168
HasStringResult(Operation & op)169 bool HasStringResult(Operation& op) {
170 for (auto result : op.getResults()) {
171 auto result_type = getElementTypeOrSelf(result);
172 if (IsStringType(result_type)) return true;
173 }
174 return false;
175 }
176
MatchesPattern(Operation & op,const llvm::DenseSet<OperationName> & supported_ops)177 bool MatchesPattern(Operation& op,
178 const llvm::DenseSet<OperationName>& supported_ops) {
179 return (supported_ops.contains(op.getName()));
180 }
181
182 // Checks if the op is supported inside of a device cluster. Ops not
183 // in `tf_dialect` are considered supported.
IsSupportedOp(Operation & op,const llvm::DenseSet<OperationName> & supported_ops,const Dialect * tf_dialect)184 bool IsSupportedOp(Operation& op,
185 const llvm::DenseSet<OperationName>& supported_ops,
186 const Dialect* tf_dialect) {
187 if (op.getDialect() != tf_dialect)
188 return true;
189 // Assert has a legalization that later removes it so we don't want to outside
190 // compile it ever for performance reasons.
191 if (llvm::isa<TF::AssertOp>(op)) return true;
192 return !HasStringOperand(op) && !HasStringResult(op) &&
193 (MatchesPattern(op, supported_ops) ||
194 mhlo::IsOpAllowedTf2XlaFallback(&op));
195 }
196
197 // Checks all regions of `op` for captured string operands.
HasCapturedStringOperand(Operation * op)198 bool HasCapturedStringOperand(Operation* op) {
199 bool string_operand = false;
200 for (auto& region : op->getRegions()) {
201 mlir::visitUsedValuesDefinedAbove(
202 region, region, [&](mlir::OpOperand* operand) {
203 if (getElementTypeOrSelf(operand->get()).isa<TF::StringType>())
204 string_operand = true;
205 });
206 if (string_operand) return string_operand;
207 }
208 return string_operand;
209 }
210
IsVariant(Value value)211 bool IsVariant(Value value) {
212 return getElementTypeOrSelf(value.getType()).isa<TF::VariantType>();
213 }
214
HasOutsideCompiledAncestor(Operation * op)215 bool HasOutsideCompiledAncestor(Operation* op) {
216 Operation* parent = op->getParentOp();
217 while (parent) {
218 if (parent->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
219 return true;
220 parent = parent->getParentOp();
221 }
222 return false;
223 }
224
225 // If any tf.variants are inputs/outputs to the another outside compiled
226 // Operation, `op`, mark them for outside compilation unless they are already
227 // marks with outside compilation attribute.
MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster)228 void MarkVariantInputsOutputs(tf_device::ClusterOp tpu_cluster) {
229 std::queue<Operation*> outside_compiled_ops;
230 tpu_cluster.walk([&](Operation* op) {
231 if (op->hasAttrOfType<StringAttr>(kXlaOutsideCompilationAttr))
232 outside_compiled_ops.push(op);
233 });
234
235 while (!outside_compiled_ops.empty()) {
236 Operation* host_op = outside_compiled_ops.front();
237 outside_compiled_ops.pop();
238 host_op->walk([&](Operation* op) {
239 // Add any operations that provide variant inputs to the cluster.
240 for (auto value : op->getOperands()) {
241 Operation* input_defining_op = value.getDefiningOp();
242 if (IsVariant(value) && input_defining_op &&
243 !HasOutsideCompiledAncestor(input_defining_op) &&
244 !input_defining_op->hasAttrOfType<StringAttr>(
245 kXlaOutsideCompilationAttr)) {
246 input_defining_op->setAttr(
247 kXlaOutsideCompilationAttr,
248 StringAttr::get(input_defining_op->getContext(), "auto"));
249 outside_compiled_ops.push(input_defining_op);
250 }
251 }
252 // Mark for outside compilation any operations that consume variant
253 // outputs from an outside compiled operation.
254 for (auto value : op->getResults()) {
255 if (IsVariant(value)) {
256 for (auto user : value.getUsers()) {
257 if (!user->hasTrait<OpTrait::IsTerminator>() &&
258 !HasOutsideCompiledAncestor(user) &&
259 !user->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
260 user->setAttr(kXlaOutsideCompilationAttr,
261 StringAttr::get(user->getContext(), "auto"));
262 outside_compiled_ops.push(user);
263 }
264 }
265 }
266 }
267 });
268 }
269 }
270
271 // Marks uncompilable ops that are in `tf_dialect` for outside compilation.
MarkUncompilableOps(const Dialect * tf_dialect,Block * block,llvm::DenseSet<OperationName> & supported_ops)272 LogicalResult MarkUncompilableOps(
273 const Dialect* tf_dialect, Block* block,
274 llvm::DenseSet<OperationName>& supported_ops) {
275 // Automatically marked ops for outside compilation have
276 // `_xla_outside_compilation` attribute value of "auto" plus
277 // an increasing counter. Manually marked ops for outside compilation only
278 // have an increasing counteri for the attribute value. Therefore there is no
279 // collision in
280 // `_xla_outside_compilation` attribute between automatically and manually
281 // marking ops.
282 int outside_compiled_cluster_counter = 0;
283 block->walk([&](Operation* op) {
284 if (!IsSupportedOp(*op, supported_ops, tf_dialect)) {
285 VLOG(3) << "Cloud TPU: Op " << op->getName().getStringRef().str()
286 << " isn't compilable, adding outside_compilation attr. "
287 "This op will automatically be placed on CPU.";
288 op->setAttr(kXlaOutsideCompilationAttr,
289 StringAttr::get(
290 op->getContext(),
291 llvm::formatv("auto{0}", outside_compiled_cluster_counter)
292 .str()));
293 outside_compiled_cluster_counter++;
294 }
295 });
296 if (outside_compiled_cluster_counter > 0) {
297 auto_outside_compilation_gauge->GetCell()->Set(true);
298 }
299 return success();
300 }
301
302 // Unmarks outside compilation for any op that has parents already
303 // marked for outside compilation since the child will be extracted
304 // anyways.
UnmarkChildren(Block * block)305 void UnmarkChildren(Block* block) {
306 block->walk([&](Operation* op) {
307 if (!op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) return;
308 Operation* iter_op = op;
309 bool remove_attr = false;
310 while (auto* parent_op = iter_op->getParentOp()) {
311 if (parent_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr)) {
312 remove_attr = true;
313 break;
314 }
315 iter_op = parent_op;
316 }
317 if (remove_attr) op->removeAttr(kXlaOutsideCompilationAttr);
318 });
319 }
320
runOnOperation()321 void MarkOpsForOutsideCompilation::runOnOperation() {
322 auto module = getOperation();
323 const Dialect* tf_dialect = getContext().getLoadedDialect("tf");
324 if (!tf_dialect) {
325 getOperation().emitError() << "'tf' dialect is not registered";
326 return signalPassFailure();
327 }
328 OwningRewritePatternList patterns;
329 mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns);
330 TF::PopulateLoweringTFPatterns(module.getContext(), &patterns);
331 AddCanonicalizationPatterns(module.getContext(), &patterns);
332
333 // `supported_ops` contains the name of all of the ops that can potentially be
334 // lowered into HLO on the device. This doesn't always mean that the op can
335 // be lowered in the future passes but if the op is not in this set, it can't
336 // be lowered in a subsequent pass.
337 llvm::DenseSet<OperationName> supported_ops;
338 PatternApplicator(std::move(patterns))
339 .walkAllPatterns([&](const Pattern& pattern) {
340 Optional<OperationName> root_kind = pattern.getRootKind();
341 if (root_kind.hasValue()) supported_ops.insert(root_kind.getValue());
342 });
343 AddSupportedControlFlowOps(module.getContext(), &supported_ops);
344 AddSupportedOpsUsingFolding(module.getContext(), &supported_ops);
345 AddRewrittenEmbeddingOps(module.getContext(), &supported_ops);
346 AddRewrittenCompositeOps(module.getContext(), &supported_ops);
347
348 auto result = module.walk([&](tf_device::ClusterOp cluster) {
349 // Only if `allow_soft_placement` attribute is true should we mark ops
350 // for outside compilation.
351 auto soft_placement_attr =
352 cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
353 if ((soft_placement_attr && soft_placement_attr.getValue())) {
354 if (failed(MarkUncompilableOps(tf_dialect, &cluster.GetBody(),
355 supported_ops)))
356 return WalkResult::interrupt();
357 }
358 MarkVariantInputsOutputs(cluster);
359
360 return WalkResult::advance();
361 });
362
363 if (result.wasInterrupted()) return signalPassFailure();
364
365 module.walk([&](tf_device::ClusterOp cluster) {
366 // Only if `allow_soft_placement` attribute is true should we unmark ops
367 // for outside compilation.
368 auto soft_placement_attr =
369 cluster->getAttrOfType<BoolAttr>(kAllowSoftPlacementAttr);
370 if (!(soft_placement_attr && soft_placement_attr.getValue())) {
371 return;
372 }
373 UnmarkChildren(&cluster.GetBody());
374 });
375 }
376
377 } // namespace
378
379 std::unique_ptr<OperationPass<ModuleOp>>
CreateMarkOpsForOutsideCompilationPass()380 CreateMarkOpsForOutsideCompilationPass() {
381 return std::make_unique<MarkOpsForOutsideCompilation>();
382 }
383
384 } // namespace TFDevice
385 } // namespace mlir
386