1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Block.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Support/LLVM.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34
35 namespace mlir {
36 namespace TFTPU {
37 namespace {
38
39 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
40 constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
41
42 struct TPUUpdateEmbeddingEnqueueOpInputsPass
43 : public TF::TPUUpdateEmbeddingEnqueueOpInputsPassBase<
44 TPUUpdateEmbeddingEnqueueOpInputsPass> {
45 void runOnFunction() override;
46 };
47
48 // Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
49 // clear the attribute from the operation. This ensures that future optimization
50 // passes does not trigger additional logic due to presence of this attribute.
ExtractEmbeddingAttribute(Operation * op,llvm::StringMap<Operation * > * embedding_op_map)51 LogicalResult ExtractEmbeddingAttribute(
52 Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
53 auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
54 if (!embedding_attr) return mlir::success();
55
56 if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
57 return op->emitOpError(
58 "found duplicate TPU embedding ops potentially from multiple "
59 "TPUEmbedding layers");
60
61 op->removeAttr(kTPUEmbeddingAttr);
62 return success();
63 }
64
FindTPUEmbeddingOps(FuncOp func_op,llvm::StringMap<Operation * > * enqueue_op_map,llvm::StringMap<Operation * > * recv_activation_op_map,llvm::StringMap<Operation * > * send_gradient_op_map)65 LogicalResult FindTPUEmbeddingOps(
66 FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
67 llvm::StringMap<Operation*>* recv_activation_op_map,
68 llvm::StringMap<Operation*>* send_gradient_op_map) {
69 auto walk_result = func_op.walk([&](Operation* op) {
70 if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
71 if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
72 return WalkResult::interrupt();
73
74 if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
75 if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
76 return WalkResult::interrupt();
77
78 if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
79 TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op))
80 if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
81 return WalkResult::interrupt();
82
83 return WalkResult::advance();
84 });
85 return failure(walk_result.wasInterrupted());
86 }
87
88 // Updates the operand of TPU embedding enqueue ops depending on whether
89 // the graph is in training mode or in non-training mode.
90 // If SendTPUEmbeddingGradients op is present, this means that graph is in
91 // training mode. As so, correctly feed in `then` branch value of SelectV2
92 // operand as inputs to the TPU embedding enqueue ops.
UpdateEmbeddingEnqueueOpInput(const llvm::StringMap<Operation * > & enqueue_op_map,const llvm::StringMap<Operation * > & recv_activation_op_map,const llvm::StringMap<Operation * > & send_gradient_op_map,OpBuilder * builder)93 LogicalResult UpdateEmbeddingEnqueueOpInput(
94 const llvm::StringMap<Operation*>& enqueue_op_map,
95 const llvm::StringMap<Operation*>& recv_activation_op_map,
96 const llvm::StringMap<Operation*>& send_gradient_op_map,
97 OpBuilder* builder) {
98 for (const auto& it : enqueue_op_map) {
99 const auto& embedding_attr = it.getKey();
100 Operation* embedding_op = it.second;
101 if (!recv_activation_op_map.count(embedding_attr))
102 return embedding_op->emitOpError()
103 << "must have a corresponding '"
104 << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
105
106 // TPU Embedding enqueue ops take different inputs depending on whether
107 // graph is in training mode or in eval/prediction mode. During training,
108 // the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
109 // evaluation or prediction, mode must be set to `inference`.
110 // If SendTPUEmbeddingGradients op exists in the graph, then graph is
111 // in training mode, so create a const op with value `train` use the
112 // output value of the constant as an operand to the TPU embedding
113 // enqueue op.
114 bool is_training = send_gradient_op_map.count(embedding_attr);
115
116 // The last operand of TPUEmbeddingEnqueue ops is the mode which
117 // represents whether graph is in training mode or in evaluation mode.
118 auto& mode_enqueue_operand =
119 embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
120
121 llvm::SmallVector<StringRef, 1> mode_string_value;
122 mode_string_value.emplace_back(is_training ? "train" : "inference");
123 builder->setInsertionPoint(embedding_op);
124 auto enqueue_mode = builder->create<TF::ConstOp>(
125 embedding_op->getLoc(),
126 DenseStringElementsAttr::get(
127 RankedTensorType::get({}, builder->getType<TF::StringType>()),
128 mode_string_value));
129
130 auto outside_compilation_attr =
131 embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
132 if (outside_compilation_attr)
133 enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
134 outside_compilation_attr);
135
136 mode_enqueue_operand.set(enqueue_mode);
137 }
138
139 return success();
140 }
141
runOnFunction()142 void TPUUpdateEmbeddingEnqueueOpInputsPass::runOnFunction() {
143 OpBuilder builder(&getContext());
144 auto func_op = getFunction();
145
146 // All TPU embedding layer related ops are annotated with
147 // `_tpu_embedding_layer` attribute along with corresponding string attribute.
148 // Store all tpu embedding layer related ops with value of
149 // `_tpu_embedding_layer` attribute as map key.
150 llvm::StringMap<Operation*> enqueue_op_map;
151 llvm::StringMap<Operation*> recv_activation_op_map;
152 llvm::StringMap<Operation*> send_gradient_op_map;
153 if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
154 &recv_activation_op_map,
155 &send_gradient_op_map)))
156 return signalPassFailure();
157
158 if (enqueue_op_map.size() != recv_activation_op_map.size()) {
159 func_op.emitError() << "expects the number of embedding enqueue ops to "
160 "match the number of '"
161 << TF::RecvTPUEmbeddingActivationsOp::getOperationName()
162 << "' ops";
163 return signalPassFailure();
164 }
165
166 if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
167 recv_activation_op_map,
168 send_gradient_op_map, &builder)))
169 return signalPassFailure();
170 }
171
172 } // anonymous namespace
173
174 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass()175 CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
176 return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputsPass>();
177 }
178
179 } // namespace TFTPU
180 } // namespace mlir
181