• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/PointerUnion.h"
19 #include "llvm/ADT/STLExtras.h"
20 #include "llvm/ADT/SmallSet.h"
21 #include "llvm/ADT/SmallVector.h"
22 #include "llvm/ADT/StringExtras.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
26 #include "mlir/IR/Attributes.h"  // from @llvm-project
27 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
28 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
29 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
30 #include "mlir/IR/Types.h"  // from @llvm-project
31 #include "mlir/IR/Value.h"  // from @llvm-project
32 #include "mlir/Pass/Pass.h"  // from @llvm-project
33 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
34 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
36 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
37 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
38 
39 namespace mlir {
40 namespace TF {
41 namespace {
42 
43 constexpr char kResourceFunctionMsg[] =
44     "expects function level resource argument";
45 constexpr char kInvalidResourceMsg[] =
46     "expects resource to be a VarHandleOp or function argument";
47 constexpr char kResourceNameArgAttr[] = "tf.resource_name";
48 
49 // Checks if a function has only one block.
CheckSingleBlockFunction(FuncOp function)50 mlir::LogicalResult CheckSingleBlockFunction(FuncOp function) {
51   if (!llvm::hasSingleElement(function)) {
52     return function.emitError()
53            << "expects function '" << function.getName()
54            << "' to have 1 block, got " << function.getBlocks().size();
55   }
56   return success();
57 }
58 
59 // Collects names of users of a resource that are not `tf.ReadVariableOp` and
60 // not `tf.AssignVariableOp`.
GetCompositeResourceUserNames(Value resource)61 llvm::SmallSet<llvm::StringRef, 1> GetCompositeResourceUserNames(
62     Value resource) {
63   // SmallSet will use a vector when there is only one element and use std::set
64   // when there are more than one elements. This ensures that the operations in
65   // the error message are ordered.
66   llvm::SmallSet<llvm::StringRef, 1> composite_users;
67   for (Operation* user : resource.getUsers())
68     if (!llvm::isa<TF::ReadVariableOp, TF::AssignVariableOp>(user))
69       composite_users.insert(user->getName().getStringRef());
70 
71   return composite_users;
72 }
73 
74 // Checks that the only users of `tf.VarHandleOp` are
75 // `tf.ReadVariableOp` and `tf.AssignVariableOp`.
ValidateVarHandle(TF::VarHandleOp var_handle_op)76 mlir::LogicalResult ValidateVarHandle(TF::VarHandleOp var_handle_op) {
77   auto composite_ops = GetCompositeResourceUserNames(var_handle_op);
78   if (!composite_ops.empty())
79     return var_handle_op.emitOpError()
80            << "expects users to be 'tf.ReadVariableOp' or "
81               "'tf.AssignVariableOp', got ["
82            << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
83            << "]";
84 
85   return success();
86 }
87 
88 // Checks if resource argument has a valid resource subtype and its users are of
89 // `tf.ReadVariableOp` and `tf.AssignVariableOp` only.
ValidateResourceArgument(FuncOp function,BlockArgument resource_arg,TF::ResourceType resource_type)90 mlir::LogicalResult ValidateResourceArgument(FuncOp function,
91                                              BlockArgument resource_arg,
92                                              TF::ResourceType resource_type) {
93   if (resource_type.getSubtypes().size() != 1)
94     return function.emitError()
95            << "expects resource type of argument "
96            << resource_arg.getArgNumber() << " to have one subtype, got "
97            << resource_type;
98 
99   auto composite_ops = GetCompositeResourceUserNames(resource_arg);
100   if (!composite_ops.empty())
101     return function.emitError()
102            << "expects users of resource argument "
103            << resource_arg.getArgNumber()
104            << " to be 'tf.ReadVariableOp' or 'tf.AssignVariableOp', got ["
105            << llvm::join(composite_ops.begin(), composite_ops.end(), ", ")
106            << "]";
107 
108   return success();
109 }
110 
VariableIsInitialized(TF::VarHandleOp var_handle_op)111 bool VariableIsInitialized(TF::VarHandleOp var_handle_op) {
112   auto is_variable_initialized =
113       var_handle_op->getAttrOfType<BoolAttr>("_is_initialized");
114   // Assume variable is initialized if attribute is not set.
115   // There are paths that doesn't mark the variables. All variables
116   // that doesn't have the attribute will be promoted.
117   return !is_variable_initialized || is_variable_initialized.getValue();
118 }
119 
120 // Adds resource arguments for every unique (name) variable handle. Associated
121 // `tf.VarHandleOp` are removed from the function. Variable shared names are
122 // returned in `var_handle_shared_names` based on the ordering of added resource
123 // arguments.
PromoteVarHandlesToArguments(FuncOp function,bool add_validation,llvm::SmallVectorImpl<std::string> * var_handle_shared_names)124 mlir::LogicalResult PromoteVarHandlesToArguments(
125     FuncOp function, bool add_validation,
126     llvm::SmallVectorImpl<std::string>* var_handle_shared_names) {
127   Block& block = function.front();
128   auto func_type = function.getType();
129 
130   auto func_arg_types = llvm::to_vector<4>(func_type.getInputs());
131   llvm::SmallDenseMap<llvm::StringRef, int> var_arg_index_by_name;
132   for (auto var_handle_op :
133        llvm::make_early_inc_range(block.getOps<TF::VarHandleOp>())) {
134     if (add_validation && failed(ValidateVarHandle(var_handle_op)))
135       return failure();
136     // In the case of variables that are not initialized at graph creation
137     // then we keep them as VarHandleOps.
138     if (!VariableIsInitialized(var_handle_op)) continue;
139 
140     llvm::StringRef name = var_handle_op.shared_nameAttr().getValue();
141     auto it = var_arg_index_by_name.insert({name, func_arg_types.size()});
142     if (it.second) {
143       var_handle_shared_names->emplace_back(name);
144       auto resource_type = var_handle_op.resource().getType();
145       func_arg_types.push_back(resource_type);
146       var_handle_op.resource().replaceAllUsesWith(
147           block.addArgument(resource_type));
148     } else {
149       var_handle_op.resource().replaceAllUsesWith(
150           block.getArgument(it.first->getSecond()));
151     }
152     var_handle_op.erase();
153   }
154 
155   if (!var_handle_shared_names->empty())
156     function.setType(FunctionType::get(function.getContext(), func_arg_types,
157                                        func_type.getResults()));
158 
159   return success();
160 }
161 
162 // Records the current live value for a resource variable and whether a read or
163 // write on the variable occurred.
164 struct ResourceInfo {
165   Value live_value = nullptr;
166   bool read = false;
167   bool write = false;
168 };
169 
PromoteResourcesToArguments(FuncOp function,llvm::ArrayRef<std::string> var_handle_shared_names)170 LogicalResult PromoteResourcesToArguments(
171     FuncOp function, llvm::ArrayRef<std::string> var_handle_shared_names) {
172   Block& block = function.front();
173 
174   auto return_op = llvm::dyn_cast_or_null<ReturnOp>(block.getTerminator());
175   if (!return_op)
176     return function.emitError() << "expects function '" << function.getName()
177                                 << "' to have a MLIR ReturnOp";
178 
179   llvm::SmallVector<ResourceInfo, 4> resources(function.getNumArguments());
180   auto argument_types = llvm::to_vector<4>(function.getType().getInputs());
181   bool has_resources = false;
182   auto add_resource_argument = [&](BlockArgument arg,
183                                    TF::ResourceType resource_type) {
184     Type arg_type = resource_type.getSubtypes().front();
185     arg.setType(arg_type);
186     resources[arg.getArgNumber()].live_value = arg;
187     argument_types[arg.getArgNumber()] = arg_type;
188     has_resources = true;
189   };
190 
191   // Loop through the non `tf.VarHandleOp` resource arguments in the function,
192   // validate its uses and subtype, and store a mapping from that argument to
193   // itself as the current live value.
194   auto func_args = function.getArguments().take_front(
195       function.getNumArguments() - var_handle_shared_names.size());
196   for (BlockArgument& func_arg : func_args) {
197     auto resource_type =
198         getElementTypeOrSelf(func_arg.getType()).dyn_cast<TF::ResourceType>();
199     if (!resource_type) continue;
200     if (failed(ValidateResourceArgument(function, func_arg, resource_type)))
201       return failure();
202 
203     add_resource_argument(func_arg, resource_type);
204   }
205 
206   // Loop through `tf.VarHandleOp` resource arguments in the function and store
207   // a mapping from that argument to itself as the current live value. No
208   // validations are necessary here as these arguments were validated prior to
209   // being added.
210   auto var_handle_args =
211       function.getArguments().take_back(var_handle_shared_names.size());
212   for (BlockArgument& var_handle_arg : var_handle_args) {
213     auto resource_type =
214         getElementTypeOrSelf(var_handle_arg.getType()).cast<TF::ResourceType>();
215     add_resource_argument(var_handle_arg, resource_type);
216   }
217 
218   if (!has_resources) return success();
219 
220   // We initially assign the argument for a resource as the live value for the
221   // resource. We then walk through the operations in the function in their
222   // lexical order, to update the live value for the resource when we see a
223   // store to the resource and replace reads of the resource with uses of its
224   // live value.
225   for (Operation& op : llvm::make_early_inc_range(block)) {
226     if (auto read_op = llvm::dyn_cast<TF::ReadVariableOp>(&op)) {
227       if (auto func_arg = read_op.resource().dyn_cast<BlockArgument>()) {
228         if (func_arg.getOwner() != &block)
229           return read_op.emitOpError(kResourceFunctionMsg);
230 
231         ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
232         resource_info.read = true;
233         read_op.value().replaceAllUsesWith(resource_info.live_value);
234       } else {
235         return read_op.emitOpError(kInvalidResourceMsg);
236       }
237 
238       read_op.erase();
239     } else if (auto write_op = llvm::dyn_cast<TF::AssignVariableOp>(&op)) {
240       if (auto func_arg = write_op.resource().dyn_cast<BlockArgument>()) {
241         if (func_arg.getOwner() != &block)
242           return write_op.emitOpError(kResourceFunctionMsg);
243 
244         ResourceInfo& resource_info = resources[func_arg.getArgNumber()];
245         resource_info.write = true;
246         resource_info.live_value = write_op.value();
247       } else {
248         return read_op.emitOpError(kInvalidResourceMsg);
249       }
250 
251       write_op.erase();
252     }
253   }
254 
255   const int64_t num_results_before = function.getNumResults();
256   auto return_operands = llvm::to_vector<4>(return_op.getOperands());
257   auto result_types = llvm::to_vector<4>(return_op.getOperandTypes());
258   llvm::SmallVector<std::pair<int64_t, llvm::StringRef>, 4>
259       output_only_resources;
260   llvm::SmallVector<std::pair<int64_t, int64_t>, 4> input_output_alias;
261 
262   // Collect new return values for variable writes and either (a) output-only
263   // resource attributes (if the resource is not promoted to an argument) or (b)
264   // mapping from resource input index to output alias (if the resource has been
265   // promoted to an argument). Resource arguments that were originally
266   // `tf.VarHandleOp` but not read are collected and then removed.
267   OpBuilder builder(return_op);
268   const int var_handles_start_idx =
269       function.getNumArguments() - var_handle_shared_names.size();
270   int new_argument_index = 0;
271   llvm::SmallVector<int, 4> argument_indices_to_remove;
272   for (auto resource_and_index : llvm::enumerate(resources)) {
273     const auto& resource = resource_and_index.value();
274     if (!resource.live_value) {
275       // Ignore non resource arguments.
276       ++new_argument_index;
277       continue;
278     }
279 
280     const int64_t index = resource_and_index.index();
281     const bool is_var_handle = index >= var_handles_start_idx;
282     if (resource.write) {
283       if (!is_var_handle || resource.read) {
284         input_output_alias.push_back(
285             {new_argument_index, return_operands.size()});
286       } else if (is_var_handle) {
287         output_only_resources.push_back(
288             {return_operands.size(),
289              var_handle_shared_names[index - var_handles_start_idx]});
290       }
291       return_operands.push_back(resource.live_value);
292       result_types.push_back(resource.live_value.getType());
293     }
294 
295     if (is_var_handle && !resource.read) {
296       assert(block.getArgument(index).getUses().empty());
297       argument_indices_to_remove.push_back(index);
298     } else {
299       if (is_var_handle) {
300         // Add resource_name attribute to VarHandleOp read.
301         function.setArgAttr(
302             new_argument_index, kResourceNameArgAttr,
303             builder.getStringAttr(
304                 var_handle_shared_names[index - var_handles_start_idx]));
305       }
306       ++new_argument_index;
307     }
308   }
309 
310   // Remove unread var handle arguments.
311   for (int argument_index_to_remove :
312        llvm::reverse(argument_indices_to_remove)) {
313     block.eraseArgument(argument_index_to_remove);
314     argument_types.erase(argument_types.begin() + argument_index_to_remove);
315   }
316 
317   // Rewrite return if there are variable writes.
318   const int return_operands_size = return_operands.size();
319   if (return_operands_size > num_results_before) {
320     builder.create<ReturnOp>(return_op.getLoc(), return_operands);
321     return_op.erase();
322   }
323 
324   // Update function argument and result types with new resource subtypes.
325   function.setType(builder.getFunctionType(argument_types, result_types));
326 
327   // Add resource_name attribute to the output for the resources.
328   for (auto& resource : output_only_resources)
329     function.setResultAttr(resource.first, kResourceNameArgAttr,
330                            builder.getStringAttr(resource.second));
331 
332   // Add aliasing_output attribute to the input argument for the resources that
333   // are updated by the function.
334   for (auto& input_output : input_output_alias)
335     function.setArgAttr(input_output.first, "tf.aliasing_output",
336                         builder.getI64IntegerAttr(input_output.second));
337 
338   return success();
339 }
340 
341 class PromoteResourcesToArgsPass
342     : public PromoteResourcesToArgsPassBase<PromoteResourcesToArgsPass> {
343  public:
344   void runOnOperation() override;
345 };
346 
runOnOperation()347 void PromoteResourcesToArgsPass::runOnOperation() {
348   ModuleOp module = getOperation();
349   FuncOp main_func = module.lookupSymbol<FuncOp>("main");
350   if (!main_func) return;
351 
352   // This routine should only be called when control flow operations are still
353   // represented with TF IfOp and WhileOp operations. In this case, there should
354   // be only one basic blocks in the MLIR representation.
355   if (failed(CheckSingleBlockFunction(main_func))) return signalPassFailure();
356 
357   llvm::SmallVector<std::string, 4> var_handle_shared_names;
358   if (failed(ResourceLiftingForFunctionalControlFlow(main_func)) ||
359       failed(PromoteVarHandlesToArguments(main_func, /*add_validation=*/true,
360                                           &var_handle_shared_names)) ||
361       failed(PromoteResourcesToArguments(main_func, var_handle_shared_names)))
362     return signalPassFailure();
363 }
364 
365 class PromoteVarHandlesToArgsPass
366     : public PromoteVarHandlesToArgsPassBase<PromoteVarHandlesToArgsPass> {
367  public:
368   void runOnOperation() override;
369 };
370 
runOnOperation()371 void PromoteVarHandlesToArgsPass::runOnOperation() {
372   ModuleOp module = getOperation();
373   MLIRContext* context = module.getContext();
374   for (auto function : module.getOps<FuncOp>()) {
375     if (failed(CheckSingleBlockFunction(function))) return signalPassFailure();
376 
377     llvm::SmallVector<std::string, 4> var_handle_shared_names;
378     (void)PromoteVarHandlesToArguments(function, /*add_validation=*/false,
379                                        &var_handle_shared_names);
380 
381     // Add resource names for each `tf.VarHandleOp` that were promoted to
382     // resource arguments.
383     const int var_handle_args_offset =
384         function.getNumArguments() - var_handle_shared_names.size();
385     for (auto var_name_and_index : llvm::enumerate(var_handle_shared_names))
386       function.setArgAttr(var_name_and_index.index() + var_handle_args_offset,
387                           kResourceNameArgAttr,
388                           StringAttr::get(context, var_name_and_index.value()));
389   }
390 }
391 
392 }  // namespace
393 
CreatePromoteResourcesToArgsPass()394 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteResourcesToArgsPass() {
395   return std::make_unique<PromoteResourcesToArgsPass>();
396 }
397 
CreatePromoteVarHandlesToArgsPass()398 std::unique_ptr<OperationPass<ModuleOp>> CreatePromoteVarHandlesToArgsPass() {
399   return std::make_unique<PromoteVarHandlesToArgsPass>();
400 }
401 
402 }  // namespace TF
403 }  // namespace mlir
404