1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/PointerUnion.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Types.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
38
39 namespace mlir {
40 namespace TF {
41 namespace {
42
43 constexpr char kResourceFunctionMsg[] =
44 "expects function level resource argument";
45 constexpr char kInvalidResourceMsg[] =
46 "expects resource to be a VarHandleOp or function argument";
47 constexpr char kResourceNameArgAttr[] = "tf.resource_name";
48
49 // Checks if a function has only one block.
CheckSingleBlockFunction(FuncOp function)50 mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
51 if (!llvm::hasSingleElement(function)) {
52 return function.emitError()
53 << "expects function '" << function.getName()
54 << "' to have 1 block, got " << function.getBlocks().size();
55 }
56 return success();
57 }
58
59 // Collects names of users of a resource that are not `tf.ReadVariableOp` and
60 // not `tf.AssignVariableOp`.
GetCompositeResourceUserNames(Value resource)61 llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
62 Value resource) {
63 // SmallSet will use a vector when there is only one element and use std::set
64 // when there are more than one elements. This ensures that the operations in
65 // the error message are ordered.
66 llvm::SmallSet<llvm::StringRef, 1> composite_users;
67 for (Operation* user : resource.getUsers())
68 if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
69 composite_users.insert(user->getName().getStringRef());
70
71 return composite_users;
72 }
73
74 // Checks that the only users of `tf.VarHandleOp` are
75 // `tf.ReadVariableOp` and `tf.AssignVariableOp`.
ValidateVarHandle(TF::VarHandleOp var_handle_op)76 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
77 auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
78 if (!composite_ops.empty())
79 return var_handle_op.emitOpError()
80 << "expects users to be 'tf.ReadVariableOp' or "
81 "'tf.AssignVariableOp', got ["
82 << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
83 << "]";
84
85 return success();
86 }
87
88 // Checks if resource argument has a valid resource subtype and its users are of
89 // `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
ValidateResourceArgument(FuncOp function,BlockArgument resource_arg,TF::ResourceType resource_type)90 mlir::LogicalResult ValidateResourceArgument(FuncOp function,
91 BlockArgument resource_arg,
92 TF::ResourceType resource_type) {
93 if (resource_type.getSubtypes().size() != 1)
94 return function.emitError()
95 << "expects resource type of argument "
96 << resource_arg.getArgNumber() << " to have one subtype, got "
97 << resource_type;
98
99 auto composite_ops = GetCompositeResourceUserNames(resource_arg);
100 if (!composite_ops.empty())
101 return function.emitError()
102 << "expects users of resource argument "
103 << resource_arg.getArgNumber()
104 << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got ["
105 << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
106 << "]";
107
108 return success();
109 }
110
VariableIsInitialized(TF::VarHandleOp var_handle_op)111 bool VariableIsInitialized(TF::VarHandleOp var_handle_op) {
112 auto is_variable_initialized =
113 var_handle_op->getAttrOfType<BoolAttr>("_is_initialized");
114 // Assume variable is initialized if attribute is not set.
115 // There are paths that doesn't mark the variables. All variables
116 // that doesn't have the attribute will be promoted.
117 return !is_variable_initialized || is_variable_initialized.getValue();
118 }
119
120 // Adds resource arguments for every unique (name) variable handle. Associated
121 // `tf.VarHandleOp` are removed from the function. Variable shared names are
122 // returned in `var_handle_shared_names` based on the ordering of added resource
123 // arguments.
PromoteVarHandlesToArguments(FuncOp function,bool add_validation,llvm::SmallVectorImpl<std::string> * var_handle_shared_names)124 mlir::LogicalResult PromoteVarHandlesToArguments(
125 FuncOp function, bool add_validation,
126 llvm::SmallVectorImpl<std::string>* var_handle_shared_names) {
127 Block& block = function.front();
128 auto func_type = function.getType();
129
130 auto func_arg_types = llvm::to_vector<4>(func_type.getInputs());
131 llvm::SmallDenseMap<llvm::StringRef, int> var_arg_index_by_name;
132 for (auto var_handle_op :
133 llvm::make_early_inc_range(block.getOps<TF::VarHandleOp>())) {
134 if (add_validation && failed(ValidateVarHandle(var_handle_op)))
135 return failure();
136 // In the case of variables that are not initialized at graph creation
137 // then we keep them as VarHandleOps.
138 if (!VariableIsInitialized(var_handle_op)) continue;
139
140 llvm::StringRef name = var_handle_op.shared_nameAttr().getValue();
141 auto it = var_arg_index_by_name.insert({name, func_arg_types.size()});
142 if (it.second) {
143 var_handle_shared_names->emplace_back(name);
144 auto resource_type = var_handle_op.resource().getType();
145 func_arg_types.push_back(resource_type);
146 var_handle_op.resource().replaceAllUsesWith(
147 block.addArgument(resource_type));
148 } else {
149 var_handle_op.resource().replaceAllUsesWith(
150 block.getArgument(it.first->getSecond()));
151 }
152 var_handle_op.erase();
153 }
154
155 if (!var_handle_shared_names->empty())
156 function.setType(FunctionType::get(function.getContext(), func_arg_types,
157 func_type.getResults()));
158
159 return success();
160 }
161
162 // Records the current live value for a resource variable and whether a read or
163 // write on the variable occurred.
164 struct ResourceInfo {
165 Value live_value = nullptr;
166 bool read = false;
167 bool write = false;
168 };
169
PromoteResourcesToArguments(FuncOp function,llvm::ArrayRef<std::string> var_handle_shared_names)170 LogicalResult PromoteResourcesToArguments(
171 FuncOp function, llvm::ArrayRef<std::string> var_handle_shared_names) {
172 Block& block = function.front();
173
174 auto return_op = llvm::dyn_cast_or_null<ReturnOp>(block.getTerminator());
175 if (!return_op)
176 return function.emitError() << "expects function '" << function.getName()
177 << "' to have a MLIR ReturnOp";
178
179 llvm::SmallVector<ResourceInfo, 4> resources(function.getNumArguments());
180 auto argument_types = llvm::to_vector<4>(function.getType().getInputs());
181 bool has_resources = false;
182 auto add_resource_argument = [&](BlockArgument arg,
183 TF::ResourceType resource_type) {
184 Type arg_type = resource_type.getSubtypes().front();
185 arg.setType(arg_type);
186 resources[arg.getArgNumber()].live_value = arg;
187 argument_types[arg.getArgNumber()] = arg_type;
188 has_resources = true;
189 };
190
191 // Loop through the non `tf.VarHandleOp` resource arguments in the function,
192 // validate its uses and subtype, and store a mapping from that argument to
193 // itself as the current live value.
194 auto func_args = function.getArguments().take_front(
195 function.getNumArguments() - var_handle_shared_names.size());
196 for (BlockArgument& func_arg : func_args) {
197 auto resource_type =
198 getElementTypeOrSelf(func_arg.getType()).dyn_cast<TF::ResourceType>();
199 if (!resource_type) continue;
200 if (failed(ValidateResourceArgument(function, func_arg, resource_type)))
201 return failure();
202
203 add_resource_argument(func_arg, resource_type);
204 }
205
206 // Loop through `tf.VarHandleOp` resource arguments in the function and store
207 // a mapping from that argument to itself as the current live value. No
208 // validations are necessary here as these arguments were validated prior to
209 // being added.
210 auto var_handle_args =
211 function.getArguments().take_back(var_handle_shared_names.size());
212 for (BlockArgument& var_handle_arg : var_handle_args) {
213 auto resource_type =
214 getElementTypeOrSelf(var_handle_arg.getType()).cast<TF::ResourceType>();
215 add_resource_argument(var_handle_arg, resource_type);
216 }
217
218 if (!has_resources) return success();
219
220 // We initially assign the argument for a resource as the live value for the
221 // resource. We then walk through the operations in the function in their
222 // lexical order, to update the live value for the resource when we see a
223 // store to the resource and replace reads of the resource with uses of its
224 // live value.
225 for (Operation& op : llvm::make_early_inc_range(block)) {
226 if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
227 if (auto func_arg = read_op.resource().dyn_cast<BlockArgument>()) {
228 if (func_arg.getOwner() != &block)
229 return read_op.emitOpError(kResourceFunctionMsg);
230
231 ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
232 resource_info.read = true;
233 read_op.value().replaceAllUsesWith(resource_info.live_value);
234 } else {
235 return read_op.emitOpError(kInvalidResourceMsg);
236 }
237
238 read_op.erase();
239 } else if (auto write_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
240 if (auto func_arg = write_op.resource().dyn_cast<BlockArgument>()) {
241 if (func_arg.getOwner() != &block)
242 return write_op.emitOpError(kResourceFunctionMsg);
243
244 ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
245 resource_info.write = true;
246 resource_info.live_value = write_op.value();
247 } else {
248 return read_op.emitOpError(kInvalidResourceMsg);
249 }
250
251 write_op.erase();
252 }
253 }
254
255 const int64_t num_results_before = function.getNumResults();
256 auto return_operands = llvm::to_vector<4>(return_op.getOperands());
257 auto result_types = llvm::to_vector<4>(return_op.getOperandTypes());
258 llvm::SmallVector<std::pair<int64_t, llvm::StringRef>, 4>
259 output_only_resources;
260 llvm::SmallVector<std::pair<int64_t, int64_t>, 4> input_output_alias;
261
262 // Collect new return values for variable writes and either (a) output-only
263 // resource attributes (if the resource is not promoted to an argument) or (b)
264 // mapping from resource input index to output alias (if the resource has been
265 // promoted to an argument). Resource arguments that were originally
266 // `tf.VarHandleOp` but not read are collected and then removed.
267 OpBuilder builder(return_op);
268 const int var_handles_start_idx =
269 function.getNumArguments() - var_handle_shared_names.size();
270 int new_argument_index = 0;
271 llvm::SmallVector<int, 4> argument_indices_to_remove;
272 for (auto resource_and_index : llvm::enumerate(resources)) {
273 const auto& resource = resource_and_index.value();
274 if (!resource.live_value) {
275 // Ignore non resource arguments.
276 ++new_argument_index;
277 continue;
278 }
279
280 const int64_t index = resource_and_index.index();
281 const bool is_var_handle = index >= var_handles_start_idx;
282 if (resource.write) {
283 if (!is_var_handle || resource.read) {
284 input_output_alias.push_back(
285 {new_argument_index, return_operands.size()});
286 } else if (is_var_handle) {
287 output_only_resources.push_back(
288 {return_operands.size(),
289 var_handle_shared_names[index - var_handles_start_idx]});
290 }
291 return_operands.push_back(resource.live_value);
292 result_types.push_back(resource.live_value.getType());
293 }
294
295 if (is_var_handle && !resource.read) {
296 assert(block.getArgument(index).getUses().empty());
297 argument_indices_to_remove.push_back(index);
298 } else {
299 if (is_var_handle) {
300 // Add resource_name attribute to VarHandleOp read.
301 function.setArgAttr(
302 new_argument_index, kResourceNameArgAttr,
303 builder.getStringAttr(
304 var_handle_shared_names[index - var_handles_start_idx]));
305 }
306 ++new_argument_index;
307 }
308 }
309
310 // Remove unread var handle arguments.
311 for (int argument_index_to_remove :
312 llvm::reverse(argument_indices_to_remove)) {
313 block.eraseArgument(argument_index_to_remove);
314 argument_types.erase(argument_types.begin() + argument_index_to_remove);
315 }
316
317 // Rewrite return if there are variable writes.
318 const int return_operands_size = return_operands.size();
319 if (return_operands_size > num_results_before) {
320 builder.create<ReturnOp>(return_op.getLoc(), return_operands);
321 return_op.erase();
322 }
323
324 // Update function argument and result types with new resource subtypes.
325 function.setType(builder.getFunctionType(argument_types, result_types));
326
327 // Add resource_name attribute to the output for the resources.
328 for (auto& resource : output_only_resources)
329 function.setResultAttr(resource.first, kResourceNameArgAttr,
330 builder.getStringAttr(resource.second));
331
332 // Add aliasing_output attribute to the input argument for the resources that
333 // are updated by the function.
334 for (auto& input_output : input_output_alias)
335 function.setArgAttr(input_output.first, "tf.aliasing_output",
336 builder.getI64IntegerAttr(input_output.second));
337
338 return success();
339 }
340
341 class PromoteResourcesToArgsPass
342 : public PromoteResourcesToArgsPassBase<PromoteResourcesToArgsPass> {
343 public:
344 void runOnOperation() override;
345 };
346
runOnOperation()347 void PromoteResourcesToArgsPass::runOnOperation() {
348 ModuleOp module = getOperation();
349 FuncOp main_func = module.lookupSymbol<FuncOp>("main");
350 if (!main_func) return;
351
352 // This routine should only be called when control flow operations are still
353 // represented with TF IfOp and WhileOp operations. In this case, there should
354 // be only one basic blocks in the MLIR representation.
355 if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure();
356
357 llvm::SmallVector<std::string, 4> var_handle_shared_names;
358 if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) ||
359 failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true,
360 &var_handle_shared_names)) ||
361 failed(PromoteResourcesToArguments(main_func, var_handle_shared_names)))
362 return signalPassFailure();
363 }
364
365 class PromoteVarHandlesToArgsPass
366 : public PromoteVarHandlesToArgsPassBase<PromoteVarHandlesToArgsPass> {
367 public:
368 void runOnOperation() override;
369 };
370
runOnOperation()371 void PromoteVarHandlesToArgsPass::runOnOperation() {
372 ModuleOp module = getOperation();
373 MLIRContext* context = module.getContext();
374 for (auto function : module.getOps<FuncOp>()) {
375 if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
376
377 llvm::SmallVector<std::string, 4> var_handle_shared_names;
378 (void)PromoteVarHandlesToArguments(function, /*add_validation=*/false,
379 &var_handle_shared_names);
380
381 // Add resource names for each `tf.VarHandleOp` that were promoted to
382 // resource arguments.
383 const int var_handle_args_offset =
384 function.getNumArguments() - var_handle_shared_names.size();
385 for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
386 function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
387 kResourceNameArgAttr,
388 StringAttr::get(context, var_name_and_index.value()));
389 }
390 }
391
392 } // namespace
393
CreatePromoteResourcesToArgsPass()394 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
395 return std::make_unique<PromoteResourcesToArgsPass>();
396 }
397
CreatePromoteVarHandlesToArgsPass()398 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
399 return std::make_unique<PromoteVarHandlesToArgsPass>();
400 }
401
402 } // namespace TF
403 } // namespace mlir
404