1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Block.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Support/LLVM.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33
34 namespace mlir {
35 namespace TFTPU {
36 namespace {
37
38 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
39 constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
40
41 struct TPUUpdateEmbeddingEnqueueOpInputs
42 : public PassWrapper<TPUUpdateEmbeddingEnqueueOpInputs, FunctionPass> {
43 void runOnFunction() override;
44 };
45
46 // Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
47 // clear the attribute from the operation. This ensures that future optimization
48 // passes does not trigger additional logic due to presence of this attribute.
ExtractEmbeddingAttribute(Operation * op,llvm::StringMap<Operation * > * embedding_op_map)49 LogicalResult ExtractEmbeddingAttribute(
50 Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
51 auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
52 if (!embedding_attr) return mlir::success();
53
54 if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
55 return op->emitOpError(
56 "found duplicate TPU embedding ops potentially from multiple "
57 "TPUEmbedding layers");
58
59 op->removeAttr(kTPUEmbeddingAttr);
60 return success();
61 }
62
FindTPUEmbeddingOps(FuncOp func_op,llvm::StringMap<Operation * > * enqueue_op_map,llvm::StringMap<Operation * > * recv_activation_op_map,llvm::StringMap<Operation * > * send_gradient_op_map)63 LogicalResult FindTPUEmbeddingOps(
64 FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
65 llvm::StringMap<Operation*>* recv_activation_op_map,
66 llvm::StringMap<Operation*>* send_gradient_op_map) {
67 auto walk_result = func_op.walk([&](Operation* op) {
68 if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
69 if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
70 return WalkResult::interrupt();
71
72 if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
73 if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
74 return WalkResult::interrupt();
75
76 if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
77 TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op))
78 if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
79 return WalkResult::interrupt();
80
81 return WalkResult::advance();
82 });
83 return failure(walk_result.wasInterrupted());
84 }
85
86 // Updates the operand of TPU embedding enqueue ops depending on whether
87 // the graph is in training mode or in non-training mode.
88 // If SendTPUEmbeddingGradients op is present, this means that graph is in
89 // training mode. As so, correctly feed in `then` branch value of SelectV2
90 // operand as inputs to the TPU embedding enqueue ops.
UpdateEmbeddingEnqueueOpInput(const llvm::StringMap<Operation * > & enqueue_op_map,const llvm::StringMap<Operation * > & recv_activation_op_map,const llvm::StringMap<Operation * > & send_gradient_op_map,OpBuilder * builder)91 LogicalResult UpdateEmbeddingEnqueueOpInput(
92 const llvm::StringMap<Operation*>& enqueue_op_map,
93 const llvm::StringMap<Operation*>& recv_activation_op_map,
94 const llvm::StringMap<Operation*>& send_gradient_op_map,
95 OpBuilder* builder) {
96 for (const auto& it : enqueue_op_map) {
97 const auto& embedding_attr = it.getKey();
98 Operation* embedding_op = it.second;
99 if (!recv_activation_op_map.count(embedding_attr))
100 return embedding_op->emitOpError()
101 << "must have a corresponding '"
102 << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
103
104 // TPU Embedding enqueue ops take different inputs depending on whether
105 // graph is in training mode or in eval/prediction mode. During training,
106 // the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
107 // evaluation or prediction, mode must be set to `inference`.
108 // If SendTPUEmbeddingGradients op exists in the graph, then graph is
109 // in training mode, so create a const op with value `train` use the
110 // output value of the constant as an operand to the TPU embedding
111 // enqueue op.
112 bool is_training = send_gradient_op_map.count(embedding_attr);
113
114 // The last operand of TPUEmbeddingEnqueue ops is the mode which
115 // represents whether graph is in training mode or in evaluation mode.
116 auto& mode_enqueue_operand =
117 embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
118
119 llvm::SmallVector<StringRef, 1> mode_string_value;
120 mode_string_value.emplace_back(is_training ? "train" : "inference");
121 builder->setInsertionPoint(embedding_op);
122 auto enqueue_mode = builder->create<TF::ConstOp>(
123 embedding_op->getLoc(),
124 DenseStringElementsAttr::get(
125 RankedTensorType::get({}, builder->getType<TF::StringType>()),
126 mode_string_value));
127
128 auto outside_compilation_attr =
129 embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
130 if (outside_compilation_attr)
131 enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
132 outside_compilation_attr);
133
134 mode_enqueue_operand.set(enqueue_mode);
135 }
136
137 return success();
138 }
139
runOnFunction()140 void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
141 OpBuilder builder(&getContext());
142 auto func_op = getFunction();
143
144 // All TPU embedding layer related ops are annotated with
145 // `_tpu_embedding_layer` attribute along with corresponding string attribute.
146 // Store all tpu embedding layer related ops with value of
147 // `_tpu_embedding_layer` attribute as map key.
148 llvm::StringMap<Operation*> enqueue_op_map;
149 llvm::StringMap<Operation*> recv_activation_op_map;
150 llvm::StringMap<Operation*> send_gradient_op_map;
151 if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
152 &recv_activation_op_map,
153 &send_gradient_op_map)))
154 return signalPassFailure();
155
156 if (enqueue_op_map.size() != recv_activation_op_map.size()) {
157 func_op.emitError() << "expects the number of embedding enqueue ops to "
158 "match the number of '"
159 << TF::RecvTPUEmbeddingActivationsOp::getOperationName()
160 << "' ops";
161 return signalPassFailure();
162 }
163
164 if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
165 recv_activation_op_map,
166 send_gradient_op_map, &builder)))
167 return signalPassFailure();
168 }
169
170 } // anonymous namespace
171
172 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass()173 CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
174 return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputs>();
175 }
176
177 static PassRegistration<TPUUpdateEmbeddingEnqueueOpInputs> pass(
178 "tf-tpu-update-embedding-enqueue-op-inputs",
179 "Updates inputs to TPU embedding enqueue ops depending on whether graph "
180 "is in training mode or in evaluation mode.");
181
182 } // namespace TFTPU
183 } // namespace mlir
184