1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 #include <vector> 19 20 #include "minddata/dataset/core/client.h" 21 #include "common/common.h" 22 #include "utils/ms_utils.h" 23 #include "gtest/gtest.h" 24 #include "utils/log_adapter.h" 25 #include "minddata/dataset/engine/datasetops/source/clue_op.h" 26 #include "minddata/dataset/util/status.h" 27 28 namespace common = mindspore::common; 29 30 using namespace mindspore::dataset; 31 using mindspore::LogStream; 32 using mindspore::ExceptionType::NoExceptionType; 33 using mindspore::MsLogLevel::INFO; 34 35 class MindDataTestCLUEOp : public UT::DatasetOpTesting {}; 36 37 std::shared_ptr<ClueOp> Clue(std::vector<std::string> file_list, int32_t op_connector_size, 38 std::map<std::string, std::string> key_map) { 39 std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); 40 auto worker_connector_size = config_manager->worker_connector_size(); 41 auto num_workers = config_manager->num_parallel_workers(); 42 int64_t num_samples = 0; 43 bool shuffle = false; 44 int32_t num_devices = 1; 45 int32_t device_id = 0; 46 ColKeyMap ck_map; 47 for (auto &p : key_map) { 48 std::vector<std::string> res = {}; 49 std::stringstream ss(p.second); 50 std::string item = ""; 51 52 while (getline(ss, item, '/')) { 53 res.push_back(item); 54 } 55 56 ck_map.insert({p.first, res}); 57 } 58 59 std::shared_ptr<ClueOp> so = std::make_shared<ClueOp>(num_workers, num_samples, worker_connector_size, ck_map, 60 file_list, op_connector_size, shuffle, num_devices, device_id); 61 so->Init(); 62 return so; 63 } 64 65 TEST_F(MindDataTestCLUEOp, TestCLUEBasic) { 66 // Start with an empty execution tree 67 auto tree = std::make_shared<ExecutionTree>(); 68 Status rc; 69 std::string dataset_path; 70 dataset_path = datasets_root_path_ + "/testCLUE/afqmc/train.json"; 71 std::map<std::string, std::string> key_map; 72 key_map["sentence1"] = "sentence1"; 73 key_map["sentence2"] = "sentence2"; 74 key_map["label"] = "label"; 75 76 std::shared_ptr<ClueOp> op = Clue({dataset_path}, 2, key_map); 77 78 rc = tree->AssociateNode(op); 79 ASSERT_TRUE(rc.IsOk()); 80 81 rc = tree->AssignRoot(op); 82 ASSERT_TRUE(rc.IsOk()); 83 84 MS_LOG(INFO) << "Launching tree and begin iteration."; 85 rc = tree->Prepare(); 86 ASSERT_TRUE(rc.IsOk()); 87 88 rc = tree->Launch(); 89 ASSERT_TRUE(rc.IsOk()); 90 91 // Start the loop of reading tensors from our pipeline 92 DatasetIterator di(tree); 93 TensorRow tensor_list; 94 rc = di.FetchNextTensorRow(&tensor_list); 95 ASSERT_TRUE(rc.IsOk()); 96 97 int row_count = 0; 98 while (!tensor_list.empty()) { 99 // Display the tensor by calling the printer on it 100 for (int i = 0; i < tensor_list.size(); i++) { 101 std::ostringstream ss; 102 ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; 103 MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; 104 } 105 106 rc = di.FetchNextTensorRow(&tensor_list); 107 ASSERT_TRUE(rc.IsOk()); 108 row_count++; 109 } 110 111 ASSERT_EQ(row_count, 3); 112 } 113 114 TEST_F(MindDataTestCLUEOp, TestTotalRows) { 115 std::string tf_file1 = datasets_root_path_ + "/testCLUE/afqmc/train.json"; 116 std::string tf_file2 = datasets_root_path_ + "/testCLUE/afqmc/dev.json"; 117 std::vector<std::string> files; 118 files.push_back(tf_file1); 119 int64_t total_rows = 0; 120 ClueOp::CountAllFileRows(files, &total_rows); 121 ASSERT_EQ(total_rows, 3); 122 files.clear(); 123 124 files.push_back(tf_file2); 125 ClueOp::CountAllFileRows(files, &total_rows); 126 ASSERT_EQ(total_rows, 3); 127 files.clear(); 128 129 files.push_back(tf_file1); 130 files.push_back(tf_file2); 131 ClueOp::CountAllFileRows(files, &total_rows); 132 ASSERT_EQ(total_rows, 6); 133 files.clear(); 134 } 135