• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/PointerUnion.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Types.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
38 
39 namespace mlir {
40 namespace TF {
41 namespace {
42 
43 constexpr char kResourceFunctionMsg[] =
44     "expects function level resource argument";
45 constexpr char kInvalidResourceMsg[] =
46     "expects resource to be a VarHandleOp or function argument";
47 constexpr char kResourceNameArgAttr[] = "tf.resource_name";
48 
49 // Checks if a function has only one block.
CheckSingleBlockFunction(func::FuncOp function)50 mlir::LogicalResult CheckSingleBlockFunction(func::FuncOp function) {
51   if (!llvm::hasSingleElement(function)) {
52     return function.emitError()
53            << "expects function '" << function.getName()
54            << "' to have 1 block, got " << function.getBlocks().size();
55   }
56   return success();
57 }
58 
59 // Collects names of users of a resource that are not `tf.ReadVariableOp` and
60 // not `tf.AssignVariableOp`.
GetCompositeResourceUserNames(Value resource)61 llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
62     Value resource) {
63   // SmallSet will use a vector when there is only one element and use std::set
64   // when there are more than one elements. This ensures that the operations in
65   // the error message are ordered.
66   llvm::SmallSet<llvm::StringRef, 1> composite_users;
67   for (Operation* user : resource.getUsers())
68     if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
69       composite_users.insert(user->getName().getStringRef());
70 
71   return composite_users;
72 }
73 
74 // Checks that the only users of `tf.VarHandleOp` are
75 // `tf.ReadVariableOp` and `tf.AssignVariableOp`.
ValidateVarHandle(TF::VarHandleOp var_handle_op)76 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
77   auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
78   if (!composite_ops.empty())
79     return var_handle_op.emitOpError()
80            << "expects users to be 'tf.ReadVariableOp' or "
81               "'tf.AssignVariableOp', got ["
82            << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
83            << "]";
84 
85   return success();
86 }
87 
88 // Checks if resource argument has a valid resource subtype and its users are of
89 // `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
ValidateResourceArgument(func::FuncOp function,BlockArgument resource_arg,TF::ResourceType resource_type)90 mlir::LogicalResult ValidateResourceArgument(func::FuncOp function,
91                                              BlockArgument resource_arg,
92                                              TF::ResourceType resource_type) {
93   if (resource_type.getSubtypes().size() != 1)
94     return function.emitError()
95            << "expects resource type of argument "
96            << resource_arg.getArgNumber() << " to have one subtype, got "
97            << resource_type;
98 
99   auto composite_ops = GetCompositeResourceUserNames(resource_arg);
100   if (!composite_ops.empty())
101     return function.emitError()
102            << "expects users of resource argument "
103            << resource_arg.getArgNumber()
104            << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got ["
105            << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
106            << "]";
107 
108   return success();
109 }
110 
VariableIsInitialized(TF::VarHandleOp var_handle_op)111 bool VariableIsInitialized(TF::VarHandleOp var_handle_op) {
112   auto is_variable_initialized =
113       var_handle_op->getAttrOfType<BoolAttr>("_is_initialized");
114   // Assume variable is initialized if attribute is not set.
115   // There are paths that doesn't mark the variables. All variables
116   // that doesn't have the attribute will be promoted.
117   return !is_variable_initialized || is_variable_initialized.getValue();
118 }
119 
120 // Adds resource arguments for every unique (name) variable handle. Associated
121 // `tf.VarHandleOp` are removed from the function. Variable shared names are
122 // returned in `var_handle_shared_names` based on the ordering of added resource
123 // arguments.
PromoteVarHandlesToArguments(func::FuncOp function,bool add_validation,llvm::SmallVectorImpl<std::string> * var_handle_shared_names)124 mlir::LogicalResult PromoteVarHandlesToArguments(
125     func::FuncOp function, bool add_validation,
126     llvm::SmallVectorImpl<std::string>* var_handle_shared_names) {
127   Block& block = function.front();
128   auto func_type = function.getFunctionType();
129 
130   auto func_arg_types = llvm::to_vector<4>(func_type.getInputs());
131   llvm::SmallDenseMap<llvm::StringRef, int> var_arg_index_by_name;
132   for (auto var_handle_op :
133        llvm::make_early_inc_range(block.getOps<TF::VarHandleOp>())) {
134     if (add_validation && failed(ValidateVarHandle(var_handle_op)))
135       return failure();
136     // In the case of variables that are not initialized at graph creation
137     // then we keep them as VarHandleOps.
138     if (!VariableIsInitialized(var_handle_op)) continue;
139 
140     llvm::StringRef name = var_handle_op.shared_nameAttr().getValue();
141     auto it = var_arg_index_by_name.insert({name, func_arg_types.size()});
142     if (it.second) {
143       var_handle_shared_names->emplace_back(name);
144       auto resource_type = var_handle_op.resource().getType();
145       func_arg_types.push_back(resource_type);
146       var_handle_op.resource().replaceAllUsesWith(
147           block.addArgument(resource_type, var_handle_op.getLoc()));
148     } else {
149       var_handle_op.resource().replaceAllUsesWith(
150           block.getArgument(it.first->getSecond()));
151     }
152     var_handle_op.erase();
153   }
154 
155   if (!var_handle_shared_names->empty())
156     function.setType(FunctionType::get(function.getContext(), func_arg_types,
157                                        func_type.getResults()));
158 
159   return success();
160 }
161 
162 // Records the current live value for a resource variable and whether a read or
163 // write on the variable occurred.
164 struct ResourceInfo {
165   Value live_value = nullptr;
166   bool read = false;
167   bool write = false;
168 };
169 
PromoteResourcesToArguments(func::FuncOp function,llvm::ArrayRef<std::string> var_handle_shared_names)170 LogicalResult PromoteResourcesToArguments(
171     func::FuncOp function,
172     llvm::ArrayRef<std::string> var_handle_shared_names) {
173   Block& block = function.front();
174 
175   auto return_op =
176       llvm::dyn_cast_or_null<func::ReturnOp>(block.getTerminator());
177   if (!return_op)
178     return function.emitError() << "expects function '" << function.getName()
179                                 << "' to have a MLIR ReturnOp";
180 
181   llvm::SmallVector<ResourceInfo, 4> resources(function.getNumArguments());
182   auto argument_types =
183       llvm::to_vector<4>(function.getFunctionType().getInputs());
184   bool has_resources = false;
185   auto add_resource_argument = [&](BlockArgument arg,
186                                    TF::ResourceType resource_type) {
187     Type arg_type = resource_type.getSubtypes().front();
188     arg.setType(arg_type);
189     resources[arg.getArgNumber()].live_value = arg;
190     argument_types[arg.getArgNumber()] = arg_type;
191     has_resources = true;
192   };
193 
194   // Loop through the non `tf.VarHandleOp` resource arguments in the function,
195   // validate its uses and subtype, and store a mapping from that argument to
196   // itself as the current live value.
197   auto func_args = function.getArguments().take_front(
198       function.getNumArguments() - var_handle_shared_names.size());
199   for (BlockArgument& func_arg : func_args) {
200     auto resource_type =
201         getElementTypeOrSelf(func_arg.getType()).dyn_cast<TF::ResourceType>();
202     if (!resource_type) continue;
203     if (failed(ValidateResourceArgument(function, func_arg, resource_type)))
204       return failure();
205 
206     add_resource_argument(func_arg, resource_type);
207   }
208 
209   // Loop through `tf.VarHandleOp` resource arguments in the function and store
210   // a mapping from that argument to itself as the current live value. No
211   // validations are necessary here as these arguments were validated prior to
212   // being added.
213   auto var_handle_args =
214       function.getArguments().take_back(var_handle_shared_names.size());
215   for (BlockArgument& var_handle_arg : var_handle_args) {
216     auto resource_type =
217         getElementTypeOrSelf(var_handle_arg.getType()).cast<TF::ResourceType>();
218     add_resource_argument(var_handle_arg, resource_type);
219   }
220 
221   if (!has_resources) return success();
222 
223   // We initially assign the argument for a resource as the live value for the
224   // resource. We then walk through the operations in the function in their
225   // lexical order, to update the live value for the resource when we see a
226   // store to the resource and replace reads of the resource with uses of its
227   // live value.
228   for (Operation& op : llvm::make_early_inc_range(block)) {
229     if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
230       if (auto func_arg = read_op.resource().dyn_cast<BlockArgument>()) {
231         if (func_arg.getOwner() != &block)
232           return read_op.emitOpError(kResourceFunctionMsg);
233 
234         ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
235         resource_info.read = true;
236         read_op.value().replaceAllUsesWith(resource_info.live_value);
237       } else {
238         return read_op.emitOpError(kInvalidResourceMsg);
239       }
240 
241       read_op.erase();
242     } else if (auto write_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
243       if (auto func_arg = write_op.resource().dyn_cast<BlockArgument>()) {
244         if (func_arg.getOwner() != &block)
245           return write_op.emitOpError(kResourceFunctionMsg);
246 
247         ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
248         resource_info.write = true;
249         resource_info.live_value = write_op.value();
250       } else {
251         return read_op.emitOpError(kInvalidResourceMsg);
252       }
253 
254       write_op.erase();
255     }
256   }
257 
258   const int64_t num_results_before = function.getNumResults();
259   auto return_operands = llvm::to_vector<4>(return_op.getOperands());
260   auto result_types = llvm::to_vector<4>(return_op.getOperandTypes());
261   llvm::SmallVector<std::pair<int64_t, llvm::StringRef>, 4>
262       output_only_resources;
263   llvm::SmallVector<std::pair<int64_t, int64_t>, 4> input_output_alias;
264 
265   // Collect new return values for variable writes and either (a) output-only
266   // resource attributes (if the resource is not promoted to an argument) or (b)
267   // mapping from resource input index to output alias (if the resource has been
268   // promoted to an argument). Resource arguments that were originally
269   // `tf.VarHandleOp` but not read are collected and then removed.
270   OpBuilder builder(return_op);
271   const int var_handles_start_idx =
272       function.getNumArguments() - var_handle_shared_names.size();
273   int new_argument_index = 0;
274   llvm::SmallVector<int, 4> argument_indices_to_remove;
275   for (auto resource_and_index : llvm::enumerate(resources)) {
276     const auto& resource = resource_and_index.value();
277     if (!resource.live_value) {
278       // Ignore non resource arguments.
279       ++new_argument_index;
280       continue;
281     }
282 
283     const int64_t index = resource_and_index.index();
284     const bool is_var_handle = index >= var_handles_start_idx;
285     if (resource.write) {
286       if (!is_var_handle || resource.read) {
287         input_output_alias.push_back(
288             {new_argument_index, return_operands.size()});
289       } else if (is_var_handle) {
290         output_only_resources.push_back(
291             {return_operands.size(),
292              var_handle_shared_names[index - var_handles_start_idx]});
293       }
294       return_operands.push_back(resource.live_value);
295       result_types.push_back(resource.live_value.getType());
296     }
297 
298     if (is_var_handle && !resource.read) {
299       assert(block.getArgument(index).getUses().empty());
300       argument_indices_to_remove.push_back(index);
301     } else {
302       if (is_var_handle) {
303         // Add resource_name attribute to VarHandleOp read.
304         function.setArgAttr(
305             new_argument_index, kResourceNameArgAttr,
306             builder.getStringAttr(
307                 var_handle_shared_names[index - var_handles_start_idx]));
308       }
309       ++new_argument_index;
310     }
311   }
312 
313   // Remove unread var handle arguments.
314   for (int argument_index_to_remove :
315        llvm::reverse(argument_indices_to_remove)) {
316     block.eraseArgument(argument_index_to_remove);
317     argument_types.erase(argument_types.begin() + argument_index_to_remove);
318   }
319 
320   // Rewrite return if there are variable writes.
321   const int return_operands_size = return_operands.size();
322   if (return_operands_size > num_results_before) {
323     builder.create<func::ReturnOp>(return_op.getLoc(), return_operands);
324     return_op.erase();
325   }
326 
327   // Update function argument and result types with new resource subtypes.
328   function.setType(builder.getFunctionType(argument_types, result_types));
329 
330   // Add resource_name attribute to the output for the resources.
331   for (auto& resource : output_only_resources)
332     function.setResultAttr(resource.first, kResourceNameArgAttr,
333                            builder.getStringAttr(resource.second));
334 
335   // Add aliasing_output attribute to the input argument for the resources that
336   // are updated by the function.
337   for (auto& input_output : input_output_alias)
338     function.setArgAttr(input_output.first, "tf.aliasing_output",
339                         builder.getI64IntegerAttr(input_output.second));
340 
341   return success();
342 }
343 
344 class PromoteResourcesToArgsPass
345     : public PromoteResourcesToArgsPassBase<PromoteResourcesToArgsPass> {
346  public:
347   PromoteResourcesToArgsPass() = default;
348   explicit PromoteResourcesToArgsPass(llvm::ArrayRef<std::string> functions);
349   void runOnOperation() override;
350 };
351 
PromoteResourcesToArgsPass(llvm::ArrayRef<std::string> functions)352 PromoteResourcesToArgsPass::PromoteResourcesToArgsPass(
353     llvm::ArrayRef<std::string> functions) {
354   functions_ = functions;
355 }
356 
runOnOperation()357 void PromoteResourcesToArgsPass::runOnOperation() {
358   ModuleOp module = getOperation();
359   if (llvm::size(functions_) == 0) {
360     functions_ = {"main"};
361   }
362   SymbolTable symbolTable(module);
363   for (const std::string& f : functions_) {
364     func::FuncOp func = symbolTable.lookup<func::FuncOp>(f);
365     if (!func) continue;
366 
367     // This routine should only be called when control flow operations are still
368     // represented with TF IfOp and WhileOp operations. In this case, there
369     // should be only one basic blocks in the MLIR representation.
370     if (failed(CheckSingleBlockFunction(func))) return signalPassFailure();
371 
372     llvm::SmallVector<std::string, 4> var_handle_shared_names;
373     if (failed(ResourceLiftingForFunctionalControlFlow(func)) ||
374         failed(PromoteVarHandlesToArguments(func, /*add_validation=*/true,
375                                             &var_handle_shared_names)) ||
376         failed(PromoteResourcesToArguments(func, var_handle_shared_names)))
377       return signalPassFailure();
378   }
379 }
380 
381 class PromoteVarHandlesToArgsPass
382     : public PromoteVarHandlesToArgsPassBase<PromoteVarHandlesToArgsPass> {
383  public:
384   void runOnOperation() override;
385 };
386 
runOnOperation()387 void PromoteVarHandlesToArgsPass::runOnOperation() {
388   ModuleOp module = getOperation();
389   MLIRContext* context = module.getContext();
390   for (auto function : module.getOps<func::FuncOp>()) {
391     if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
392 
393     llvm::SmallVector<std::string, 4> var_handle_shared_names;
394     (void)PromoteVarHandlesToArguments(function, /*add_validation=*/false,
395                                        &var_handle_shared_names);
396 
397     // Add resource names for each `tf.VarHandleOp` that were promoted to
398     // resource arguments.
399     const int var_handle_args_offset =
400         function.getNumArguments() - var_handle_shared_names.size();
401     for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
402       function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
403                           kResourceNameArgAttr,
404                           StringAttr::get(context, var_name_and_index.value()));
405   }
406 }
407 
408 }  // namespace
409 
CreatePromoteResourcesToArgsPass(llvm::ArrayRef<std::string> functions)410 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass(
411     llvm::ArrayRef<std::string> functions) {
412   return std::make_unique<PromoteResourcesToArgsPass>(functions);
413 }
414 
CreatePromoteVarHandlesToArgsPass()415 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
416   return std::make_unique<PromoteVarHandlesToArgsPass>();
417 }
418 
419 }  // namespace TF
420 }  // namespace mlir
421