1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/ADT/SmallVector.h"
19 #include "llvm/ADT/StringExtras.h"
20 #include "llvm/ADT/StringRef.h"
21 #include "llvm/Support/Casting.h"
22 #include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
23 #include "mlir/IR/Attributes.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Types.h" // from @llvm-project
27 #include "mlir/IR/Value.h" // from @llvm-project
28 #include "mlir/Pass/Pass.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
31 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
32
33 namespace mlir {
34 namespace TF {
35 namespace {
36
37 // Location attribute.
38 constexpr StringRef kClassAttr = "_class";
39 constexpr StringRef kSharedNameAttr = "shared_name";
40 constexpr StringRef kLocationPrefix = "loc:@";
41
42 // A pass that converts readonly reference variables to the corresponding
43 // resource variables.
44 //
45 // It converts (VariableV2 -> Identity) to (VarHandle -> ReadVariable).
46 //
47 // For the background, this pass is a part of hoisting VariableV2 ops by
48 // re-using the pipeline for hoisting (VarHandle -> ReadVariable) cases, which
49 // can be done by the following passes:
50 // - Capturing resource values into global tensors (importing saved model).
51 // - Promoting VarHandle ops to function input/outputs.
52 // - Freezing global tensor pass.
53 //
54 // This path assumes that all the VariableV2 ops is read-only via verifying the
55 // heuristic method that assumes that all the users of them is Identity op,
56 // fed directly.
57 class ConvertReadonlyReferenceVariablesToResourceVariablesPass
58 : public PassWrapper<
59 ConvertReadonlyReferenceVariablesToResourceVariablesPass,
60 FunctionPass> {
61 public:
getArgument() const62 StringRef getArgument() const final {
63 return "tf-readonly-references-to-resources";
64 }
65
getDescription() const66 StringRef getDescription() const final {
67 return "Convert readonly reference variables to resource variables.";
68 }
69
70 void runOnFunction() override;
71 };
72
73 // Parse node name from "_class" or "shared_name" attributes.
GetNodeNameFromClassAttrOrSharedNameAttr(Operation * op)74 StringRef GetNodeNameFromClassAttrOrSharedNameAttr(Operation *op) {
75 // Parse node name from the `shared_name` attribute first. The variable v2 op
76 // relies on the share name to look up from the TensorFlow's resource manager.
77 StringAttr shared_name_attr = op->getAttrOfType<StringAttr>(kSharedNameAttr);
78 if (shared_name_attr) {
79 auto shared_name = StringRef(shared_name_attr.getValue());
80 if (!shared_name.empty()) {
81 return shared_name;
82 }
83 }
84 // Attempt to parse "_class" attribute if there is no "shared_name"
85 // attribute.
86 ArrayAttr classes_attr = op->getAttrOfType<ArrayAttr>(kClassAttr);
87 if (!classes_attr) {
88 // Attempt to parse "_class" from the IdentityOp that follows VariableV2.
89 // For read-only reference variables, IdentityOp should be the only user of
90 // VariableV2.
91 auto identity_op = op->getUsers().begin();
92 classes_attr = identity_op->getAttrOfType<ArrayAttr>(kClassAttr);
93 if (!classes_attr) {
94 op->emitOpError() << "has no '_class' and 'shared_name' attributes";
95 return StringRef();
96 }
97 }
98
99 StringRef result;
100 for (Attribute class_attr : classes_attr) {
101 StringRef node_name = class_attr.cast<StringAttr>().getValue();
102 if (!node_name.startswith(kLocationPrefix)) {
103 continue;
104 }
105 if (!result.empty()) {
106 // Invalid case since there are multiple loc:@ attributes.
107 op->emitOpError()
108 << "expects only one named location in '_class' attribute, but got "
109 << classes_attr;
110 return StringRef();
111 }
112 result = node_name.drop_front(kLocationPrefix.size());
113 }
114 if (result.empty()) {
115 op->emitOpError() << "expects variable name in '_class' attribute, but got "
116 << classes_attr;
117 }
118 return result;
119 }
120
runOnFunction()121 void ConvertReadonlyReferenceVariablesToResourceVariablesPass::runOnFunction() {
122 FuncOp func = getFunction();
123
124 OpBuilder builder(func.getContext());
125 SmallVector<VariableV2Op, 4> variable_v2s_to_replace;
126
127 // Checks all the VariableV2 ops is read-only via verifying the heuristic
128 // method that assumes that all the users of them is Identity op, feeded
129 // directly.
130 auto read_only_vars_fn = [&variable_v2s_to_replace](
131 VariableV2Op variable_v2_op) {
132 if (variable_v2_op.getResult().use_empty()) {
133 // Erase the op when there is no user.
134 variable_v2_op.erase();
135 return mlir::WalkResult::advance();
136 }
137 if (!all_of(variable_v2_op.getResult().getUsers(), [&variable_v2_op](
138 Operation *user) {
139 if (!isa<IdentityOp>(user)) {
140 variable_v2_op.emitOpError()
141 << "expects all users to be 'tf.Identity', but got user "
142 << user->getName();
143 return false;
144 }
145 return true;
146 })) {
147 return mlir::WalkResult::interrupt();
148 }
149 variable_v2s_to_replace.push_back(variable_v2_op);
150 return mlir::WalkResult::advance();
151 };
152
153 WalkResult walk_res = func.walk(read_only_vars_fn);
154 if (walk_res.wasInterrupted()) return signalPassFailure();
155
156 for (VariableV2Op variable_v2_op : variable_v2s_to_replace) {
157 builder.setInsertionPoint(variable_v2_op);
158 ShapedType shaped_type =
159 variable_v2_op.getResult().getType().cast<ShapedType>();
160 TensorType tensor_type = DropRefType(shaped_type).cast<TensorType>();
161 StringAttr device_attr =
162 variable_v2_op->getAttrOfType<StringAttr>("device");
163 if (!device_attr) device_attr = builder.getStringAttr("");
164 StringRef variable_name =
165 GetNodeNameFromClassAttrOrSharedNameAttr(variable_v2_op);
166 if (variable_name.empty()) {
167 return signalPassFailure();
168 }
169 VarHandleOp var_handle_op = builder.create<VarHandleOp>(
170 variable_v2_op.getLoc(),
171 ArrayRef<Type>{RankedTensorType::get(
172 {}, TF::ResourceType::get(ArrayRef<TensorType>{tensor_type},
173 builder.getContext()))},
174 ArrayRef<Value>{},
175 ArrayRef<NamedAttribute>{
176 builder.getNamedAttr("device", device_attr),
177 builder.getNamedAttr("container", variable_v2_op.containerAttr()),
178 builder.getNamedAttr("shared_name",
179 builder.getStringAttr(variable_name))});
180 for (Operation *user :
181 make_early_inc_range(variable_v2_op.getResult().getUsers())) {
182 builder.setInsertionPoint(user);
183 ReadVariableOp read_variable_op = builder.create<ReadVariableOp>(
184 user->getLoc(), ArrayRef<Type>{tensor_type},
185 ArrayRef<Value>{var_handle_op});
186 user->getResult(0).replaceAllUsesWith(read_variable_op.getResult());
187 user->erase();
188 }
189 variable_v2_op.erase();
190 }
191 }
192
193 } // namespace
194
195 std::unique_ptr<OperationPass<FuncOp>>
CreateConvertReadonlyReferenceVariablesToResourceVariablesPass()196 CreateConvertReadonlyReferenceVariablesToResourceVariablesPass() {
197 return std::make_unique<
198 ConvertReadonlyReferenceVariablesToResourceVariablesPass>();
199 }
200
201 static PassRegistration<
202 ConvertReadonlyReferenceVariablesToResourceVariablesPass>
203 pass;
204
205 } // namespace TF
206
207 } // namespace mlir
208