1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include <chrono> 17 #include <cstdlib> 18 #include <cstring> 19 #include <functional> 20 #include <iostream> 21 #include <memory> 22 #include <string> 23 #include "minddata/dataset/core/client.h" 24 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" 25 #include "minddata/dataset/engine/jagged_connector.h" 26 #include "gtest/gtest.h" 27 #include "minddata/dataset/core/global_context.h" 28 #include "minddata/dataset/util/status.h" 29 #include "minddata/dataset/core/client.h" 30 #include "common/common.h" 31 #include "gtest/gtest.h" 32 #include "utils/log_adapter.h" 33 #include <memory> 34 #include <vector> 35 #include <iostream> 36 37 using namespace mindspore::dataset; 38 using mindspore::LogStream; 39 using mindspore::ExceptionType::NoExceptionType; 40 using mindspore::MsLogLevel::INFO; 41 42 class MindDataTestClientConfig : public UT::DatasetOpTesting { 43 protected: 44 }; 45 46 TEST_F(MindDataTestClientConfig, TestClientConfig1) { 47 std::shared_ptr<ConfigManager> my_conf = GlobalContext::config_manager(); 48 49 ASSERT_EQ(my_conf->num_parallel_workers(), kCfgParallelWorkers); 50 ASSERT_EQ(my_conf->worker_connector_size(), kCfgWorkerConnectorSize); 51 ASSERT_EQ(my_conf->op_connector_size(), kCfgOpConnectorSize); 52 ASSERT_EQ(my_conf->seed(), kCfgDefaultSeed); 53 54 my_conf->set_num_parallel_workers(2); 55 my_conf->set_worker_connector_size(3); 56 my_conf->set_op_connector_size(4); 57 my_conf->set_seed(5); 58 my_conf->set_enable_shared_mem(false); 59 60 ASSERT_EQ(my_conf->num_parallel_workers(), 2); 61 ASSERT_EQ(my_conf->worker_connector_size(), 3); 62 ASSERT_EQ(my_conf->op_connector_size(), 4); 63 ASSERT_EQ(my_conf->seed(), 5); 64 ASSERT_EQ(my_conf->enable_shared_mem(), false); 65 66 std::string file = datasets_root_path_ + "/declient.cfg"; 67 ASSERT_TRUE(my_conf->LoadFile(file)); 68 69 ASSERT_EQ(my_conf->num_parallel_workers(), kCfgParallelWorkers); 70 ASSERT_EQ(my_conf->worker_connector_size(), kCfgWorkerConnectorSize); 71 ASSERT_EQ(my_conf->op_connector_size(), kCfgOpConnectorSize); 72 ASSERT_EQ(my_conf->seed(), kCfgDefaultSeed); 73 } 74 75 TEST_F(MindDataTestClientConfig, TestClientConfig2) { 76 std::shared_ptr<ConfigManager> my_conf = GlobalContext::config_manager(); 77 78 my_conf->set_num_parallel_workers(8); 79 80 Status rc; 81 82 // Start with an empty execution tree 83 auto my_tree = std::make_shared<ExecutionTree>(); 84 85 // Test info: 86 // Dataset from testDataset1 has 10 rows, 2 columns. 87 std::string dataset_path; 88 dataset_path = datasets_root_path_ + "/testDataset1/testDataset1.data"; 89 // get defaults for tf_reader 90 std::shared_ptr<ConfigManager> config_manager = GlobalContext::config_manager(); 91 auto op_connector_size = config_manager->op_connector_size(); 92 std::vector<std::string> columns_to_load = {}; 93 std::vector<std::string> files = {dataset_path}; 94 std::unique_ptr<DataSchema> schema = std::make_unique<DataSchema>(); 95 std::shared_ptr<TFReaderOp> my_tfreader_op = std::make_shared<TFReaderOp>( 96 1, 2, 0, files, std::move(schema), op_connector_size, columns_to_load, false, 1, 0, false); 97 rc = my_tfreader_op->Init(); 98 ASSERT_OK(rc); 99 ASSERT_EQ(my_tfreader_op->NumWorkers(), 1); 100 my_tree->AssociateNode(my_tfreader_op); 101 102 // Set children/root layout. 103 my_tree->AssignRoot(my_tfreader_op); 104 105 my_tree->Prepare(); 106 my_tree->Launch(); 107 108 // Start the loop of reading tensors from our pipeline 109 DatasetIterator di(my_tree); 110 TensorRow tensor_list; 111 rc = di.FetchNextTensorRow(&tensor_list); 112 ASSERT_TRUE(rc.IsOk()); 113 114 int row_count = 0; 115 while (!tensor_list.empty()) { 116 rc = di.FetchNextTensorRow(&tensor_list); 117 ASSERT_TRUE(rc.IsOk()); 118 row_count++; 119 } 120 ASSERT_EQ(row_count, 10); // Should be 10 rows fetched 121 ASSERT_EQ(my_tfreader_op->NumWorkers(), 1); 122 } 123