• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "common/common.h"
17 #include "minddata/dataset/include/dataset/datasets.h"
18 
19 namespace common = mindspore::common;
20 
21 using namespace mindspore::dataset;
22 
23 class MindDataTestPipeline : public UT::DatasetOpTesting {
24  protected:
25 };
26 
TEST_F(MindDataTestPipeline,TestPullBasedBatch)27 TEST_F(MindDataTestPipeline, TestPullBasedBatch) {
28   MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumBasic.";
29 
30   std::string folder_path = datasets_root_path_ + "/testAlbum/images";
31   std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
32   std::vector<std::string> column_names = {"label"};
33   // Create a Album Dataset
34   std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
35   EXPECT_NE(ds, nullptr);
36 
37   int32_t batch_size = 4;
38   ds = ds->Batch(batch_size, true);
39   EXPECT_NE(ds, nullptr);
40 
41   auto iter = ds->CreatePullBasedIterator();
42   EXPECT_NE(iter, nullptr);
43 
44   std::vector<mindspore::MSTensor> row;
45   ASSERT_OK(iter->GetNextRow(&row));
46   EXPECT_EQ(row.size(), 1);
47   auto temp = row[0].Shape();
48   std::vector<int64_t> result = {batch_size, 2};
49   EXPECT_EQ(row[0].Shape(), result);
50 }
51 
TEST_F(MindDataTestPipeline,TestPullBasedProject)52 TEST_F(MindDataTestPipeline, TestPullBasedProject) {
53   MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumBasic.";
54 
55   std::string folder_path = datasets_root_path_ + "/testAlbum/images";
56   std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
57   std::vector<std::string> column_names = {"label", "image"};
58   // Create a Album Dataset
59   std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
60   EXPECT_NE(ds, nullptr);
61 
62   std::vector<mindspore::MSTensor> row;
63   auto iter = ds->CreatePullBasedIterator();
64   EXPECT_NE(iter, nullptr);
65   ASSERT_OK(iter->GetNextRow(&row));
66   EXPECT_EQ(row.size(), 2);
67 
68   std::shared_ptr<Dataset> ds2 = Album(folder_path, schema_file, column_names);
69   EXPECT_NE(ds2, nullptr);
70   std::vector<std::string> columns_to_project = {"image"};
71   ds2 = ds2->Project(columns_to_project);
72   EXPECT_NE(ds2, nullptr);
73 
74   auto iter2 = ds2->CreatePullBasedIterator();
75   EXPECT_NE(iter2, nullptr);
76 
77   std::vector<mindspore::MSTensor> new_row;
78   ASSERT_OK(iter2->GetNextRow(&new_row));
79   EXPECT_EQ(new_row.size(), 1);
80 }