1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #include "common/common.h" 17 #include "minddata/dataset/include/dataset/datasets.h" 18 19 using namespace mindspore::dataset; 20 using mindspore::dataset::Tensor; 21 22 class MindDataTestPipeline : public UT::DatasetOpTesting { 23 protected: 24 }; 25 26 TEST_F(MindDataTestPipeline, TestRepeatSetNumWorkers) { 27 MS_LOG(INFO) << "Doing MindDataTestRepeat-TestRepeatSetNumWorkers."; 28 29 std::string file_path = datasets_root_path_ + "/testTFTestAllTypes/test.data"; 30 std::shared_ptr<Dataset> ds = TFRecord({file_path}); 31 ds = ds->SetNumWorkers(8); 32 ds = ds->Repeat(32); 33 34 // Create an iterator over the result of the above dataset 35 std::shared_ptr<Iterator> iter = ds->CreateIterator(); 36 // Expect a valid iterator 37 ASSERT_NE(iter, nullptr); 38 39 // Iterate the dataset and get each row 40 std::unordered_map<std::string, mindspore::MSTensor> row; 41 ASSERT_OK(iter->GetNextRow(&row)); 42 43 uint64_t i = 0; 44 while (row.size() != 0) { 45 i++; 46 ASSERT_OK(iter->GetNextRow(&row)); 47 } 48 49 // Verify correct number of rows fetched 50 EXPECT_EQ(i, 12 * 32); 51 52 // Manually terminate the pipeline 53 iter->Stop(); 54 } 55