1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "common/common.h"
17 #include "minddata/dataset/include/dataset/datasets.h"
18
19 namespace common = mindspore::common;
20
21 using namespace mindspore::dataset;
22
23 class MindDataTestPipeline : public UT::DatasetOpTesting {
24 protected:
25 };
26
TEST_F(MindDataTestPipeline,TestPullBasedBatch)27 TEST_F(MindDataTestPipeline, TestPullBasedBatch) {
28 MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumBasic.";
29
30 std::string folder_path = datasets_root_path_ + "/testAlbum/images";
31 std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
32 std::vector<std::string> column_names = {"label"};
33 // Create a Album Dataset
34 std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
35 EXPECT_NE(ds, nullptr);
36
37 int32_t batch_size = 4;
38 ds = ds->Batch(batch_size, true);
39 EXPECT_NE(ds, nullptr);
40
41 auto iter = ds->CreatePullBasedIterator();
42 EXPECT_NE(iter, nullptr);
43
44 std::vector<mindspore::MSTensor> row;
45 ASSERT_OK(iter->GetNextRow(&row));
46 EXPECT_EQ(row.size(), 1);
47 auto temp = row[0].Shape();
48 std::vector<int64_t> result = {batch_size, 2};
49 EXPECT_EQ(row[0].Shape(), result);
50 }
51
TEST_F(MindDataTestPipeline,TestPullBasedProject)52 TEST_F(MindDataTestPipeline, TestPullBasedProject) {
53 MS_LOG(INFO) << "Doing MindDataTestPipeline-TestAlbumBasic.";
54
55 std::string folder_path = datasets_root_path_ + "/testAlbum/images";
56 std::string schema_file = datasets_root_path_ + "/testAlbum/datasetSchema.json";
57 std::vector<std::string> column_names = {"label", "image"};
58 // Create a Album Dataset
59 std::shared_ptr<Dataset> ds = Album(folder_path, schema_file, column_names);
60 EXPECT_NE(ds, nullptr);
61
62 std::vector<mindspore::MSTensor> row;
63 auto iter = ds->CreatePullBasedIterator();
64 EXPECT_NE(iter, nullptr);
65 ASSERT_OK(iter->GetNextRow(&row));
66 EXPECT_EQ(row.size(), 2);
67
68 std::shared_ptr<Dataset> ds2 = Album(folder_path, schema_file, column_names);
69 EXPECT_NE(ds2, nullptr);
70 std::vector<std::string> columns_to_project = {"image"};
71 ds2 = ds2->Project(columns_to_project);
72 EXPECT_NE(ds2, nullptr);
73
74 auto iter2 = ds2->CreatePullBasedIterator();
75 EXPECT_NE(iter2, nullptr);
76
77 std::vector<mindspore::MSTensor> new_row;
78 ASSERT_OK(iter2->GetNextRow(&new_row));
79 EXPECT_EQ(new_row.size(), 1);
80 }