1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef TESTS_UT_CPP_DATASET_COMMON_COMMON_H_ 17 #define TESTS_UT_CPP_DATASET_COMMON_COMMON_H_ 18 19 #include "gtest/gtest.h" 20 #include "include/api/status.h" 21 #include "include/api/types.h" 22 #include "minddata/dataset/core/client.h" 23 #include "minddata/dataset/core/tensor_shape.h" 24 #include "minddata/dataset/core/de_tensor.h" 25 #include "minddata/dataset/core/type_id.h" 26 #include "utils/log_adapter.h" 27 #include "minddata/dataset/engine/datasetops/batch_op.h" 28 #include "minddata/dataset/engine/datasetops/repeat_op.h" 29 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" 30 #include "minddata/dataset/engine/ir/datasetops/dataset_node.h" 31 32 using mindspore::Status; 33 using mindspore::StatusCode; 34 using CompressionType = mindspore::dataset::NonMappableLeafOp::CompressionType; 35 36 #define ASSERT_OK(_s) \ 37 do { \ 38 Status __rc = (_s); \ 39 if (__rc.IsError()) { \ 40 MS_LOG(ERROR) << __rc.ToString() << "."; \ 41 ASSERT_TRUE(false); \ 42 } \ 43 } while (false) 44 45 #define EXPECT_OK(_s) \ 46 do { \ 47 Status __rc = (_s); \ 48 if (__rc.IsError()) { \ 49 MS_LOG(ERROR) << __rc.ToString() << "."; \ 50 EXPECT_TRUE(false); \ 51 } \ 52 } while (false) 53 54 #define ASSERT_ERROR(_s) \ 55 do { \ 56 Status __rc = (_s); \ 57 if (__rc.IsOk()) { \ 58 MS_LOG(ERROR) << __rc.ToString() << "."; \ 59 ASSERT_TRUE(false); \ 60 } \ 61 } while (false) 62 63 #define EXPECT_ERROR(_s) \ 64 do { \ 65 Status __rc = (_s); \ 66 if (__rc.IsOk()) { \ 67 MS_LOG(ERROR) << __rc.ToString() << "."; \ 68 EXPECT_TRUE(false); \ 69 } \ 70 } while (false) 71 72 // Macro to compare 2 MSTensors; compare shape-size, shape and data 73 #define EXPECT_MSTENSOR_EQ(_mstensor1, _mstensor2) \ 74 do { \ 75 EXPECT_EQ(_mstensor1.Shape().size(), _mstensor2.Shape().size()); \ 76 for (int i = 0; i < _mstensor1.Shape().size(); i++) { \ 77 EXPECT_EQ(_mstensor1.Shape()[i], _mstensor2.Shape()[i]); \ 78 } \ 79 EXPECT_EQ(_mstensor1.DataSize(), _mstensor2.DataSize()); \ 80 EXPECT_EQ(std::memcmp((const void *)_mstensor1.Data().get(), (const void *)_mstensor2.Data().get(), \ 81 _mstensor2.DataSize()), \ 82 0); \ 83 } while (false) 84 85 // Macro to invoke MS_LOG for MSTensor 86 #define TEST_MS_LOG_MSTENSOR(_loglevel, _msg, _mstensor) \ 87 do { \ 88 std::shared_ptr<Tensor> _de_tensor; \ 89 ASSERT_OK(Tensor::CreateFromMSTensor(_mstensor, &_de_tensor)); \ 90 MS_LOG(_loglevel) << _msg << *_de_tensor; \ 91 } while (false) 92 93 Status GetSessionFromEnv(uint32_t *session_id); 94 95 namespace UT { 96 class Common : public testing::Test { 97 public: 98 // every TEST_F macro will enter one 99 virtual void SetUp(); 100 101 virtual void TearDown(); 102 }; 103 104 class DatasetOpTesting : public Common { 105 public: 106 // Helper functions for creating datasets 107 std::shared_ptr<mindspore::dataset::BatchOp> Batch(int32_t batch_size = 1, bool drop = false, 108 mindspore::dataset::PadInfo = {}); 109 110 std::shared_ptr<mindspore::dataset::RepeatOp> Repeat(int repeat_cnt = 1); 111 112 std::shared_ptr<mindspore::dataset::TFReaderOp> TFReader(std::string file, int num_works = 8); 113 114 std::shared_ptr<mindspore::dataset::ExecutionTree> Build( 115 std::vector<std::shared_ptr<mindspore::dataset::DatasetOp>> ops); 116 117 std::vector<mindspore::dataset::TensorShape> ToTensorShapeVec(const std::vector<std::vector<int64_t>> &v); 118 std::vector<mindspore::dataset::DataType> ToDETypes(const std::vector<mindspore::DataType> &t); 119 mindspore::MSTensor ReadFileToTensor(const std::string &file); 120 std::string datasets_root_path_; 121 std::string mindrecord_root_path_; 122 void SetUp() override; 123 }; 124 } // namespace UT 125 126 namespace mindspore { 127 namespace dataset { 128 // defined in datasets.cc code, and function prototypes added here for UT purposes 129 // convert MSTensorVec to DE TensorRow, return empty if fails 130 TensorRow VecToRow(const MSTensorVec &v); 131 132 // defined in datasets.cc code, and function prototypes added here for UT purposes 133 // convert DE TensorRow to MSTensorVec, won't fail 134 MSTensorVec RowToVec(const TensorRow &v); 135 136 MSTensorVec Predicate1(MSTensorVec in); 137 138 MSTensorVec Predicate2(MSTensorVec in); 139 140 MSTensorVec Predicate3(MSTensorVec in); 141 142 cv::Mat BGRToRGB(const cv::Mat &img); 143 } // namespace dataset 144 } // namespace mindspore 145 #endif // TESTS_UT_CPP_DATASET_COMMON_COMMON_H_ 146