• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h"  // from @llvm-project
23 #include "mlir/IR/Attributes.h"  // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
26 #include "mlir/IR/Types.h"  // from @llvm-project
27 #include "mlir/IR/Value.h"  // from @llvm-project
28 #include "mlir/Pass/Pass.h"  // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32 
33 namespace mlir {
34 namespace TF {
35 namespace {
36 
37 // Location attribute.
38 constexpr StringRef kClassAttr = "_class";
39 constexpr StringRef kSharedNameAttr = "shared_name";
40 constexpr StringRef kLocationPrefix = "loc:@";
41 
42 // A pass that converts readonly reference variables to the corresponding
43 // resource variables.
44 //
45 // It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
46 //
47 // For the background, this pass is a part of hoisting VariableV2 ops by
48 // re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
49 //  can be done by the following passes:
50 //  - Capturing resource values into global tensors (importing saved model).
51 //  - Promoting VarHandle ops to function input/outputs.
52 //  - Freezing global tensor pass.
53 //
54 // This path assumes that all the VariableV2 ops is read-only via verifying the
55 // heuristic method that assumes that all the users of them is Identity op,
56 // fed directly.
57 class ConvertReadonlyReferenceVariablesToResourceVariablesPass
58     : public PassWrapper<
59           ConvertReadonlyReferenceVariablesToResourceVariablesPass,
60           FunctionPass> {
61  public:
getArgument() const62   StringRef getArgument() const final {
63     return "tf-readonly-references-to-resources";
64   }
65 
getDescription() const66   StringRef getDescription() const final {
67     return "Convert readonly reference variables to resource variables.";
68   }
69 
70   void runOnFunction() override;
71 };
72 
73 // Parse node name from "_class" or "shared_name" attributes.
GetNodeNameFromClassAttrOrSharedNameAttr(Operation * op)74 StringRef GetNodeNameFromClassAttrOrSharedNameAttr(Operation *op) {
75   // Parse node name from the `shared_name` attribute first. The variable v2 op
76   // relies on the share name to look up from the TensorFlow's resource manager.
77   StringAttr shared_name_attr = op->getAttrOfType<StringAttr>(kSharedNameAttr);
78   if (shared_name_attr) {
79     auto shared_name = StringRef(shared_name_attr.getValue());
80     if (!shared_name.empty()) {
81       return shared_name;
82     }
83   }
84   // Attempt to parse "_class" attribute if there is no "shared_name"
85   // attribute.
86   ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
87   if (!classes_attr) {
88     // Attempt to parse "_class" from the IdentityOp that follows VariableV2.
89     // For read-only reference variables, IdentityOp should be the only user of
90     // VariableV2.
91     auto identity_op = op->getUsers().begin();
92     classes_attr = identity_op->getAttrOfType<ArrayAttr>(kClassAttr);
93     if (!classes_attr) {
94       op->emitOpError() << "has no '_class' and 'shared_name' attributes";
95       return StringRef();
96     }
97   }
98 
99   StringRef result;
100   for (Attribute class_attr : classes_attr) {
101     StringRef node_name = class_attr.cast<StringAttr>().getValue();
102     if (!node_name.startswith(kLocationPrefix)) {
103       continue;
104     }
105     if (!result.empty()) {
106       // Invalid case since there are multiple loc:@ attributes.
107       op->emitOpError()
108           << "expects only one named location in '_class' attribute, but got "
109           << classes_attr;
110       return StringRef();
111     }
112     result = node_name.drop_front(kLocationPrefix.size());
113   }
114   if (result.empty()) {
115     op->emitOpError() << "expects variable name in '_class' attribute, but got "
116                       << classes_attr;
117   }
118   return result;
119 }
120 
runOnFunction()121 void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
122   FuncOp func = getFunction();
123 
124   OpBuilder builder(func.getContext());
125   SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
126 
127   // Checks all the VariableV2 ops is read-only via verifying the heuristic
128   // method that assumes that all the users of them is Identity op, feeded
129   // directly.
130   auto read_only_vars_fn = [&variable_v2s_to_replace](
131                                VariableV2Op variable_v2_op) {
132     if (variable_v2_op.getResult().use_empty()) {
133       // Erase the op when there is no user.
134       variable_v2_op.erase();
135       return mlir::WalkResult::advance();
136     }
137     if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
138                                                            Operation *user) {
139           if (!isa<IdentityOp>(user)) {
140             variable_v2_op.emitOpError()
141                 << "expects all users to be 'tf.Identity', but got user "
142                 << user->getName();
143             return false;
144           }
145           return true;
146         })) {
147       return mlir::WalkResult::interrupt();
148     }
149     variable_v2s_to_replace.push_back(variable_v2_op);
150     return mlir::WalkResult::advance();
151   };
152 
153   WalkResult walk_res = func.walk(read_only_vars_fn);
154   if (walk_res.wasInterrupted()) return signalPassFailure();
155 
156   for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
157     builder.setInsertionPoint(variable_v2_op);
158     ShapedType shaped_type =
159         variable_v2_op.getResult().getType().cast<ShapedType>();
160     TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
161     StringAttr device_attr =
162         variable_v2_op->getAttrOfType<StringAttr>("device");
163     if (!device_attr) device_attr = builder.getStringAttr("");
164     StringRef variable_name =
165         GetNodeNameFromClassAttrOrSharedNameAttr(variable_v2_op);
166     if (variable_name.empty()) {
167       return signalPassFailure();
168     }
169     VarHandleOp var_handle_op = builder.create<VarHandleOp>(
170         variable_v2_op.getLoc(),
171         ArrayRef<Type>{RankedTensorType::get(
172             {}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
173                                       builder.getContext()))},
174         ArrayRef<Value>{},
175         ArrayRef<NamedAttribute>{
176             builder.getNamedAttr("device", device_attr),
177             builder.getNamedAttr("container", variable_v2_op.containerAttr()),
178             builder.getNamedAttr("shared_name",
179                                  builder.getStringAttr(variable_name))});
180     for (Operation *user :
181          make_early_inc_range(variable_v2_op.getResult().getUsers())) {
182       builder.setInsertionPoint(user);
183       ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
184           user->getLoc(), ArrayRef<Type>{tensor_type},
185           ArrayRef<Value>{var_handle_op});
186       user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
187       user->erase();
188     }
189     variable_v2_op.erase();
190   }
191 }
192 
193 }  // namespace
194 
195 std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()196 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
197   return std::make_unique<
198       ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
199 }
200 
201 static PassRegistration<
202     ConvertReadonlyReferenceVariablesToResourceVariablesPass>
203     pass;
204 
205 }  // namespace TF
206 
207 }  // namespace mlir
208