• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Block.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "mlir/Support/LLVM.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 
34 namespace mlir {
35 namespace TFTPU {
36 namespace {
37 
38 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
39 constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
40 
41 struct TPUUpdateEmbeddingEnqueueOpInputs
42     : public PassWrapper<TPUUpdateEmbeddingEnqueueOpInputs, FunctionPass> {
43   void runOnFunction() override;
44 };
45 
46 // Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
47 // clear the attribute from the operation. This ensures that future optimization
48 // passes does not trigger additional logic due to presence of this attribute.
ExtractEmbeddingAttribute(Operation * op,llvm::StringMap<Operation * > * embedding_op_map)49 LogicalResult ExtractEmbeddingAttribute(
50     Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
51   auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
52   if (!embedding_attr) return mlir::success();
53 
54   if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
55     return op->emitOpError(
56         "found duplicate TPU embedding ops potentially from multiple "
57         "TPUEmbedding layers");
58 
59   op->removeAttr(kTPUEmbeddingAttr);
60   return success();
61 }
62 
FindTPUEmbeddingOps(FuncOp func_op,llvm::StringMap<Operation * > * enqueue_op_map,llvm::StringMap<Operation * > * recv_activation_op_map,llvm::StringMap<Operation * > * send_gradient_op_map)63 LogicalResult FindTPUEmbeddingOps(
64     FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
65     llvm::StringMap<Operation*>* recv_activation_op_map,
66     llvm::StringMap<Operation*>* send_gradient_op_map) {
67   auto walk_result = func_op.walk([&](Operation* op) {
68     if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
69       if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
70         return WalkResult::interrupt();
71 
72     if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
73       if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
74         return WalkResult::interrupt();
75 
76     if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
77                   TF::EnqueueTPUEmbeddingRaggedTensorBatchOp>(op))
78       if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
79         return WalkResult::interrupt();
80 
81     return WalkResult::advance();
82   });
83   return failure(walk_result.wasInterrupted());
84 }
85 
86 // Updates the operand of TPU embedding enqueue ops depending on whether
87 // the graph is in training mode or in non-training mode.
88 // If SendTPUEmbeddingGradients op is present, this means that graph is in
89 // training mode. As so, correctly feed in `then` branch value of SelectV2
90 // operand as inputs to the TPU embedding enqueue ops.
UpdateEmbeddingEnqueueOpInput(const llvm::StringMap<Operation * > & enqueue_op_map,const llvm::StringMap<Operation * > & recv_activation_op_map,const llvm::StringMap<Operation * > & send_gradient_op_map,OpBuilder * builder)91 LogicalResult UpdateEmbeddingEnqueueOpInput(
92     const llvm::StringMap<Operation*>& enqueue_op_map,
93     const llvm::StringMap<Operation*>& recv_activation_op_map,
94     const llvm::StringMap<Operation*>& send_gradient_op_map,
95     OpBuilder* builder) {
96   for (const auto& it : enqueue_op_map) {
97     const auto& embedding_attr = it.getKey();
98     Operation* embedding_op = it.second;
99     if (!recv_activation_op_map.count(embedding_attr))
100       return embedding_op->emitOpError()
101              << "must have a corresponding '"
102              << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
103 
104     // TPU Embedding enqueue ops take different inputs depending on whether
105     // graph is in training mode or in eval/prediction mode. During training,
106     // the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
107     // evaluation or prediction, mode must be set to `inference`.
108     // If SendTPUEmbeddingGradients op exists in the graph, then graph is
109     // in training mode, so create a const op with value `train` use the
110     // output value of the constant as an operand to the TPU embedding
111     // enqueue op.
112     bool is_training = send_gradient_op_map.count(embedding_attr);
113 
114     // The last operand of TPUEmbeddingEnqueue ops is the mode which
115     // represents whether graph is in training mode or in evaluation mode.
116     auto& mode_enqueue_operand =
117         embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
118 
119     llvm::SmallVector<StringRef, 1> mode_string_value;
120     mode_string_value.emplace_back(is_training ? "train" : "inference");
121     builder->setInsertionPoint(embedding_op);
122     auto enqueue_mode = builder->create<TF::ConstOp>(
123         embedding_op->getLoc(),
124         DenseStringElementsAttr::get(
125             RankedTensorType::get({}, builder->getType<TF::StringType>()),
126             mode_string_value));
127 
128     auto outside_compilation_attr =
129         embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
130     if (outside_compilation_attr)
131       enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
132                             outside_compilation_attr);
133 
134     mode_enqueue_operand.set(enqueue_mode);
135   }
136 
137   return success();
138 }
139 
runOnFunction()140 void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() {
141   OpBuilder builder(&getContext());
142   auto func_op = getFunction();
143 
144   // All TPU embedding layer related ops are annotated with
145   // `_tpu_embedding_layer` attribute along with corresponding string attribute.
146   // Store all tpu embedding layer related ops with value of
147   // `_tpu_embedding_layer` attribute as map key.
148   llvm::StringMap<Operation*> enqueue_op_map;
149   llvm::StringMap<Operation*> recv_activation_op_map;
150   llvm::StringMap<Operation*> send_gradient_op_map;
151   if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
152                                  &recv_activation_op_map,
153                                  &send_gradient_op_map)))
154     return signalPassFailure();
155 
156   if (enqueue_op_map.size() != recv_activation_op_map.size()) {
157     func_op.emitError() << "expects the number of embedding enqueue ops to "
158                            "match the number of '"
159                         << TF::RecvTPUEmbeddingActivationsOp::getOperationName()
160                         << "' ops";
161     return signalPassFailure();
162   }
163 
164   if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
165                                            recv_activation_op_map,
166                                            send_gradient_op_map, &builder)))
167     return signalPassFailure();
168 }
169 
170 }  // anonymous namespace
171 
172 std::unique_ptr<OperationPass<FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass()173 CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
174   return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputs>();
175 }
176 
177 static PassRegistration<TPUUpdateEmbeddingEnqueueOpInputs> pass(
178     "tf-tpu-update-embedding-enqueue-op-inputs",
179     "Updates inputs to TPU embedding enqueue ops depending on whether graph "
180     "is in training mode or in evaluation mode.");
181 
182 }  // namespace TFTPU
183 }  // namespace mlir
184