1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_
17 #define TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_
18
19 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
20 #include "mlir/IR/Dialect.h" // from @llvm-project
21 #include "mlir/IR/OpDefinition.h" // from @llvm-project
22
23 namespace mlir {
24 namespace tf_saved_model {
25
26 class TensorFlowSavedModelDialect : public Dialect {
27 public:
28 explicit TensorFlowSavedModelDialect(MLIRContext *context);
29 LogicalResult verifyRegionArgAttribute(Operation *op, unsigned region_index,
30 unsigned arg_index,
31 NamedAttribute named_attr) override;
32 LogicalResult verifyRegionResultAttribute(Operation *op,
33 unsigned region_index,
34 unsigned result_index,
35 NamedAttribute named_attr) override;
36 LogicalResult verifyOperationAttribute(Operation *op,
37 NamedAttribute named_attr) override;
38
getDialectNamespace()39 static StringRef getDialectNamespace() { return "tf_saved_model"; }
40 };
41
42 } // namespace tf_saved_model
43 } // namespace mlir
44
45 // Declares the operations for this dialect using the generated header.
46 #define GET_OP_CLASSES
47 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h.inc"
48
49 namespace mlir {
50 namespace tf_saved_model {
51
52 // Returns the list of exported names for `op`.
53 // An empty list means `op` is not exported.
54 SmallVector<StringRef, 2> GetExportedNames(Operation *op);
55
56 // Returns true if `op` is exported.
57 bool IsExported(Operation *op);
58
59 // Returns true if `module` has tf_saved_model linkage semantics.
60 bool HasTfSavedModelSemantics(ModuleOp module);
61
62 // Returns the tf_saved_model.global_tensor op that func's arg_index'th argument
63 // refers to as a bound input, or null.
64 Operation *LookupBoundInput(FuncOp func, int arg_index,
65 const SymbolTable &symbol_table);
66
67 template <typename T>
LookupBoundInputOfType(FuncOp func,int arg_index,const SymbolTable & symbol_table)68 T LookupBoundInputOfType(FuncOp func, int arg_index,
69 const SymbolTable &symbol_table) {
70 return llvm::dyn_cast_or_null<T>(
71 LookupBoundInput(func, arg_index, symbol_table));
72 }
73
74 // Gets the type that an exported function arg that is bound to symbol ops such
75 // as `global_tensor` and `asset` should have.
76 Type GetBoundInputArgTypeFor(mlir::Operation *op);
77
78 // Returns the session initializer of this module if it exists. Returns null
79 // otherwise.
80 SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
81
82 // Returns the exported name for the session initializer function.
83 SmallVector<StringRef, 2> GetSessionInitializerExportedName(mlir::ModuleOp op);
84
85 } // namespace tf_saved_model
86 } // namespace mlir
87
88 #endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_IR_TF_SAVED_MODEL_H_
89