1 /**
2 * Copyright 2019-2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/py_func_op.h"
17
18 #include <memory>
19 #include <vector>
20
21 #include "minddata/dataset/core/tensor.h"
22 #include "minddata/dataset/kernels/tensor_op.h"
23 #include "minddata/dataset/kernels/ir/data/transforms_ir.h"
24 #include "minddata/dataset/util/status.h"
25
26 namespace mindspore {
27 namespace dataset {
Compute(const TensorRow & input,TensorRow * output)28 Status PyFuncOp::Compute(const TensorRow &input, TensorRow *output) {
29 IO_CHECK_VECTOR(input, output);
30 Status ret = Status(StatusCode::kSuccess, "PyFunc Call Succeed");
31 {
32 // Acquire Python GIL
33 py::gil_scoped_acquire gil_acquire;
34 if (Py_IsInitialized() == 0) {
35 ret = Status(StatusCode::kMDPythonInterpreterFailure, "Python Interpreter is finalized");
36 goto ComputeReturn;
37 }
38 try {
39 // Transform input tensor vector into numpy array vector
40 py::tuple input_args(input.size());
41 py::object ret_py_obj;
42 if (input.size() > 0) {
43 for (size_t i = 0; i < input.size(); i++) {
44 py::array new_data;
45 RETURN_IF_NOT_OK(input.at(i)->GetDataAsNumpy(&new_data));
46 // possible memcpy here
47 input_args[i] = new_data;
48 }
49 // Invoke python function
50 ret_py_obj = this->py_func_ptr_(*input_args);
51 } else {
52 ret_py_obj = this->py_func_ptr_();
53 }
54 if (output_type_ != DataType::DE_UNKNOWN) {
55 RETURN_IF_NOT_OK(CastOutput(ret_py_obj, output));
56 } else {
57 if (py::isinstance<py::tuple>(ret_py_obj)) {
58 // In case of a n-m mapping, the return value will be a tuple of numpy arrays
59 py::tuple ret_py_tuple = ret_py_obj.cast<py::tuple>();
60 // Iterate over two containers simultaneously for memory copy
61 for (size_t i = 0; i < ret_py_tuple.size(); i++) {
62 py::object ret_py_ele = ret_py_tuple[i];
63 // Object is none if pyfunc timeout
64 if (ret_py_ele.is_none()) {
65 MS_LOG(INFO) << "Expected that PyFunc should return numpy array, got None. If python_multiprocessing is "
66 "True, PyFunc may execute time out.";
67 goto TimeoutError;
68 }
69 if (!py::isinstance<py::array>(ret_py_ele)) {
70 goto ShapeMisMatch;
71 }
72 std::shared_ptr<Tensor> out;
73 RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_ele.cast<py::array>(), &out));
74 output->push_back(out);
75 }
76 } else if (py::isinstance<py::array>(ret_py_obj)) {
77 // In case of a n-1 mapping, the return value will be a numpy array
78 std::shared_ptr<Tensor> out;
79 RETURN_IF_NOT_OK(Tensor::CreateFromNpArray(ret_py_obj.cast<py::array>(), &out));
80 output->push_back(out);
81 } else {
82 goto ShapeMisMatch;
83 }
84 }
85 } catch (const py::error_already_set &e) {
86 ret = Status(StatusCode::kMDPyFuncException, e.what());
87 }
88 }
89
90 ComputeReturn:
91 return ret;
92
93 ShapeMisMatch:
94 ret = Status(StatusCode::kMDShapeMisMatch, __LINE__, __FILE__,
95 "PyFunc should return a numpy array or a numpy array tuple");
96 goto ComputeReturn;
97
98 TimeoutError:
99 ret = Status(StatusCode::kMDTimeOut, __LINE__, __FILE__,
100 "Expected that PyFunc should return numpy array, got None. If \'python_multiprocessing\' is True, "
101 "PyFunc may execute time out.");
102 goto ComputeReturn;
103 }
104
CastOutput(const py::object & ret_py_obj,TensorRow * output)105 Status PyFuncOp::CastOutput(const py::object &ret_py_obj, TensorRow *output) {
106 try {
107 std::shared_ptr<Tensor> out;
108 switch (output_type_) {
109 case DataType::DE_INT32:
110 RETURN_IF_NOT_OK(Tensor::CreateEmpty(TensorShape({1}), DataType(DataType::DE_INT32), &out));
111 RETURN_IF_NOT_OK(out->SetItemAt({0}, ret_py_obj.cast<int32_t>()));
112 break;
113 case DataType::DE_BOOL:
114 RETURN_IF_NOT_OK(Tensor::CreateScalar(ret_py_obj.cast<bool>(), &out));
115 break;
116 default:
117 RETURN_STATUS_UNEXPECTED("No cast for the specified DataType was found.");
118 }
119 output->push_back(out);
120 } catch (const std::exception &e) {
121 return Status(StatusCode::kMDUnexpectedError, e.what());
122 }
123 return Status::OK();
124 }
125
to_json(nlohmann::json * out_json)126 Status PyFuncOp::to_json(nlohmann::json *out_json) {
127 nlohmann::json args;
128 if (py_func_ptr_.attr("to_json")) {
129 args = nlohmann::json::parse(py_func_ptr_.attr("to_json")().cast<std::string>());
130 }
131 *out_json = args;
132 return Status::OK();
133 }
134
from_json(nlohmann::json json_obj,std::vector<std::shared_ptr<TensorOperation>> * result)135 Status PyFuncOp::from_json(nlohmann::json json_obj, std::vector<std::shared_ptr<TensorOperation>> *result) {
136 std::vector<std::shared_ptr<TensorOperation>> output;
137 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("tensor_op_name") != json_obj.end(), "Failed to find tensor_op_name");
138 CHECK_FAIL_RETURN_UNEXPECTED(json_obj.find("tensor_op_params") != json_obj.end(), "Failed to find tensor_op_params");
139 std::string op_name = json_obj["tensor_op_name"];
140 nlohmann::json op_params = json_obj["tensor_op_params"];
141 std::string python_module = json_obj["python_module"];
142 std::shared_ptr<TensorOperation> operation = nullptr;
143 py::function py_func =
144 py::module::import(python_module.c_str()).attr(op_name.c_str()).attr("from_json")(op_params.dump());
145 operation = std::make_shared<transforms::PreBuiltOperation>(std::make_shared<PyFuncOp>(py_func));
146 output.push_back(operation);
147 *result = output;
148 return Status::OK();
149 }
150
IsRandom()151 bool PyFuncOp::IsRandom() {
152 bool random = true;
153 if (py::hasattr(py_func_ptr_, "random") && py::reinterpret_borrow<py::bool_>(py_func_ptr_.attr("random")) == false)
154 random = false;
155 return random;
156 }
157 } // namespace dataset
158 } // namespace mindspore
159