/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "tensorflow/compiler/mlir/python/mlir_wrapper/mlir_wrapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" void init_ops(py::module& m) { py::class_>( m, "Operation") .def("getRegion", &mlir::Operation::getRegion, py::return_value_policy::reference) .def("getResult", &mlir::Operation::getResult) .def("dump", &mlir::Operation::dump) .def("getNumResults", &mlir::Operation::getNumResults); py::class_(m, "OperationState") .def(py::init([](mlir::Location loc, std::string name) { return mlir::OperationState(loc, llvm::StringRef(name)); })) .def("addTypes", [](mlir::OperationState& state, std::vector tys) { state.addTypes(mlir::ArrayRef(tys)); }) .def("addOperands", [](mlir::OperationState& os, std::vector ops) { os.addOperands(mlir::ArrayRef(ops)); }) .def("addRegion", py::overload_cast<>(&mlir::OperationState::addRegion), py::return_value_policy::reference); py::class_(m, "ModuleOp") .def("create", [](mlir::Location loc) { return mlir::ModuleOp::create(loc); }) .def("push_back", [](mlir::ModuleOp& m, mlir::func::FuncOp f) { m.push_back(f); }) .def("dump", &mlir::ModuleOp::dump) .def("getAsStr", [](mlir::ModuleOp& m) { std::string str; llvm::raw_string_ostream os(str); m.print(os); return os.str(); }); py::class_(m, "FuncOp") .def("create", [](mlir::Location location, std::string name, mlir::FunctionType type) { auto func = mlir::func::FuncOp::create(location, name, type); func.addEntryBlock(); return func; }) .def( "getBody", [](mlir::func::FuncOp& f) -> mlir::Region& { return f.getBody(); }, py::return_value_policy::reference) .def("getArguments", [](mlir::func::FuncOp& f) { return f.getArguments().vec(); }) .def("getName", [](mlir::func::FuncOp& f) { return f.getName().str(); }) .def("getType", &mlir::func::FuncOp::getFunctionType); py::class_(m, "ReturnOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, std::vector values) -> mlir::Operation* { return opb .create( loc, mlir::ArrayRef(values)) .getOperation(); }); // mlir::TF::AddOp py::class_(m, "Tf_AddV2Op") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb.create(loc, x, y).getOperation(); }); py::class_(m, "Tf_AnyOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value input, mlir::Value reduction_indices, bool keep_dims = false) -> mlir::Operation* { return opb .create(loc, opb.getI1Type(), input, reduction_indices, keep_dims) .getOperation(); }); // mlir::TF::ConstOp py::class_(m, "Tf_ConstOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Attribute value) -> mlir::Operation* { return opb.create(loc, value).getOperation(); }); // mlir::TF::EqualOp py::class_(m, "Tf_EqualOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb .create(loc, x, y, opb.getBoolAttr(true)) .getOperation(); }); // mlir::TF::GreaterEqualOp py::class_(m, "Tf_GreaterEqualOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb.create(loc, x, y) .getOperation(); }); // mlir::TF::GreaterOp py::class_(m, "Tf_GreaterOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb.create(loc, x, y).getOperation(); }); // mlir::TF::LegacyCallOp py::class_(m, "Tf_LegacyCallOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, std::vector output, std::vector args, std::string f) -> mlir::Operation* { return opb .create( loc, mlir::ArrayRef(output), mlir::ArrayRef(args), mlir::StringRef(f)) .getOperation(); }); // mlir::TF::LessEqualOp py::class_(m, "Tf_LessEqualOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb.create(loc, x, y).getOperation(); }); // mlir::TF::LessOp py::class_(m, "Tf_LessOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb.create(loc, x, y).getOperation(); }); // mlir::TF::NegOp py::class_(m, "Tf_NegOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x) -> mlir::Operation* { return opb.create(loc, x).getOperation(); }); py::class_(m, "Tf_NotEqualOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) { return opb .create( loc, x, y, mlir::BoolAttr::get(opb.getContext(), true)) .getOperation(); }); // mlir::TF::SubOp py::class_(m, "Tf_SubOp") .def("create", [](mlir::OpBuilder& opb, mlir::Location loc, mlir::Value x, mlir::Value y) -> mlir::Operation* { return opb.create(loc, x, y).getOperation(); }); }