1 /** 2 * Copyright 2019 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #include "common/common.h" 18 #include "gtest/gtest.h" 19 #include "minddata/dataset/util/task_manager.h" 20 21 using namespace mindspore::dataset; 22 23 class MindDataTestTaskManager : public UT::Common { 24 public: 25 MindDataTestTaskManager() {} 26 27 void SetUp() { Services::CreateInstance(); } 28 }; 29 30 TEST_F(MindDataTestTaskManager, Test1) { 31 // Clear the rc of the master thread if any 32 (void)TaskManager::GetMasterThreadRc(); 33 TaskGroup vg; 34 Status vg_rc = vg.CreateAsyncTask("Test error", []() -> Status { 35 TaskManager::FindMe()->Post(); 36 throw std::bad_alloc(); 37 }); 38 ASSERT_TRUE(vg_rc.IsOk() || vg_rc == StatusCode::kMDOutOfMemory); 39 ASSERT_TRUE(vg.join_all().IsOk()); 40 ASSERT_TRUE(vg.GetTaskErrorIfAny() == StatusCode::kMDOutOfMemory); 41 // Test the error is passed back to the master thread if vg_rc above is OK. 42 // If vg_rc is kOutOfMemory, the group error is already passed back. 43 // Some compiler may choose to run the next line in parallel with the above 3 lines 44 // and this will cause some mismatch once a while. 45 // To block this racing condition, we need to create a dependency that the next line 46 // depends on previous lines. 47 if (vg.GetTaskErrorIfAny().IsError() && vg_rc.IsOk()) { 48 Status rc = TaskManager::GetMasterThreadRc(); 49 ASSERT_TRUE(rc == StatusCode::kMDOutOfMemory); 50 } 51 } 52 53 TEST_F(MindDataTestTaskManager, Test2) { 54 // This testcase will spawn about 100 threads and block on a conditional variable. 55 // The master thread will try to interrupt them almost at the same time. This can 56 // cause a racing condition that some threads may miss the interrupt and blocked. 57 // The new logic of Task::Join() will do a time-out join and wake up all those 58 // threads that miss the interrupt. 59 // Clear the rc of the master thread if any 60 (void)TaskManager::GetMasterThreadRc(); 61 TaskGroup vg; 62 CondVar cv; 63 std::mutex mux; 64 Status rc; 65 rc = cv.Register(vg.GetIntrpService()); 66 EXPECT_TRUE(rc.IsOk()); 67 auto block_forever = [&cv, &mux]() -> Status { 68 std::unique_lock<std::mutex> lck(mux); 69 TaskManager::FindMe()->Post(); 70 std::this_thread::sleep_for(std::chrono::milliseconds(1)); 71 RETURN_IF_NOT_OK(cv.Wait(&lck, []() -> bool { return false; })); 72 return Status::OK(); 73 }; 74 auto f = [&vg, &block_forever]() -> Status { 75 for (auto i = 0; i < 100; ++i) { 76 RETURN_IF_NOT_OK(vg.CreateAsyncTask("Spawn block threads", block_forever)); 77 } 78 return Status::OK(); 79 }; 80 rc = f(); 81 vg.interrupt_all(); 82 EXPECT_TRUE(rc.IsOk()); 83 // Now we test the async Join 84 ASSERT_TRUE(vg.join_all(Task::WaitFlag::kNonBlocking).IsOk()); 85 } 86