• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
17 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
18 #include "mlir/IR/Operation.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h"
20 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
21 
init_ops(py::module & m)22 void init_ops(py::module& m) {
23   py::class_<mlir::Operation, std::unique_ptr<mlir::Operation, py::nodelete>>(
24       m, "Operation")
25       .def("getRegion", &mlir::Operation::getRegion,
26            py::return_value_policy::reference)
27       .def("getResult", &mlir::Operation::getResult)
28       .def("dump", &mlir::Operation::dump)
29       .def("getNumResults", &mlir::Operation::getNumResults);
30 
31   py::class_<mlir::OperationState>(m, "OperationState")
32       .def(py::init([](mlir::Location loc, std::string name) {
33         return mlir::OperationState(loc, llvm::StringRef(name));
34       }))
35       .def("addTypes",
36            [](mlir::OperationState& state, std::vector<mlir::Type> tys) {
37              state.addTypes(mlir::ArrayRef<mlir::Type>(tys));
38            })
39       .def("addOperands",
40            [](mlir::OperationState& os, std::vector<mlir::Value> ops) {
41              os.addOperands(mlir::ArrayRef<mlir::Value>(ops));
42            })
43       .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion),
44            py::return_value_policy::reference);
45 
46   py::class_<mlir::ModuleOp>(m, "ModuleOp")
47       .def("create",
48            [](mlir::Location loc) { return mlir::ModuleOp::create(loc); })
49       .def("push_back",
50            [](mlir::ModuleOp& m, mlir::func::FuncOp f) { m.push_back(f); })
51       .def("dump", &mlir::ModuleOp::dump)
52       .def("getAsStr", [](mlir::ModuleOp& m) {
53         std::string str;
54         llvm::raw_string_ostream os(str);
55         m.print(os);
56         return os.str();
57       });
58 
59   py::class_<mlir::func::FuncOp>(m, "FuncOp")
60       .def("create",
61            [](mlir::Location location, std::string name,
62               mlir::FunctionType type) {
63              auto func = mlir::func::FuncOp::create(location, name, type);
64              func.addEntryBlock();
65              return func;
66            })
67       .def(
68           "getBody",
69           [](mlir::func::FuncOp& f) -> mlir::Region& { return f.getBody(); },
70           py::return_value_policy::reference)
71       .def("getArguments",
72            [](mlir::func::FuncOp& f) { return f.getArguments().vec(); })
73       .def("getName", [](mlir::func::FuncOp& f) { return f.getName().str(); })
74       .def("getType", &mlir::func::FuncOp::getFunctionType);
75 
76   py::class_<mlir::func::ReturnOp>(m, "ReturnOp")
77       .def("create",
78            [](mlir::OpBuilder& opb, mlir::Location loc,
79               std::vector<mlir::Value> values) -> mlir::Operation* {
80              return opb
81                  .create<mlir::func::ReturnOp>(
82                      loc, mlir::ArrayRef<mlir::Value>(values))
83                  .getOperation();
84            });
85 
86   // mlir::TF::AddOp
87   py::class_<mlir::TF::AddV2Op>(m, "Tf_AddV2Op")
88       .def("create",
89            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
90               mlir::Value y) -> mlir::Operation* {
91              return opb.create<mlir::TF::AddV2Op>(loc, x, y).getOperation();
92            });
93 
94   py::class_<mlir::TF::AnyOp>(m, "Tf_AnyOp")
95       .def("create",
96            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input,
97               mlir::Value reduction_indices,
98               bool keep_dims = false) -> mlir::Operation* {
99              return opb
100                  .create<mlir::TF::AnyOp>(loc, opb.getI1Type(), input,
101                                           reduction_indices, keep_dims)
102                  .getOperation();
103            });
104 
105   // mlir::TF::ConstOp
106   py::class_<mlir::TF::ConstOp>(m, "Tf_ConstOp")
107       .def("create",
108            [](mlir::OpBuilder& opb, mlir::Location loc,
109               mlir::Attribute value) -> mlir::Operation* {
110              return opb.create<mlir::TF::ConstOp>(loc, value).getOperation();
111            });
112 
113   // mlir::TF::EqualOp
114   py::class_<mlir::TF::EqualOp>(m, "Tf_EqualOp")
115       .def("create",
116            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
117               mlir::Value y) -> mlir::Operation* {
118              return opb
119                  .create<mlir::TF::EqualOp>(loc, x, y, opb.getBoolAttr(true))
120                  .getOperation();
121            });
122 
123   // mlir::TF::GreaterEqualOp
124   py::class_<mlir::TF::GreaterEqualOp>(m, "Tf_GreaterEqualOp")
125       .def("create",
126            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
127               mlir::Value y) -> mlir::Operation* {
128              return opb.create<mlir::TF::GreaterEqualOp>(loc, x, y)
129                  .getOperation();
130            });
131 
132   // mlir::TF::GreaterOp
133   py::class_<mlir::TF::GreaterOp>(m, "Tf_GreaterOp")
134       .def("create",
135            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
136               mlir::Value y) -> mlir::Operation* {
137              return opb.create<mlir::TF::GreaterOp>(loc, x, y).getOperation();
138            });
139 
140   // mlir::TF::LegacyCallOp
141   py::class_<mlir::TF::LegacyCallOp>(m, "Tf_LegacyCallOp")
142       .def("create",
143            [](mlir::OpBuilder& opb, mlir::Location loc,
144               std::vector<mlir::Type> output, std::vector<mlir::Value> args,
145               std::string f) -> mlir::Operation* {
146              return opb
147                  .create<mlir::TF::LegacyCallOp>(
148                      loc, mlir::ArrayRef<mlir::Type>(output),
149                      mlir::ArrayRef<mlir::Value>(args), mlir::StringRef(f))
150                  .getOperation();
151            });
152 
153   // mlir::TF::LessEqualOp
154   py::class_<mlir::TF::LessEqualOp>(m, "Tf_LessEqualOp")
155       .def("create",
156            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
157               mlir::Value y) -> mlir::Operation* {
158              return opb.create<mlir::TF::LessEqualOp>(loc, x, y).getOperation();
159            });
160 
161   // mlir::TF::LessOp
162   py::class_<mlir::TF::LessOp>(m, "Tf_LessOp")
163       .def("create",
164            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
165               mlir::Value y) -> mlir::Operation* {
166              return opb.create<mlir::TF::LessOp>(loc, x, y).getOperation();
167            });
168 
169   // mlir::TF::NegOp
170   py::class_<mlir::TF::NegOp>(m, "Tf_NegOp")
171       .def("create",
172            [](mlir::OpBuilder& opb, mlir::Location loc,
173               mlir::Value x) -> mlir::Operation* {
174              return opb.create<mlir::TF::NegOp>(loc, x).getOperation();
175            });
176 
177   py::class_<mlir::TF::NotEqualOp>(m, "Tf_NotEqualOp")
178       .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
179                         mlir::Value y) {
180         return opb
181             .create<mlir::TF::NotEqualOp>(
182                 loc, x, y, mlir::BoolAttr::get(opb.getContext(), true))
183             .getOperation();
184       });
185 
186   // mlir::TF::SubOp
187   py::class_<mlir::TF::SubOp>(m, "Tf_SubOp")
188       .def("create",
189            [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x,
190               mlir::Value y) -> mlir::Operation* {
191              return opb.create<mlir::TF::SubOp>(loc, x, y).getOperation();
192            });
193 }
194