• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef TESTS_UT_CPP_DATASET_COMMON_COMMON_H_
17 #define TESTS_UT_CPP_DATASET_COMMON_COMMON_H_
18 
19 #include "gtest/gtest.h"
20 #include "include/api/status.h"
21 #include "include/api/types.h"
22 #include "minddata/dataset/core/client.h"
23 #include "minddata/dataset/core/tensor_shape.h"
24 #include "minddata/dataset/core/de_tensor.h"
25 #include "minddata/dataset/core/type_id.h"
26 #include "utils/log_adapter.h"
27 #include "minddata/dataset/engine/datasetops/batch_op.h"
28 #include "minddata/dataset/engine/datasetops/repeat_op.h"
29 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h"
30 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h"
31 
32 using mindspore::Status;
33 using mindspore::StatusCode;
34 using CompressionType = mindspore::dataset::NonMappableLeafOp::CompressionType;
35 
36 #define ASSERT_OK(_s)                          \
37   do {                                         \
38     Status __rc = (_s);                        \
39     if (__rc.IsError()) {                      \
40       MS_LOG(ERROR) << __rc.ToString() << "."; \
41       ASSERT_TRUE(false);                      \
42     }                                          \
43   } while (false)
44 
45 #define EXPECT_OK(_s)                          \
46   do {                                         \
47     Status __rc = (_s);                        \
48     if (__rc.IsError()) {                      \
49       MS_LOG(ERROR) << __rc.ToString() << "."; \
50       EXPECT_TRUE(false);                      \
51     }                                          \
52   } while (false)
53 
54 #define ASSERT_ERROR(_s)                       \
55   do {                                         \
56     Status __rc = (_s);                        \
57     if (__rc.IsOk()) {                         \
58       MS_LOG(ERROR) << __rc.ToString() << "."; \
59       ASSERT_TRUE(false);                      \
60     }                                          \
61   } while (false)
62 
63 #define EXPECT_ERROR(_s)                       \
64   do {                                         \
65     Status __rc = (_s);                        \
66     if (__rc.IsOk()) {                         \
67       MS_LOG(ERROR) << __rc.ToString() << "."; \
68       EXPECT_TRUE(false);                      \
69     }                                          \
70   } while (false)
71 
72 // Macro to compare 2 MSTensors; compare shape-size, shape and data
73 #define EXPECT_MSTENSOR_EQ(_mstensor1, _mstensor2)                                                      \
74   do {                                                                                                  \
75     EXPECT_EQ(_mstensor1.Shape().size(), _mstensor2.Shape().size());                                    \
76     for (int i = 0; i < _mstensor1.Shape().size(); i++) {                                               \
77       EXPECT_EQ(_mstensor1.Shape()[i], _mstensor2.Shape()[i]);                                          \
78     }                                                                                                   \
79     EXPECT_EQ(_mstensor1.DataSize(), _mstensor2.DataSize());                                            \
80     EXPECT_EQ(std::memcmp((const void *)_mstensor1.Data().get(), (const void *)_mstensor2.Data().get(), \
81                           _mstensor2.DataSize()),                                                       \
82               0);                                                                                       \
83   } while (false)
84 
85 // Macro to invoke MS_LOG for MSTensor
86 #define TEST_MS_LOG_MSTENSOR(_loglevel, _msg, _mstensor)           \
87   do {                                                             \
88     std::shared_ptr<Tensor> _de_tensor;                            \
89     ASSERT_OK(Tensor::CreateFromMSTensor(_mstensor, &_de_tensor)); \
90     MS_LOG(_loglevel) << _msg << *_de_tensor;                      \
91   } while (false)
92 
93 Status GetSessionFromEnv(uint32_t *session_id);
94 
95 namespace UT {
96 class Common : public testing::Test {
97  public:
98   // every TEST_F macro will enter one
99   virtual void SetUp();
100 
101   virtual void TearDown();
102 };
103 
104 class DatasetOpTesting : public Common {
105  public:
106   // Helper functions for creating datasets
107   std::shared_ptr<mindspore::dataset::BatchOp> Batch(int32_t batch_size = 1, bool drop = false,
108                                                      mindspore::dataset::PadInfo = {});
109 
110   std::shared_ptr<mindspore::dataset::RepeatOp> Repeat(int repeat_cnt = 1);
111 
112   std::shared_ptr<mindspore::dataset::TFReaderOp> TFReader(std::string file, int num_works = 8);
113 
114   std::shared_ptr<mindspore::dataset::ExecutionTree> Build(
115     std::vector<std::shared_ptr<mindspore::dataset::DatasetOp>> ops);
116 
117   std::vector<mindspore::dataset::TensorShape> ToTensorShapeVec(const std::vector<std::vector<int64_t>> &v);
118   std::vector<mindspore::dataset::DataType> ToDETypes(const std::vector<mindspore::DataType> &t);
119   mindspore::MSTensor ReadFileToTensor(const std::string &file);
120   std::string datasets_root_path_;
121   std::string mindrecord_root_path_;
122   void SetUp() override;
123 };
124 }  // namespace UT
125 
126 namespace mindspore {
127 namespace dataset {
128 // defined in datasets.cc code, and function prototypes added here for UT purposes
129 // convert MSTensorVec to DE TensorRow, return empty if fails
130 TensorRow VecToRow(const MSTensorVec &v);
131 
132 // defined in datasets.cc code, and function prototypes added here for UT purposes
133 // convert DE TensorRow to MSTensorVec, won't fail
134 MSTensorVec RowToVec(const TensorRow &v);
135 
136 MSTensorVec Predicate1(MSTensorVec in);
137 
138 MSTensorVec Predicate2(MSTensorVec in);
139 
140 MSTensorVec Predicate3(MSTensorVec in);
141 
142 cv::Mat BGRToRGB(const cv::Mat &img);
143 }  // namespace dataset
144 }  // namespace mindspore
145 #endif  // TESTS_UT_CPP_DATASET_COMMON_COMMON_H_
146