1 #include <gtest/gtest.h>
2
3 #include <ATen/ATen.h>
4 #include <ATen/DLConvertor.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/ParallelFuture.h>
7
8 #include <iostream>
9 // NOLINTNEXTLINE(modernize-deprecated-headers)
10 #include <string.h>
11 #include <sstream>
12 #if AT_MKL_ENABLED()
13 #include <mkl.h>
14 #include <thread>
15 #endif
16
17 struct NumThreadsGuard {
18 int old_num_threads_;
NumThreadsGuardNumThreadsGuard19 NumThreadsGuard(int nthreads) {
20 old_num_threads_ = at::get_num_threads();
21 at::set_num_threads(nthreads);
22 }
23
~NumThreadsGuardNumThreadsGuard24 ~NumThreadsGuard() {
25 at::set_num_threads(old_num_threads_);
26 }
27 };
28
29 using namespace at;
30
TEST(TestParallel,TestParallel)31 TEST(TestParallel, TestParallel) {
32 manual_seed(123);
33 NumThreadsGuard guard(1);
34
35 Tensor a = rand({1, 3});
36 a[0][0] = 1;
37 a[0][1] = 0;
38 a[0][2] = 0;
39 Tensor as = rand({3});
40 as[0] = 1;
41 as[1] = 0;
42 as[2] = 0;
43 ASSERT_TRUE(a.sum(0).equal(as));
44 }
45
TEST(TestParallel,NestedParallel)46 TEST(TestParallel, NestedParallel) {
47 Tensor a = ones({1024, 1024});
48 auto expected = a.sum();
49 // check that calling sum() from within a parallel block computes the same result
50 at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
51 if (begin == 0) {
52 ASSERT_TRUE(a.sum().equal(expected));
53 }
54 });
55 }
56
57 #ifdef TH_BLAS_MKL
TEST(TestParallel,LocalMKLThreadNumber)58 TEST(TestParallel, LocalMKLThreadNumber) {
59 auto master_thread_num = mkl_get_max_threads();
60 auto f = [](int nthreads){
61 set_num_threads(nthreads);
62 };
63 std::thread t(f, 1);
64 t.join();
65 ASSERT_EQ(master_thread_num, mkl_get_max_threads());
66 }
67 #endif
68
TEST(TestParallel,NestedParallelThreadId)69 TEST(TestParallel, NestedParallelThreadId) {
70 // check that thread id within a nested parallel block is accurate
71 at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
72 at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
73 // Nested parallel regions execute on a single thread
74 ASSERT_EQ(begin, 0);
75 ASSERT_EQ(end, 10);
76
77 // Thread id reflects inner parallel region
78 ASSERT_EQ(at::get_thread_num(), 0);
79 });
80 });
81
82 at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
83 auto num_threads =
84 at::parallel_reduce(0, 10, 1, 0, [&](int64_t begin, int64_t end, int ident) {
85 // Thread id + 1 should always be 1
86 return at::get_thread_num() + 1;
87 }, std::plus<>{});
88 ASSERT_EQ(num_threads, 1);
89 });
90 }
91
TEST(TestParallel,Exceptions)92 TEST(TestParallel, Exceptions) {
93 // parallel case
94 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
95 ASSERT_THROW(
96 at::parallel_for(0, 10, 1, [&](int64_t begin, int64_t end) {
97 throw std::runtime_error("exception");
98 }),
99 std::runtime_error);
100
101 // non-parallel case
102 // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
103 ASSERT_THROW(
104 at::parallel_for(0, 1, 1000, [&](int64_t begin, int64_t end) {
105 throw std::runtime_error("exception");
106 }),
107 std::runtime_error);
108 }
109
TEST(TestParallel,IntraOpLaunchFuture)110 TEST(TestParallel, IntraOpLaunchFuture) {
111 int v1 = 0;
112 int v2 = 0;
113
114 auto fut1 = at::intraop_launch_future([&v1](){
115 v1 = 1;
116 });
117
118 auto fut2 = at::intraop_launch_future([&v2](){
119 v2 = 2;
120 });
121
122 fut1->wait();
123 fut2->wait();
124
125 ASSERT_TRUE(v1 == 1 && v2 == 2);
126 }
127