1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
17
18 #include "llvm/ADT/DenseMap.h"
19 #include "llvm/ADT/DenseSet.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/ADT/StringRef.h"
22 #include "llvm/ADT/Twine.h"
23 #include "llvm/Support/Casting.h"
24 #include "llvm/Support/raw_ostream.h"
25 #include "mlir/IR/Attributes.h" // from @llvm-project
26 #include "mlir/IR/Builders.h" // from @llvm-project
27 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
28 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
29 #include "mlir/IR/Identifier.h" // from @llvm-project
30 #include "mlir/IR/OpImplementation.h" // from @llvm-project
31 #include "mlir/IR/PatternMatch.h" // from @llvm-project
32 #include "mlir/IR/SymbolTable.h" // from @llvm-project
33 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
34 #include "mlir/Support/LogicalResult.h" // from @llvm-project
35 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
36 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
37
38 namespace mlir {
39 namespace tf_saved_model {
40
41 //===----------------------------------------------------------------------===//
42 // Utilities
43 //===----------------------------------------------------------------------===//
44
IsStrArrayAttr(Attribute attr)45 static bool IsStrArrayAttr(Attribute attr) {
46 auto array = attr.dyn_cast<ArrayAttr>();
47 if (!array) return false;
48
49 return llvm::all_of(array,
50 [](Attribute attr) { return attr.isa<StringAttr>(); });
51 }
52
53 //===----------------------------------------------------------------------===//
54 // TensorFlowSavedModelDialect Op's
55 //===----------------------------------------------------------------------===//
56
VerifyTensorTypesCompatible(Type t1,Type t2)57 LogicalResult VerifyTensorTypesCompatible(Type t1, Type t2) {
58 if (!t1.isa<TensorType>() || !t2.isa<TensorType>()) {
59 return failure();
60 }
61 return verifyCompatibleShape(t1.cast<TensorType>(), t2.cast<TensorType>());
62 }
63
Verify(GlobalTensorOp global_tensor)64 static LogicalResult Verify(GlobalTensorOp global_tensor) {
65 if (failed(VerifyTensorTypesCompatible(
66 global_tensor.type(), global_tensor.value().Attribute::getType()))) {
67 return global_tensor.emitError() << "'type' and 'value' attributes should "
68 "have compatible tensor types";
69 }
70 if (!global_tensor.is_mutable()) {
71 if (!global_tensor.type().cast<TensorType>().hasStaticShape()) {
72 return global_tensor.emitError()
73 << "'type' attribute for immutable 'tf_saved_model.global_tensor' "
74 "should have a static shape";
75 }
76 }
77 return success();
78 }
79
Verify(SessionInitializerOp session_initializer)80 static LogicalResult Verify(SessionInitializerOp session_initializer) {
81 mlir::SymbolTable symbol_table(
82 session_initializer->getParentOfType<ModuleOp>());
83
84 for (auto sym_ref : session_initializer.initializers()) {
85 auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
86 sym_ref.cast<FlatSymbolRefAttr>().getValue());
87
88 if (!init_func_op)
89 return session_initializer.emitOpError()
90 << "the initializer function does not exist";
91
92 if (!init_func_op.getType().getResults().empty())
93 return session_initializer.emitOpError()
94 << "the initializer function should have no output";
95
96 auto exported_names = GetExportedNames(init_func_op);
97
98 if (exported_names.empty())
99 return session_initializer.emitOpError()
100 << "the initializer function should be exported";
101
102 if (exported_names.size() != 1)
103 return session_initializer.emitOpError()
104 << "the initializer function should have only one exported names";
105 }
106
107 return success();
108 }
109
110 } // namespace tf_saved_model
111 } // namespace mlir
112
113 #define GET_OP_CLASSES
114 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
115
116 namespace mlir {
117 namespace tf_saved_model {
118
119 //===----------------------------------------------------------------------===//
120 // TensorFlowSavedModelDialect Dialect
121 //===----------------------------------------------------------------------===//
122
TensorFlowSavedModelDialect(MLIRContext * context)123 TensorFlowSavedModelDialect::TensorFlowSavedModelDialect(MLIRContext *context)
124 : Dialect(/*name=*/"tf_saved_model", context,
125 TypeID::get<TensorFlowSavedModelDialect>()) {
126 // The TensorFlow Dialect is needed in the verifier and other routines
127 // associated to this dialect. It makes little sense anyway to use the
128 // SavedModel dialect without the TensorFlow Dialect.
129 context->loadDialect<TF::TensorFlowDialect>();
130
131 addOperations<
132 #define GET_OP_LIST
133 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc.inc"
134 >();
135 }
136
VerifyIndexPath(Operation * op,NamedAttribute named_attr)137 static LogicalResult VerifyIndexPath(Operation *op, NamedAttribute named_attr) {
138 auto attr = named_attr.second.dyn_cast<ArrayAttr>();
139 if (!attr) {
140 return op->emitError()
141 << "'tf_saved_model.index_path' attribute should be an ArrayAttr";
142 }
143 for (auto element : attr) {
144 if (element.isa<StringAttr>()) {
145 continue;
146 }
147 if (auto integer = element.dyn_cast<IntegerAttr>()) {
148 if (integer.getValue().getBitWidth() == 64) {
149 continue;
150 }
151 }
152 return op->emitError() << "'tf_saved_model.index_path' elements should "
153 "be strings or 64-bit integers";
154 }
155 return mlir::success();
156 }
157
GetBoundInputArgTypeFor(mlir::Operation * op)158 Type GetBoundInputArgTypeFor(mlir::Operation *op) {
159 if (auto global_tensor = llvm::dyn_cast<GlobalTensorOp>(op)) {
160 auto type = global_tensor.type().cast<TensorType>();
161 return RankedTensorType::get(
162 {}, TF::ResourceType::get({type}, type.getContext()));
163 }
164
165 if (auto asset = llvm::dyn_cast<AssetOp>(op)) {
166 return RankedTensorType::get({}, TF::StringType::get(asset.getContext()));
167 }
168
169 op->emitError() << "unknown symbol operation";
170 return {};
171 }
172
VerifyBoundInputArgType(Operation * op_for_diagnostics,Type arg_type,mlir::Operation * symbol_op)173 static LogicalResult VerifyBoundInputArgType(Operation *op_for_diagnostics,
174 Type arg_type,
175 mlir::Operation *symbol_op) {
176 auto expected_type = GetBoundInputArgTypeFor(symbol_op);
177 if (!expected_type) return failure();
178
179 if (arg_type != expected_type) {
180 return op_for_diagnostics->emitError()
181 << "bound input with type " << arg_type << " expected to have type "
182 << expected_type;
183 }
184 return success();
185 }
186
verifyRegionArgAttribute(Operation * op,unsigned region_index,unsigned arg_index,NamedAttribute named_attr)187 LogicalResult TensorFlowSavedModelDialect::verifyRegionArgAttribute(
188 Operation *op, unsigned region_index, unsigned arg_index,
189 NamedAttribute named_attr) {
190 if (named_attr.first == "tf_saved_model.bound_input") {
191 if (!named_attr.second.isa<FlatSymbolRefAttr>()) {
192 return op->emitError() << "'tf_saved_model.bound_input' attribute should "
193 "be a FlatSymbolRefAttr";
194 }
195 auto symbol_name = named_attr.second.cast<FlatSymbolRefAttr>().getValue();
196 auto module = op->getParentOfType<ModuleOp>();
197 mlir::Operation *symbol_op = module.lookupSymbol(symbol_name);
198 if (!symbol_op) {
199 return op->emitError() << "'tf_saved_model.bound_input' attribute must "
200 "reference a valid symbol, got invalid symbol '"
201 << symbol_name << "'";
202 }
203 auto arg_type = cast<FuncOp>(op).getArgument(arg_index).getType();
204 return VerifyBoundInputArgType(op, arg_type, symbol_op);
205 }
206 if (named_attr.first == "tf_saved_model.index_path") {
207 return VerifyIndexPath(op, named_attr);
208 }
209
210 return op->emitError() << "unknown tf_saved_model dialect arg attribute '"
211 << named_attr.first << "'";
212 }
213
verifyRegionResultAttribute(Operation * op,unsigned region_index,unsigned result_index,NamedAttribute named_attr)214 LogicalResult TensorFlowSavedModelDialect::verifyRegionResultAttribute(
215 Operation *op, unsigned region_index, unsigned result_index,
216 NamedAttribute named_attr) {
217 if (named_attr.first == "tf_saved_model.index_path") {
218 return VerifyIndexPath(op, named_attr);
219 }
220
221 return op->emitError() << "unknown tf_saved_model dialect result attribute '"
222 << named_attr.first << "'";
223 }
224
HasAnyTfSavedModelArgAttr(FuncOp func)225 static bool HasAnyTfSavedModelArgAttr(FuncOp func) {
226 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
227 if (func.getArgAttr(i, "tf_saved_model.index_path") ||
228 func.getArgAttr(i, "tf_saved_model.bound_input")) {
229 return true;
230 }
231 }
232 for (int i = 0, e = func.getNumResults(); i < e; i++) {
233 if (func.getResultAttr(i, "tf_saved_model.index_path") ||
234 func.getResultAttr(i, "tf_saved_model.bound_input")) {
235 return true;
236 }
237 }
238 return false;
239 }
240
VerifySavedModelModule(ModuleOp module,TensorFlowSavedModelDialect * dialect)241 static LogicalResult VerifySavedModelModule(
242 ModuleOp module, TensorFlowSavedModelDialect *dialect) {
243 auto exported_names_ident =
244 Identifier::get("tf_saved_model.exported_names", dialect->getContext());
245 // Check that there are no duplicated exported_names.
246 DenseMap<StringRef, Operation *> exported_name_to_op;
247 for (auto &op : module) {
248 auto attr = op.getAttr(exported_names_ident);
249 if (!attr) continue;
250 // If this verifier is called before we verify the
251 // 'tf_saved_model.exported_names' attribute, then it might be invalid.
252 // Forward to the dialect's verification to establish that precondition.
253 if (failed(dialect->verifyOperationAttribute(
254 &op, {exported_names_ident, attr}))) {
255 return failure();
256 }
257 for (auto str : attr.cast<ArrayAttr>()) {
258 auto exported_name = str.cast<StringAttr>().getValue();
259 auto p = exported_name_to_op.insert({exported_name, &op});
260 if (!p.second) {
261 return op.emitError()
262 .append("duplicate exported name '", exported_name, "'")
263 .attachNote(p.first->getSecond()->getLoc())
264 .append("previously seen here");
265 }
266 }
267 }
268 for (auto func : module.getOps<FuncOp>()) {
269 const bool is_exported = IsExported(func);
270
271 if (is_exported && !func.isPublic()) {
272 return func.emitError()
273 << "exported function @" << func.getName() << " should be public";
274 }
275
276 if (!is_exported && func.isPublic()) {
277 return func.emitError() << "non-exported function @" << func.getName()
278 << " should be private";
279 }
280
281 if (!is_exported && HasAnyTfSavedModelArgAttr(func)) {
282 return func.emitError() << "can only apply 'tf_saved_model' argument "
283 "attributes to exported functions";
284 }
285 }
286
287 auto session_initializers = module.getOps<SessionInitializerOp>();
288 if (!session_initializers.empty() &&
289 !llvm::hasSingleElement(session_initializers)) {
290 return (*++session_initializers.begin()).emitError()
291 << "there must be no more than one session_initializer op";
292 }
293
294 auto is_init = [&session_initializers](mlir::FuncOp func) {
295 if (session_initializers.empty()) return false;
296 auto init_syms = (*session_initializers.begin()).initializers();
297 return std::any_of(
298 init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) {
299 return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName();
300 });
301 };
302
303 SymbolTable symbol_table(module);
304 auto symbol_uses = SymbolTable::getSymbolUses(&module.getBodyRegion());
305 if (!symbol_uses.hasValue()) {
306 return module.emitError() << "modules with 'tf_saved_model.semantics' must "
307 "have analyzable symbol uses";
308 }
309 for (auto symbol_use : *symbol_uses) {
310 auto func = symbol_table.lookup<FuncOp>(
311 symbol_use.getSymbolRef().cast<FlatSymbolRefAttr>().getValue());
312 if (func && IsExported(func)) {
313 // If it is an init function, then it can be used by the unique
314 // session_initializer op.
315 if (is_init(func) &&
316 llvm::isa<SessionInitializerOp>(symbol_use.getUser()))
317 continue;
318
319 return symbol_use.getUser()
320 ->emitError("exported function cannot be internally referenced")
321 .attachNote(func.getLoc())
322 .append("references this exported function");
323 }
324 }
325 return success();
326 }
327
VerifyExportedFunc(FuncOp func)328 LogicalResult VerifyExportedFunc(FuncOp func) {
329 bool reached_bound_inputs = false;
330 auto module = func->getParentOfType<ModuleOp>();
331 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
332 if (func.getArgAttr(i, "tf_saved_model.bound_input")) {
333 reached_bound_inputs = true;
334 continue;
335 }
336 if (func.getArgAttr(i, "tf_saved_model.index_path")) {
337 if (reached_bound_inputs) {
338 return func.emitError()
339 << "all 'tf_saved_model.index_path' arg attributes should "
340 "precede all 'tf_saved_model.bound_input' arg attributes";
341 }
342 continue;
343 }
344 if (func.getArgAttr(i, "tf.resource_name")) {
345 if (module->getAttr("tf_saved_model.under_construction")) continue;
346 return func.emitError() << "'tf.resource_name' attribute is not allowed "
347 "unless it is being under construction";
348 }
349 return func.emitError()
350 << "all arguments should have 'tf_saved_model.index_path', "
351 "'tf_saved_model.bound_input' or 'tf.resource_name' attributes";
352 }
353 llvm::SmallDenseSet<StringRef, 8> unique_bound_inputs;
354 for (int i = 0, e = func.getNumArguments(); i < e; i++) {
355 if (auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
356 i, "tf_saved_model.bound_input")) {
357 if (!unique_bound_inputs.insert(attr.getValue()).second) {
358 if (module->getAttr("tf_saved_model.under_construction")) continue;
359 return func.emitError()
360 << "duplicate 'tf_saved_model.bound_input' binding";
361 }
362 }
363 }
364
365 for (int i = 0, e = func.getNumResults(); i < e; i++) {
366 if (!func.getResultAttr(i, "tf_saved_model.index_path")) {
367 return func.emitError() << "all results should have "
368 "'tf_saved_model.index_path' attributes";
369 }
370 }
371
372 return success();
373 }
374
verifyOperationAttribute(Operation * op,NamedAttribute named_attr)375 LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute(
376 Operation *op, NamedAttribute named_attr) {
377 if (named_attr.first == "tf_saved_model.exported_names") {
378 if (!isa<FuncOp, GlobalTensorOp>(op)) {
379 return op->emitError() << "'tf_saved_model.exported_names' must be on a "
380 "'func' or 'tf_saved_model.global_tensor' op";
381 }
382 if (!IsStrArrayAttr(named_attr.second)) {
383 return op->emitError()
384 << "'tf_saved_model.exported_names' must be an array of strings";
385 }
386 if (!op->getParentOp()->getAttr("tf_saved_model.semantics")) {
387 return op->emitError()
388 << "'tf_saved_model.exported_names' must be on an op "
389 "whose immediate parent has attribute "
390 "'tf_saved_model.semantics'";
391 }
392 if (auto func = dyn_cast<FuncOp>(op)) {
393 if (failed(VerifyExportedFunc(func))) {
394 return failure();
395 }
396 }
397 return success();
398 }
399 if (named_attr.first == "tf_saved_model.semantics") {
400 auto module = dyn_cast<ModuleOp>(op);
401 if (!module) {
402 return op->emitError() << "'tf_saved_model.semantics' must "
403 "be on a module op";
404 }
405 return VerifySavedModelModule(module, this);
406 }
407 if (named_attr.first == "tf_saved_model.under_construction") {
408 return success();
409 }
410
411 return op->emitError() << "unknown tf_saved_model dialect attribute '"
412 << named_attr.first << "'";
413 }
414
GetExportedNames(Operation * op)415 SmallVector<StringRef, 2> GetExportedNames(Operation *op) {
416 SmallVector<StringRef, 2> ret;
417 auto exported_names =
418 op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
419 if (exported_names) {
420 for (auto name : exported_names) {
421 ret.push_back(name.cast<StringAttr>().getValue());
422 }
423 }
424 return ret;
425 }
426
IsExported(Operation * op)427 bool IsExported(Operation *op) {
428 auto exported_names =
429 op->getAttrOfType<ArrayAttr>("tf_saved_model.exported_names");
430 return exported_names && !exported_names.empty();
431 }
432
HasTfSavedModelSemantics(ModuleOp module)433 bool HasTfSavedModelSemantics(ModuleOp module) {
434 return module->getAttr("tf_saved_model.semantics") != nullptr;
435 }
436
LookupBoundInput(FuncOp func,int arg_index,const SymbolTable & symbol_table)437 Operation *LookupBoundInput(FuncOp func, int arg_index,
438 const SymbolTable &symbol_table) {
439 auto attr = func.getArgAttrOfType<FlatSymbolRefAttr>(
440 arg_index, "tf_saved_model.bound_input");
441 if (!attr) return nullptr;
442 return symbol_table.lookup(attr.getValue());
443 }
444
GetSessionInitializerOp(mlir::ModuleOp op)445 SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op) {
446 auto initializers = op.getOps<SessionInitializerOp>();
447 if (initializers.empty()) return {};
448 return *initializers.begin();
449 }
450
451 class OptimizeSessionInitializerPattern
452 : public OpRewritePattern<SessionInitializerOp> {
453 public:
454 using OpRewritePattern::OpRewritePattern;
455
matchAndRewrite(SessionInitializerOp op,PatternRewriter & rewriter) const456 LogicalResult matchAndRewrite(SessionInitializerOp op,
457 PatternRewriter &rewriter) const override {
458 SymbolTable symbol_table(op->getParentOfType<ModuleOp>());
459
460 SmallVector<FuncOp, 2> to_remove;
461 SmallVector<mlir::Attribute, 2> to_keep;
462 for (auto sym_ref : op.initializers()) {
463 auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
464 sym_ref.cast<FlatSymbolRefAttr>().getValue());
465
466 // The init function can only be referenced from the SessionInitializerOp.
467 // And there is at most one SessionInitializerOp in the module. So if both
468 // ops have no other uses or have one NoOp only, they can be simply
469 // erased.
470 auto &operations = init_func_op.front().getOperations();
471 if ((operations.size() == 1 &&
472 operations.front().hasTrait<OpTrait::IsTerminator>()) ||
473 (operations.size() == 2 &&
474 dyn_cast<mlir::TF::NoOp>(operations.front()) &&
475 operations.back().hasTrait<OpTrait::IsTerminator>())) {
476 to_remove.push_back(init_func_op);
477 } else {
478 to_keep.push_back(sym_ref);
479 }
480 }
481
482 for (auto func_op : to_remove) rewriter.eraseOp(func_op);
483
484 if (to_keep.empty())
485 rewriter.eraseOp(op);
486 else
487 op->setAttr("initializers", rewriter.getArrayAttr(to_keep));
488
489 return success();
490 }
491 };
492
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)493 void SessionInitializerOp::getCanonicalizationPatterns(
494 OwningRewritePatternList &results, MLIRContext *context) {
495 results.insert<OptimizeSessionInitializerPattern>(context);
496 }
497
GetSessionInitializerExportedName(ModuleOp op)498 SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) {
499 auto session_initializer_op = GetSessionInitializerOp(op);
500 if (!session_initializer_op) return {};
501
502 SymbolTable symbol_table(op);
503
504 SmallVector<StringRef, 2> results;
505 for (auto sym_ref : session_initializer_op.initializers()) {
506 auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
507 sym_ref.cast<FlatSymbolRefAttr>().getValue());
508 auto exported_names = GetExportedNames(init_func_op);
509 assert(exported_names.size() == 1);
510 results.push_back(exported_names[0]);
511 }
512
513 return results;
514 }
515
516 } // namespace tf_saved_model
517 } // namespace mlir
518