1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef TESTS_UT_CPP_DATASET_COMMON_COMMON_H_ 17 #define TESTS_UT_CPP_DATASET_COMMON_COMMON_H_ 18 19 #include "gtest/gtest.h" 20 #include "include/api/status.h" 21 #include "include/api/types.h" 22 #include "minddata/dataset/core/client.h" 23 #include "minddata/dataset/core/tensor_shape.h" 24 #include "minddata/dataset/core/de_tensor.h" 25 #include "minddata/dataset/core/type_id.h" 26 #include "utils/log_adapter.h" 27 #include "minddata/dataset/engine/datasetops/batch_op.h" 28 #include "minddata/dataset/engine/datasetops/repeat_op.h" 29 #include "minddata/dataset/engine/datasetops/source/tf_reader_op.h" 30 31 using mindspore::Status; 32 using mindspore::StatusCode; 33 34 #define ASSERT_OK(_s) \ 35 do { \ 36 Status __rc = (_s); \ 37 if (__rc.IsError()) { \ 38 MS_LOG(ERROR) << __rc.ToString() << "."; \ 39 ASSERT_TRUE(false); \ 40 } \ 41 } while (false) 42 43 #define EXPECT_OK(_s) \ 44 do { \ 45 Status __rc = (_s); \ 46 if (__rc.IsError()) { \ 47 MS_LOG(ERROR) << __rc.ToString() << "."; \ 48 EXPECT_TRUE(false); \ 49 } \ 50 } while (false) 51 52 #define ASSERT_ERROR(_s) \ 53 do { \ 54 Status __rc = (_s); \ 55 if (__rc.IsOk()) { \ 56 MS_LOG(ERROR) << __rc.ToString() << "."; \ 57 ASSERT_TRUE(false); \ 58 } \ 59 } while (false) 60 61 #define EXPECT_ERROR(_s) \ 62 do { \ 63 Status __rc = (_s); \ 64 if (__rc.IsOk()) { \ 65 MS_LOG(ERROR) << __rc.ToString() << "."; \ 66 EXPECT_TRUE(false); \ 67 } \ 68 } while (false) 69 70 // Macro to compare 2 MSTensors; compare shape-size, shape and data 71 #define EXPECT_MSTENSOR_EQ(_mstensor1, _mstensor2) \ 72 do { \ 73 EXPECT_EQ(_mstensor1.Shape().size(), _mstensor2.Shape().size()); \ 74 for (int i = 0; i < _mstensor1.Shape().size(); i++) { \ 75 EXPECT_EQ(_mstensor1.Shape()[i], _mstensor2.Shape()[i]); \ 76 } \ 77 EXPECT_EQ(_mstensor1.DataSize(), _mstensor2.DataSize()); \ 78 EXPECT_EQ(std::memcmp((const void *)_mstensor1.Data().get(), (const void *)_mstensor2.Data().get(), \ 79 _mstensor2.DataSize()), \ 80 0); \ 81 } while (false) 82 83 // Macro to invoke MS_LOG for MSTensor 84 #define TEST_MS_LOG_MSTENSOR(_loglevel, _msg, _mstensor) \ 85 do { \ 86 std::shared_ptr<Tensor> _de_tensor; \ 87 ASSERT_OK(Tensor::CreateFromMSTensor(_mstensor, &_de_tensor)); \ 88 MS_LOG(_loglevel) << _msg << *_de_tensor; \ 89 } while (false) 90 91 namespace UT { 92 class Common : public testing::Test { 93 public: 94 // every TEST_F macro will enter one 95 virtual void SetUp(); 96 97 virtual void TearDown(); 98 }; 99 100 class DatasetOpTesting : public Common { 101 public: 102 // Helper functions for creating datasets 103 std::shared_ptr<mindspore::dataset::BatchOp> Batch(int32_t batch_size = 1, bool drop = false, 104 mindspore::dataset::PadInfo = {}); 105 106 std::shared_ptr<mindspore::dataset::RepeatOp> Repeat(int repeat_cnt = 1); 107 108 std::shared_ptr<mindspore::dataset::TFReaderOp> TFReader(std::string file, int num_works = 8); 109 110 std::shared_ptr<mindspore::dataset::ExecutionTree> Build( 111 std::vector<std::shared_ptr<mindspore::dataset::DatasetOp>> ops); 112 113 std::vector<mindspore::dataset::TensorShape> ToTensorShapeVec(const std::vector<std::vector<int64_t>> &v); 114 std::vector<mindspore::dataset::DataType> ToDETypes(const std::vector<mindspore::DataType> &t); 115 mindspore::MSTensor ReadFileToTensor(const std::string &file); 116 std::string datasets_root_path_; 117 std::string mindrecord_root_path_; 118 void SetUp() override; 119 }; 120 } // namespace UT 121 #endif // TESTS_UT_CPP_DATASET_COMMON_COMMON_H_ 122