1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
18 #include "mlir/IR/Operation.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21
init_ops(py::module & m)22 void init_ops(py::module& m) {
23 py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
24 m, "Operation")
25 .def("getRegion", &mlir::Operation::getRegion,
26 py::return_value_policy::reference)
27 .def("getResult", &mlir::Operation::getResult)
28 .def("dump", &mlir::Operation::dump)
29 .def("getNumResults", &mlir::Operation::getNumResults);
30
31 py::class_<mlir::OperationState>(m, "OperationState")
32 .def(py::init([](mlir::Location loc, std::string name) {
33 return mlir::OperationState(loc, llvm::StringRef(name));
34 }))
35 .def("addTypes",
36 [](mlir::OperationState& state, std::vector<mlir::Type> tys) {
37 state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
38 })
39 .def("addOperands",
40 [](mlir::OperationState& os, std::vector<mlir::Value> ops) {
41 os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
42 })
43 .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
44 py::return_value_policy::reference);
45
46 py::class_<mlir::ModuleOp>(m, "ModuleOp")
47 .def("create",
48 [](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
49 .def("push_back",
50 [](mlir::ModuleOp& m, mlir::func::FuncOp f) { m.push_back(f); })
51 .def("dump", &mlir::ModuleOp::dump)
52 .def("getAsStr", [](mlir::ModuleOp& m) {
53 std::string str;
54 llvm::raw_string_ostream os(str);
55 m.print(os);
56 return os.str();
57 });
58
59 py::class_<mlir::func::FuncOp>(m, "FuncOp")
60 .def("create",
61 [](mlir::Location location, std::string name,
62 mlir::FunctionType type) {
63 auto func = mlir::func::FuncOp::create(location, name, type);
64 func.addEntryBlock();
65 return func;
66 })
67 .def(
68 "getBody",
69 [](mlir::func::FuncOp& f) -> mlir::Region& { return f.getBody(); },
70 py::return_value_policy::reference)
71 .def("getArguments",
72 [](mlir::func::FuncOp& f) { return f.getArguments().vec(); })
73 .def("getName", [](mlir::func::FuncOp& f) { return f.getName().str(); })
74 .def("getType", &mlir::func::FuncOp::getFunctionType);
75
76 py::class_<mlir::func::ReturnOp>(m, "ReturnOp")
77 .def("create",
78 [](mlir::OpBuilder& opb, mlir::Location loc,
79 std::vector<mlir::Value> values) -> mlir::Operation* {
80 return opb
81 .create<mlir::func::ReturnOp>(
82 loc, mlir::ArrayRef<mlir::Value>(values))
83 .getOperation();
84 });
85
86 // mlir::TF::AddOp
87 py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
88 .def("create",
89 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
90 mlir::Value y) -> mlir::Operation* {
91 return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
92 });
93
94 py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
95 .def("create",
96 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
97 mlir::Value reduction_indices,
98 bool keep_dims = false) -> mlir::Operation* {
99 return opb
100 .create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
101 reduction_indices, keep_dims)
102 .getOperation();
103 });
104
105 // mlir::TF::ConstOp
106 py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
107 .def("create",
108 [](mlir::OpBuilder& opb, mlir::Location loc,
109 mlir::Attribute value) -> mlir::Operation* {
110 return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
111 });
112
113 // mlir::TF::EqualOp
114 py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
115 .def("create",
116 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
117 mlir::Value y) -> mlir::Operation* {
118 return opb
119 .create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
120 .getOperation();
121 });
122
123 // mlir::TF::GreaterEqualOp
124 py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
125 .def("create",
126 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
127 mlir::Value y) -> mlir::Operation* {
128 return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
129 .getOperation();
130 });
131
132 // mlir::TF::GreaterOp
133 py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
134 .def("create",
135 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
136 mlir::Value y) -> mlir::Operation* {
137 return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
138 });
139
140 // mlir::TF::LegacyCallOp
141 py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
142 .def("create",
143 [](mlir::OpBuilder& opb, mlir::Location loc,
144 std::vector<mlir::Type> output, std::vector<mlir::Value> args,
145 std::string f) -> mlir::Operation* {
146 return opb
147 .create<mlir::TF::LegacyCallOp>(
148 loc, mlir::ArrayRef<mlir::Type>(output),
149 mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
150 .getOperation();
151 });
152
153 // mlir::TF::LessEqualOp
154 py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
155 .def("create",
156 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
157 mlir::Value y) -> mlir::Operation* {
158 return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
159 });
160
161 // mlir::TF::LessOp
162 py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
163 .def("create",
164 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
165 mlir::Value y) -> mlir::Operation* {
166 return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
167 });
168
169 // mlir::TF::NegOp
170 py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
171 .def("create",
172 [](mlir::OpBuilder& opb, mlir::Location loc,
173 mlir::Value x) -> mlir::Operation* {
174 return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
175 });
176
177 py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
178 .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
179 mlir::Value y) {
180 return opb
181 .create<mlir::TF::NotEqualOp>(
182 loc, x, y, mlir::BoolAttr::get(opb.getContext(), true))
183 .getOperation();
184 });
185
186 // mlir::TF::SubOp
187 py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
188 .def("create",
189 [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
190 mlir::Value y) -> mlir::Operation* {
191 return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
192 });
193 }
194