• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <gtest/gtest.h>
2 
3 #include <ATen/ATen.h>
4 #include <ATen/DLConvertor.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/ParallelFuture.h>
7 
8 #include <iostream>
9 // NOLINTNEXTLINE(modernize-deprecated-headers)
10 #include <string.h>
11 #include <sstream>
12 #if AT_MKL_ENABLED()
13 #include <mkl.h>
14 #include <thread>
15 #endif
16 
17 struct NumThreadsGuard {
18   int old_num_threads_;
NumThreadsGuardNumThreadsGuard19   NumThreadsGuard(int nthreads) {
20     old_num_threads_ = at::get_num_threads();
21     at::set_num_threads(nthreads);
22   }
23 
~NumThreadsGuardNumThreadsGuard24   ~NumThreadsGuard() {
25     at::set_num_threads(old_num_threads_);
26   }
27 };
28 
29 using namespace at;
30 
TEST(TestParallel,TestParallel)31 TEST(TestParallel, TestParallel) {
32   manual_seed(123);
33   NumThreadsGuard guard(1);
34 
35   Tensor a = rand({1, 3});
36   a[0][0] = 1;
37   a[0][1] = 0;
38   a[0][2] = 0;
39   Tensor as = rand({3});
40   as[0] = 1;
41   as[1] = 0;
42   as[2] = 0;
43   ASSERT_TRUE(a.sum(0).equal(as));
44 }
45 
TEST(TestParallel,NestedParallel)46 TEST(TestParallel, NestedParallel) {
47   Tensor a = ones({1024, 1024});
48   auto expected = a.sum();
49   // check that calling sum() from within a parallel block computes the same result
50   at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
51     if (begin == 0) {
52       ASSERT_TRUE(a.sum().equal(expected));
53     }
54   });
55 }
56 
57 #ifdef TH_BLAS_MKL
TEST(TestParallel,LocalMKLThreadNumber)58 TEST(TestParallel, LocalMKLThreadNumber) {
59   auto master_thread_num = mkl_get_max_threads();
60   auto f = [](int nthreads){
61     set_num_threads(nthreads);
62   };
63   std::thread t(f, 1);
64   t.join();
65   ASSERT_EQ(master_thread_num, mkl_get_max_threads());
66 }
67 #endif
68 
TEST(TestParallel,NestedParallelThreadId)69 TEST(TestParallel, NestedParallelThreadId) {
70   // check that thread id within a nested parallel block is accurate
71   at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
72     at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
73       // Nested parallel regions execute on a single thread
74       ASSERT_EQ(begin, 0);
75       ASSERT_EQ(end, 10);
76 
77       // Thread id reflects inner parallel region
78       ASSERT_EQ(at::get_thread_num(), 0);
79     });
80   });
81 
82   at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
83     auto num_threads =
84       at::parallel_reduce(0, 10, 1, 0, [&](int64_t begin, int64_t end, int ident) {
85         // Thread id + 1 should always be 1
86         return at::get_thread_num() + 1;
87       }, std::plus<>{});
88     ASSERT_EQ(num_threads, 1);
89   });
90 }
91 
TEST(TestParallel,Exceptions)92 TEST(TestParallel, Exceptions) {
93   // parallel case
94   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
95   ASSERT_THROW(
96     at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
97       throw std::runtime_error("exception");
98     }),
99     std::runtime_error);
100 
101   // non-parallel case
102   // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
103   ASSERT_THROW(
104     at::parallel_for(0, 1, 1000, [&](int64_t begin, int64_t end) {
105       throw std::runtime_error("exception");
106     }),
107     std::runtime_error);
108 }
109 
TEST(TestParallel,IntraOpLaunchFuture)110 TEST(TestParallel, IntraOpLaunchFuture) {
111   int v1 = 0;
112   int v2 = 0;
113 
114   auto fut1 = at::intraop_launch_future([&v1](){
115     v1 = 1;
116   });
117 
118   auto fut2 = at::intraop_launch_future([&v2](){
119     v2 = 2;
120   });
121 
122   fut1->wait();
123   fut2->wait();
124 
125   ASSERT_TRUE(v1 == 1 && v2 == 2);
126 }
127