1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/PointerUnion.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
26 #include "mlir/IR/Attributes.h" // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
30 #include "mlir/IR/Types.h" // from @llvm-project
31 #include "mlir/IR/Value.h" // from @llvm-project
32 #include "mlir/Pass/Pass.h" // from @llvm-project
33 #include "mlir/Support/LogicalResult.h" // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
38
39 namespace mlir {
40 namespace TF {
41 namespace {
42
43 constexpr char kResourceFunctionMsg[] =
44 "expects function level resource argument";
45 constexpr char kInvalidResourceMsg[] =
46 "expects resource to be a VarHandleOp or function argument";
47 constexpr char kResourceNameArgAttr[] = "tf.resource_name";
48
49 // Checks if a function has only one block.
CheckSingleBlockFunction(func::FuncOp function)50 mlir::LogicalResult CheckSingleBlockFunction(func::FuncOp function) {
51 if (!llvm::hasSingleElement(function)) {
52 return function.emitError()
53 << "expects function '" << function.getName()
54 << "' to have 1 block, got " << function.getBlocks().size();
55 }
56 return success();
57 }
58
59 // Collects names of users of a resource that are not `tf.ReadVariableOp` and
60 // not `tf.AssignVariableOp`.
GetCompositeResourceUserNames(Value resource)61 llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
62 Value resource) {
63 // SmallSet will use a vector when there is only one element and use std::set
64 // when there are more than one elements. This ensures that the operations in
65 // the error message are ordered.
66 llvm::SmallSet<llvm::StringRef, 1> composite_users;
67 for (Operation* user : resource.getUsers())
68 if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
69 composite_users.insert(user->getName().getStringRef());
70
71 return composite_users;
72 }
73
74 // Checks that the only users of `tf.VarHandleOp` are
75 // `tf.ReadVariableOp` and `tf.AssignVariableOp`.
ValidateVarHandle(TF::VarHandleOp var_handle_op)76 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
77 auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
78 if (!composite_ops.empty())
79 return var_handle_op.emitOpError()
80 << "expects users to be 'tf.ReadVariableOp' or "
81 "'tf.AssignVariableOp', got ["
82 << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
83 << "]";
84
85 return success();
86 }
87
88 // Checks if resource argument has a valid resource subtype and its users are of
89 // `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
ValidateResourceArgument(func::FuncOp function,BlockArgument resource_arg,TF::ResourceType resource_type)90 mlir::LogicalResult ValidateResourceArgument(func::FuncOp function,
91 BlockArgument resource_arg,
92 TF::ResourceType resource_type) {
93 if (resource_type.getSubtypes().size() != 1)
94 return function.emitError()
95 << "expects resource type of argument "
96 << resource_arg.getArgNumber() << " to have one subtype, got "
97 << resource_type;
98
99 auto composite_ops = GetCompositeResourceUserNames(resource_arg);
100 if (!composite_ops.empty())
101 return function.emitError()
102 << "expects users of resource argument "
103 << resource_arg.getArgNumber()
104 << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got ["
105 << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
106 << "]";
107
108 return success();
109 }
110
VariableIsInitialized(TF::VarHandleOp var_handle_op)111 bool VariableIsInitialized(TF::VarHandleOp var_handle_op) {
112 auto is_variable_initialized =
113 var_handle_op->getAttrOfType<BoolAttr>("_is_initialized");
114 // Assume variable is initialized if attribute is not set.
115 // There are paths that doesn't mark the variables. All variables
116 // that doesn't have the attribute will be promoted.
117 return !is_variable_initialized || is_variable_initialized.getValue();
118 }
119
120 // Adds resource arguments for every unique (name) variable handle. Associated
121 // `tf.VarHandleOp` are removed from the function. Variable shared names are
122 // returned in `var_handle_shared_names` based on the ordering of added resource
123 // arguments.
PromoteVarHandlesToArguments(func::FuncOp function,bool add_validation,llvm::SmallVectorImpl<std::string> * var_handle_shared_names)124 mlir::LogicalResult PromoteVarHandlesToArguments(
125 func::FuncOp function, bool add_validation,
126 llvm::SmallVectorImpl<std::string>* var_handle_shared_names) {
127 Block& block = function.front();
128 auto func_type = function.getFunctionType();
129
130 auto func_arg_types = llvm::to_vector<4>(func_type.getInputs());
131 llvm::SmallDenseMap<llvm::StringRef, int> var_arg_index_by_name;
132 for (auto var_handle_op :
133 llvm::make_early_inc_range(block.getOps<TF::VarHandleOp>())) {
134 if (add_validation && failed(ValidateVarHandle(var_handle_op)))
135 return failure();
136 // In the case of variables that are not initialized at graph creation
137 // then we keep them as VarHandleOps.
138 if (!VariableIsInitialized(var_handle_op)) continue;
139
140 llvm::StringRef name = var_handle_op.shared_nameAttr().getValue();
141 auto it = var_arg_index_by_name.insert({name, func_arg_types.size()});
142 if (it.second) {
143 var_handle_shared_names->emplace_back(name);
144 auto resource_type = var_handle_op.resource().getType();
145 func_arg_types.push_back(resource_type);
146 var_handle_op.resource().replaceAllUsesWith(
147 block.addArgument(resource_type, var_handle_op.getLoc()));
148 } else {
149 var_handle_op.resource().replaceAllUsesWith(
150 block.getArgument(it.first->getSecond()));
151 }
152 var_handle_op.erase();
153 }
154
155 if (!var_handle_shared_names->empty())
156 function.setType(FunctionType::get(function.getContext(), func_arg_types,
157 func_type.getResults()));
158
159 return success();
160 }
161
162 // Records the current live value for a resource variable and whether a read or
163 // write on the variable occurred.
164 struct ResourceInfo {
165 Value live_value = nullptr;
166 bool read = false;
167 bool write = false;
168 };
169
PromoteResourcesToArguments(func::FuncOp function,llvm::ArrayRef<std::string> var_handle_shared_names)170 LogicalResult PromoteResourcesToArguments(
171 func::FuncOp function,
172 llvm::ArrayRef<std::string> var_handle_shared_names) {
173 Block& block = function.front();
174
175 auto return_op =
176 llvm::dyn_cast_or_null<func::ReturnOp>(block.getTerminator());
177 if (!return_op)
178 return function.emitError() << "expects function '" << function.getName()
179 << "' to have a MLIR ReturnOp";
180
181 llvm::SmallVector<ResourceInfo, 4> resources(function.getNumArguments());
182 auto argument_types =
183 llvm::to_vector<4>(function.getFunctionType().getInputs());
184 bool has_resources = false;
185 auto add_resource_argument = [&](BlockArgument arg,
186 TF::ResourceType resource_type) {
187 Type arg_type = resource_type.getSubtypes().front();
188 arg.setType(arg_type);
189 resources[arg.getArgNumber()].live_value = arg;
190 argument_types[arg.getArgNumber()] = arg_type;
191 has_resources = true;
192 };
193
194 // Loop through the non `tf.VarHandleOp` resource arguments in the function,
195 // validate its uses and subtype, and store a mapping from that argument to
196 // itself as the current live value.
197 auto func_args = function.getArguments().take_front(
198 function.getNumArguments() - var_handle_shared_names.size());
199 for (BlockArgument& func_arg : func_args) {
200 auto resource_type =
201 getElementTypeOrSelf(func_arg.getType()).dyn_cast<TF::ResourceType>();
202 if (!resource_type) continue;
203 if (failed(ValidateResourceArgument(function, func_arg, resource_type)))
204 return failure();
205
206 add_resource_argument(func_arg, resource_type);
207 }
208
209 // Loop through `tf.VarHandleOp` resource arguments in the function and store
210 // a mapping from that argument to itself as the current live value. No
211 // validations are necessary here as these arguments were validated prior to
212 // being added.
213 auto var_handle_args =
214 function.getArguments().take_back(var_handle_shared_names.size());
215 for (BlockArgument& var_handle_arg : var_handle_args) {
216 auto resource_type =
217 getElementTypeOrSelf(var_handle_arg.getType()).cast<TF::ResourceType>();
218 add_resource_argument(var_handle_arg, resource_type);
219 }
220
221 if (!has_resources) return success();
222
223 // We initially assign the argument for a resource as the live value for the
224 // resource. We then walk through the operations in the function in their
225 // lexical order, to update the live value for the resource when we see a
226 // store to the resource and replace reads of the resource with uses of its
227 // live value.
228 for (Operation& op : llvm::make_early_inc_range(block)) {
229 if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
230 if (auto func_arg = read_op.resource().dyn_cast<BlockArgument>()) {
231 if (func_arg.getOwner() != &block)
232 return read_op.emitOpError(kResourceFunctionMsg);
233
234 ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
235 resource_info.read = true;
236 read_op.value().replaceAllUsesWith(resource_info.live_value);
237 } else {
238 return read_op.emitOpError(kInvalidResourceMsg);
239 }
240
241 read_op.erase();
242 } else if (auto write_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
243 if (auto func_arg = write_op.resource().dyn_cast<BlockArgument>()) {
244 if (func_arg.getOwner() != &block)
245 return write_op.emitOpError(kResourceFunctionMsg);
246
247 ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
248 resource_info.write = true;
249 resource_info.live_value = write_op.value();
250 } else {
251 return read_op.emitOpError(kInvalidResourceMsg);
252 }
253
254 write_op.erase();
255 }
256 }
257
258 const int64_t num_results_before = function.getNumResults();
259 auto return_operands = llvm::to_vector<4>(return_op.getOperands());
260 auto result_types = llvm::to_vector<4>(return_op.getOperandTypes());
261 llvm::SmallVector<std::pair<int64_t, llvm::StringRef>, 4>
262 output_only_resources;
263 llvm::SmallVector<std::pair<int64_t, int64_t>, 4> input_output_alias;
264
265 // Collect new return values for variable writes and either (a) output-only
266 // resource attributes (if the resource is not promoted to an argument) or (b)
267 // mapping from resource input index to output alias (if the resource has been
268 // promoted to an argument). Resource arguments that were originally
269 // `tf.VarHandleOp` but not read are collected and then removed.
270 OpBuilder builder(return_op);
271 const int var_handles_start_idx =
272 function.getNumArguments() - var_handle_shared_names.size();
273 int new_argument_index = 0;
274 llvm::SmallVector<int, 4> argument_indices_to_remove;
275 for (auto resource_and_index : llvm::enumerate(resources)) {
276 const auto& resource = resource_and_index.value();
277 if (!resource.live_value) {
278 // Ignore non resource arguments.
279 ++new_argument_index;
280 continue;
281 }
282
283 const int64_t index = resource_and_index.index();
284 const bool is_var_handle = index >= var_handles_start_idx;
285 if (resource.write) {
286 if (!is_var_handle || resource.read) {
287 input_output_alias.push_back(
288 {new_argument_index, return_operands.size()});
289 } else if (is_var_handle) {
290 output_only_resources.push_back(
291 {return_operands.size(),
292 var_handle_shared_names[index - var_handles_start_idx]});
293 }
294 return_operands.push_back(resource.live_value);
295 result_types.push_back(resource.live_value.getType());
296 }
297
298 if (is_var_handle && !resource.read) {
299 assert(block.getArgument(index).getUses().empty());
300 argument_indices_to_remove.push_back(index);
301 } else {
302 if (is_var_handle) {
303 // Add resource_name attribute to VarHandleOp read.
304 function.setArgAttr(
305 new_argument_index, kResourceNameArgAttr,
306 builder.getStringAttr(
307 var_handle_shared_names[index - var_handles_start_idx]));
308 }
309 ++new_argument_index;
310 }
311 }
312
313 // Remove unread var handle arguments.
314 for (int argument_index_to_remove :
315 llvm::reverse(argument_indices_to_remove)) {
316 block.eraseArgument(argument_index_to_remove);
317 argument_types.erase(argument_types.begin() + argument_index_to_remove);
318 }
319
320 // Rewrite return if there are variable writes.
321 const int return_operands_size = return_operands.size();
322 if (return_operands_size > num_results_before) {
323 builder.create<func::ReturnOp>(return_op.getLoc(), return_operands);
324 return_op.erase();
325 }
326
327 // Update function argument and result types with new resource subtypes.
328 function.setType(builder.getFunctionType(argument_types, result_types));
329
330 // Add resource_name attribute to the output for the resources.
331 for (auto& resource : output_only_resources)
332 function.setResultAttr(resource.first, kResourceNameArgAttr,
333 builder.getStringAttr(resource.second));
334
335 // Add aliasing_output attribute to the input argument for the resources that
336 // are updated by the function.
337 for (auto& input_output : input_output_alias)
338 function.setArgAttr(input_output.first, "tf.aliasing_output",
339 builder.getI64IntegerAttr(input_output.second));
340
341 return success();
342 }
343
344 class PromoteResourcesToArgsPass
345 : public PromoteResourcesToArgsPassBase<PromoteResourcesToArgsPass> {
346 public:
347 PromoteResourcesToArgsPass() = default;
348 explicit PromoteResourcesToArgsPass(llvm::ArrayRef<std::string> functions);
349 void runOnOperation() override;
350 };
351
PromoteResourcesToArgsPass(llvm::ArrayRef<std::string> functions)352 PromoteResourcesToArgsPass::PromoteResourcesToArgsPass(
353 llvm::ArrayRef<std::string> functions) {
354 functions_ = functions;
355 }
356
runOnOperation()357 void PromoteResourcesToArgsPass::runOnOperation() {
358 ModuleOp module = getOperation();
359 if (llvm::size(functions_) == 0) {
360 functions_ = {"main"};
361 }
362 SymbolTable symbolTable(module);
363 for (const std::string& f : functions_) {
364 func::FuncOp func = symbolTable.lookup<func::FuncOp>(f);
365 if (!func) continue;
366
367 // This routine should only be called when control flow operations are still
368 // represented with TF IfOp and WhileOp operations. In this case, there
369 // should be only one basic blocks in the MLIR representation.
370 if (failed(CheckSingleBlockFunction(func))) return signalPassFailure();
371
372 llvm::SmallVector<std::string, 4> var_handle_shared_names;
373 if (failed(ResourceLiftingForFunctionalControlFlow(func)) ||
374 failed(PromoteVarHandlesToArguments(func, /*add_validation=*/true,
375 &var_handle_shared_names)) ||
376 failed(PromoteResourcesToArguments(func, var_handle_shared_names)))
377 return signalPassFailure();
378 }
379 }
380
381 class PromoteVarHandlesToArgsPass
382 : public PromoteVarHandlesToArgsPassBase<PromoteVarHandlesToArgsPass> {
383 public:
384 void runOnOperation() override;
385 };
386
runOnOperation()387 void PromoteVarHandlesToArgsPass::runOnOperation() {
388 ModuleOp module = getOperation();
389 MLIRContext* context = module.getContext();
390 for (auto function : module.getOps<func::FuncOp>()) {
391 if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
392
393 llvm::SmallVector<std::string, 4> var_handle_shared_names;
394 (void)PromoteVarHandlesToArguments(function, /*add_validation=*/false,
395 &var_handle_shared_names);
396
397 // Add resource names for each `tf.VarHandleOp` that were promoted to
398 // resource arguments.
399 const int var_handle_args_offset =
400 function.getNumArguments() - var_handle_shared_names.size();
401 for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
402 function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
403 kResourceNameArgAttr,
404 StringAttr::get(context, var_name_and_index.value()));
405 }
406 }
407
408 } // namespace
409
CreatePromoteResourcesToArgsPass(llvm::ArrayRef<std::string> functions)410 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass(
411 llvm::ArrayRef<std::string> functions) {
412 return std::make_unique<PromoteResourcesToArgsPass>(functions);
413 }
414
CreatePromoteVarHandlesToArgsPass()415 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
416 return std::make_unique<PromoteVarHandlesToArgsPass>();
417 }
418
419 } // namespace TF
420 } // namespace mlir
421