1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <iostream> 17 #include <memory> 18 #include <vector> 19 20 #include "common/common.h" 21 #include "utils/ms_utils.h" 22 #include "minddata/dataset/core/client.h" 23 #include "minddata/dataset/engine/jagged_connector.h" 24 #include "gtest/gtest.h" 25 #include "utils/log_adapter.h" 26 27 namespace common = mindspore::common; 28 29 using namespace mindspore::dataset; 30 using mindspore::LogStream; 31 using mindspore::ExceptionType::NoExceptionType; 32 using mindspore::MsLogLevel::INFO; 33 34 class MindDataTestProjectOp : public UT::DatasetOpTesting {}; 35 36 TEST_F(MindDataTestProjectOp, TestProjectProject) { 37 // Start with an empty execution tree 38 auto my_tree = std::make_shared<ExecutionTree>(); 39 Status rc; 40 std::string dataset_path; 41 dataset_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; 42 43 std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); 44 auto op_connector_size = config_manager->op_connector_size(); 45 auto num_workers = 1; // one file, one worker 46 std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); 47 std::vector<std::string> columns_to_load = {}; 48 std::vector<std::string> files = {dataset_path}; 49 schema->LoadSchemaFile(datasets_root_path_ + "/testTFTestAllTypes/datasetSchema.json", {}); 50 std::shared_ptr<TFReaderOp> my_tfreader_op = std::make_shared<TFReaderOp>( 51 num_workers, 16, 0, files, std::move(schema), op_connector_size, columns_to_load, false, 1, 0, false); 52 rc = my_tfreader_op->Init(); 53 ASSERT_TRUE(rc.IsOk()); 54 rc = my_tree->AssociateNode(my_tfreader_op); 55 ASSERT_TRUE(rc.IsOk()); 56 57 // ProjectOp 58 std::vector<std::string> columns_to_project = {"col_sint16", "col_float", "col_2d"}; 59 std::shared_ptr<ProjectOp> my_project_op = std::make_shared<ProjectOp>(columns_to_project); 60 rc = my_tree->AssociateNode(my_project_op); 61 ASSERT_TRUE(rc.IsOk()); 62 63 // Set children/root layout. 64 rc = my_project_op->AddChild(my_tfreader_op); 65 ASSERT_TRUE(rc.IsOk()); 66 rc = my_tree->AssignRoot(my_project_op); 67 ASSERT_TRUE(rc.IsOk()); 68 69 MS_LOG(INFO) << "Launching tree and begin iteration."; 70 rc = my_tree->Prepare(); 71 72 ASSERT_TRUE(rc.IsOk()); 73 74 rc = my_tree->Launch(); 75 ASSERT_TRUE(rc.IsOk()); 76 77 // Start the loop of reading tensors from our pipeline 78 DatasetIterator di(my_tree); 79 TensorRow tensor_list; 80 rc = di.FetchNextTensorRow(&tensor_list); 81 ASSERT_TRUE(rc.IsOk()); 82 83 int row_count = 0; 84 while (!tensor_list.empty()) { 85 MS_LOG(INFO) << "Row display for row #: " << row_count << "."; 86 87 ASSERT_EQ(tensor_list.size(), columns_to_project.size()); 88 89 // Display the tensor by calling the printer on it 90 for (int i = 0; i < tensor_list.size(); i++) { 91 std::ostringstream ss; 92 ss << "(" << tensor_list[i] << "): " << *tensor_list[i] << std::endl; 93 MS_LOG(INFO) << "Tensor print: " << ss.str() << "."; 94 } 95 96 rc = di.FetchNextTensorRow(&tensor_list); 97 ASSERT_TRUE(rc.IsOk()); 98 row_count++; 99 } 100 101 ASSERT_EQ(row_count, 12); 102 } 103