1 #include <gtest/gtest.h>
2
3 #include <caffe2/utils/threadpool/thread_pool_guard.h>
4 #include <caffe2/utils/threadpool/pthreadpool-cpp.h>
5
TEST(TestThreadPoolGuard,TestThreadPoolGuard)6 TEST(TestThreadPoolGuard, TestThreadPoolGuard) {
7 auto threadpool_ptr = caffe2::pthreadpool_();
8
9 ASSERT_NE(threadpool_ptr, nullptr);
10 {
11 caffe2::_NoPThreadPoolGuard g1;
12 auto threadpool_ptr1 = caffe2::pthreadpool_();
13 ASSERT_EQ(threadpool_ptr1, nullptr);
14
15 {
16 caffe2::_NoPThreadPoolGuard g2;
17 auto threadpool_ptr2 = caffe2::pthreadpool_();
18 ASSERT_EQ(threadpool_ptr2, nullptr);
19 }
20
21 // Guard should restore prev value (nullptr)
22 auto threadpool_ptr3 = caffe2::pthreadpool_();
23 ASSERT_EQ(threadpool_ptr3, nullptr);
24 }
25
26 // Guard should restore prev value (pthreadpool_)
27 auto threadpool_ptr4 = caffe2::pthreadpool_();
28 ASSERT_NE(threadpool_ptr4, nullptr);
29 ASSERT_EQ(threadpool_ptr4, threadpool_ptr);
30 }
31
TEST(TestThreadPoolGuard,TestRunWithGuard)32 TEST(TestThreadPoolGuard, TestRunWithGuard) {
33 const std::vector<int64_t> array = {1, 2, 3};
34
35 auto pool = caffe2::pthreadpool();
36 int64_t inner = 0;
37 {
38 // Run on same thread
39 caffe2::_NoPThreadPoolGuard g1;
40 auto fn = [&array, &inner](const size_t task_id) {
41 inner += array[task_id];
42 };
43 pool->run(fn, 3);
44
45 // confirm the guard is on
46 auto threadpool_ptr = caffe2::pthreadpool_();
47 ASSERT_EQ(threadpool_ptr, nullptr);
48 }
49 ASSERT_EQ(inner, 6);
50 }
51