• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/plugin_op.h"
17 
18 #include "minddata/dataset/core/tensor.h"
19 #include "minddata/dataset/plugin/plugin_loader.h"
20 
21 namespace mindspore {
22 namespace dataset {
PluginToTensorRow(const std::vector<plugin::Tensor> & in_row,TensorRow * out_row)23 Status PluginOp::PluginToTensorRow(const std::vector<plugin::Tensor> &in_row, TensorRow *out_row) {
24   CHECK_FAIL_RETURN_UNEXPECTED(out_row != nullptr && out_row->empty(), "null/empty out_row received!");
25   out_row->reserve(in_row.size());
26   for (const auto &tensor : in_row) {
27     std::shared_ptr<Tensor> output;
28     auto tp = DataType(tensor.type_);
29     CHECK_FAIL_RETURN_UNEXPECTED(tp.IsNumeric() && tp != DataType::DE_UNKNOWN,
30                                  "Input datatype should be numeric, got Unsupported type: " + tensor.type_);
31     RETURN_IF_NOT_OK(Tensor::CreateFromMemory(TensorShape(tensor.shape_), tp, tensor.buffer_.data(), &output));
32     out_row->emplace_back(output);
33   }
34   return Status::OK();
35 }
36 
TensorRowToPlugin(const TensorRow & in_row,std::vector<plugin::Tensor> * out_row)37 Status PluginOp::TensorRowToPlugin(const TensorRow &in_row, std::vector<plugin::Tensor> *out_row) {
38   CHECK_FAIL_RETURN_UNEXPECTED(out_row != nullptr && out_row->empty(), "null/empty out_row received!");
39   out_row->resize(in_row.size());
40   for (size_t ind = 0; ind < in_row.size(); ind++) {
41     plugin::Tensor &tensor = (*out_row)[ind];
42     if (in_row[ind]->type().IsNumeric()) {
43       dsize_t buffer_size = in_row[ind]->SizeInBytes();
44       tensor.buffer_.resize(buffer_size);
45       if (buffer_size < SECUREC_MEM_MAX_LEN) {
46         int ret_code = memcpy_s(tensor.buffer_.data(), tensor.buffer_.size(), in_row[ind]->GetBuffer(), buffer_size);
47         CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into plugin tensor.");
48       } else {
49         int ret_code = memcpy_s(tensor.buffer_.data(), buffer_size, in_row[ind]->GetBuffer(), buffer_size);
50         CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "Failed to copy data into plugin tensor.");
51       }
52     } else {  // string tensor, for now, only tensor with 1 string is supported!
53       CHECK_FAIL_RETURN_UNEXPECTED(in_row[ind]->shape().NumOfElements() == 1,
54                                    "String tensor with more than 1 element is not yet supported.");
55       // get the first and only string in this tensor
56       std::string str1(*(in_row[ind]->begin<std::string_view>()));
57       tensor.buffer_.resize(str1.size());
58       auto ret_code = memcpy_s(tensor.buffer_.data(), tensor.buffer_.size(), str1.data(), str1.size());
59       CHECK_FAIL_RETURN_UNEXPECTED(ret_code == 0, "memcpy_s failed when copying string tensor.");
60     }
61     tensor.shape_ = in_row[ind]->shape().AsVector();
62     tensor.type_ = in_row[ind]->type().ToString();
63   }
64   return Status::OK();
65 }
66 
Compute(const TensorRow & input,TensorRow * output)67 Status PluginOp::Compute(const TensorRow &input, TensorRow *output) {
68   IO_CHECK_VECTOR(input, output);
69   // Compute should quit if init fails. Error code has already been logged, no need to repeat
70   RETURN_IF_NOT_OK(init_code_);
71   std::vector<plugin::Tensor> in_row;
72   std::vector<plugin::Tensor> out_row;
73   RETURN_IF_NOT_OK(TensorRowToPlugin(input, &in_row));
74   plugin::Status rc = plugin_op_->Compute(&in_row, &out_row);
75   CHECK_FAIL_RETURN_UNEXPECTED(rc.IsOk(), rc.ToString());
76   RETURN_IF_NOT_OK(PluginToTensorRow(out_row, output));
77   return Status::OK();
78 }
79 
PluginOp(const std::string & lib_path,const std::string & func_name,const std::string & user_args)80 PluginOp::PluginOp(const std::string &lib_path, const std::string &func_name, const std::string &user_args)
81     : plugin_op_(nullptr), lib_path_(lib_path), func_name_(func_name), user_args_(user_args) {
82   init_code_ = Init();
83 }
84 
Init()85 Status PluginOp::Init() {
86   plugin::PluginManagerBase *plugin = nullptr;
87   RETURN_IF_NOT_OK(PluginLoader::GetInstance()->LoadPlugin(lib_path_, &plugin));
88   // casting a void pointer to specific type
89   plugin_op_ = dynamic_cast<plugin::TensorOp *>(plugin->GetModule(func_name_));
90   RETURN_UNEXPECTED_IF_NULL(plugin_op_);
91   plugin::Status rc = plugin_op_->ParseSerializedArgs(user_args_);
92   CHECK_FAIL_RETURN_UNEXPECTED(rc.IsOk(), rc.ToString());
93   return Status::OK();
94 }
95 }  // namespace dataset
96 }  // namespace mindspore
97