• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TFG_DIALECT
17#define TFG_DIALECT
18
19include "mlir/IR/OpBase.td"
20include "mlir/Interfaces/InferTypeOpInterface.td"
21
22
23// ODS Definition for the dialect, see https://mlir.llvm.org/docs/OpDefinitions/
24// for more information.
25
26
27//===----------------------------------------------------------------------===//
28// TFGraph dialect definitions
29//===----------------------------------------------------------------------===//
30
31def TFGraphDialect : Dialect {
32  let name = "tfg";
33
34  let summary = "This dialect models TensorFlow Graphs as encoded in GraphDef.";
35  let description = [{
36    This dialect is modeling TensorFlow GraphDefs and intended to provide a high
37    level of fidelity.
38
39    The attribute mappings from GraphDef are listed down below,
40
41    Graph/Function Attributes:
42    FunctionDef.attr will prepand with "tf" prefix
43    FunctionDef.signature.name <-> "sym_name"
44    FunctionDef.signature.description <-> "description"
45    FunctionDef.signature.is_stateful <-> "is_stateful"
46    FunctionDef.signature.gradient <-> "gradient"
47    FunctionDef.resource_arg_unique_id <-> "resource_arg_unique_ids_keys"
48    FunctionDef.resource_arg_unique_id <-> "resource_arg_unique_ids_values"
49
50    Input Attributes:
51    FunctionDef.signature.input_arg.name <-> "tfg.name"
52    FunctionDef.signature.input_arg.description <-> "tfg.description"
53    FunctionDef.signature.input_arg.handle_data <-> "tfg.handle_data"
54    FunctionDef.signature.input_arg.is_ref <-> "tfg.is_ref"
55    FunctionDef.arg_attr will prepand with "tf" prefix
56
57    Output Attributes:
58    FunctionDef.signature.output_arg.name <-> "tfg.name"
59    FunctionDef.signature.output_arg.description <-> "tfg.description"
60    FunctionDef.signature.output_arg.handle_data <-> "tfg.handle_data"
61    FunctionDef.signature.output_arg.type <-> "tfg.dtype"
62    FunctionDef.signature.control_output <-> "tfg.control_ret_name_"
63
64    Node Attributes:
65    NodeDef.device <-> "_mlir_device"
66    NodeDef.name <-> "_mlir_name"
67    NodeDef.attr <-> "_output_shape"
68    NodeDef.experimental_type <-> "_mlir_fulltype"
69  }];
70
71  let extraClassDeclaration = [{
72    StringAttr getNameAttrIdentifier() const { return name_key_; }
73    static constexpr StringLiteral getNameAttrKey() { return {"_mlir_name"}; }
74
75    StringAttr getDeviceAttrIdentifier() const { return device_key_; }
76    static constexpr StringLiteral getDeviceAttrKey() {
77      return {"_mlir_device"};
78    }
79
80    StringAttr getAssignedDeviceAttrIdentifier() const {
81      return assigned_device_key_;
82    }
83    static constexpr StringLiteral getAssignedDeviceAttrKey() {
84      return {"_mlir_assigned_device"};
85    }
86
87    StringAttr getFullTypeAttrIdentifier() const { return fulltype_key_; }
88    static constexpr StringLiteral getFullTypeAttrKey() {
89      return {"_mlir_fulltype"};
90    }
91
92    StringAttr getTfgNameAttrIdentifier() const { return tfg_name_key_; }
93    static constexpr StringRef getTfgNameAttrKey() { return "tfg.name"; }
94
95    StringAttr getTfgDescriptionAttrIdentifier() const {
96      return tfg_description_key_;
97    }
98    static constexpr StringRef getTfgDescriptionAttrKey() {
99      return {"tfg.description"};
100    }
101
102    StringAttr getTfgIsRefAttrIdentifier() const { return tfg_is_ref_key_; }
103    static constexpr StringRef getTfgIsRefAttrKey() { return {"tfg.is_ref"}; }
104
105    StringAttr getTfgHandleDataAttrIdentifier() const {
106      return tfg_handle_data_key_;
107    }
108    static constexpr StringRef getTfgHandleDataAttrKey() {
109      return {"tfg.handle_data"};
110    }
111
112    StringAttr getTfgFullTypeAttrIdentifier() const {
113      return tfg_full_type_key_;
114    }
115    static constexpr StringRef getTfgFullTypeAttrKey() {
116      return {"tfg.experimental_full_type"};
117    }
118
119    StringAttr getLiftedGraphFuncNameAttrIdentifier() const {
120      return lifted_graph_func_name_;
121    }
122    static constexpr StringRef getLiftedGraphFuncNameKey() {
123      return {"_mlir_lifted_graph"};
124    }
125
126    // Cached accessor for the control type.
127    ControlType getControlType() const { return control_ty_; }
128
129    // Print an operation that belongs to this dialect if unregistered.
130    void printCustomTfOp(Operation *op, OpAsmPrinter &printer) const;
131
132    // Returns the hook to parse an operation belonging to this dialect, even
133    // if unregistered.
134    Optional<ParseOpHook> getParseOperationHook(StringRef opName) const
135      override;
136
137    // Returns the took to print an operation belonging to this dialect, even
138    // if unregistered.
139    llvm::unique_function<void(Operation *, OpAsmPrinter &)>
140    getOperationPrinter(Operation *op) const override;
141
142    // Functions for checking operation categories.
143    #define GET_OP_CATEGORIES
144    #include "tensorflow/core/ir/tf_op_names.inc"
145
146  private:
147    // Fallback implementation of OpAsmOpInterface.
148    TFGraphOpAsmInterface *fallbackOpAsmInterface_ = nullptr;
149
150    // Cached TensorFlow operation names.
151    #define GET_OP_NAME_DECLS
152    #include "tensorflow/core/ir/tf_op_names.inc"
153
154    // Cached identifier for efficiency purpose.
155    StringAttr assigned_device_key_;
156    StringAttr device_key_;
157    StringAttr fulltype_key_;
158    StringAttr lifted_graph_func_name_;
159    StringAttr name_key_;
160    StringAttr tfg_description_key_;
161    StringAttr tfg_full_type_key_;
162    StringAttr tfg_handle_data_key_;
163    StringAttr tfg_is_ref_key_;
164    StringAttr tfg_name_key_;
165
166    // Cached control type.
167    ControlType control_ty_;
168  }];
169
170  let cppNamespace = "::mlir::tfg";
171
172  let useDefaultAttributePrinterParser = 1;
173  let hasNonDefaultDestructor = 1;
174  let hasOperationInterfaceFallback = 1;
175}
176
177#endif // TFG_DIALECT
178