1 #pragma once
2 
3 #include <gtest/gtest.h>
4 #include <gmock/gmock.h>
5 
6 #include <ATen/core/Tensor.h>
7 #include <ATen/core/dispatch/Dispatcher.h>
8 #include <ATen/core/ivalue.h>
9 #include <c10/core/CPUAllocator.h>
10 #include <c10/util/irange.h>
11 
12 template<class... Inputs>
makeStack(Inputs &&...inputs)13 inline std::vector<c10::IValue> makeStack(Inputs&&... inputs) {
14   return {std::forward<Inputs>(inputs)...};
15 }
16 
17 inline at::Tensor dummyTensor(c10::DispatchKeySet ks, bool requires_grad=false) {
18   auto* allocator = c10::GetCPUAllocator();
19   int64_t nelements = 1;
20   auto dtype = caffe2::TypeMeta::Make<float>();
21   int64_t size_bytes = nelements * dtype.itemsize();
22   auto storage_impl = c10::make_intrusive<c10::StorageImpl>(
23       c10::StorageImpl::use_byte_size_t(),
24       size_bytes,
25       allocator->allocate(size_bytes),
26       allocator,
27       /*resizable=*/true);
28   at::Tensor t = at::detail::make_tensor<c10::TensorImpl>(storage_impl, ks, dtype);
29   // TODO: We add this to simulate the ideal case where we only have Autograd backend keys
30   //       on Tensor when it requires grad. But currently Autograd keys are added in TensorImpl
31   //       constructor by default.
32   if (!requires_grad) {
33     t.unsafeGetTensorImpl()->remove_autograd_key();
34   }
35   return t;
36 }
37 
38 inline at::Tensor dummyTensor(c10::DispatchKey dispatch_key, bool requires_grad=false) {
39   return dummyTensor(c10::DispatchKeySet(dispatch_key), requires_grad);
40 }
41 
42 template<class... Args>
callOp(const c10::OperatorHandle & op,Args...args)43 inline std::vector<c10::IValue> callOp(const c10::OperatorHandle& op, Args... args) {
44   auto stack = makeStack(std::forward<Args>(args)...);
45   op.callBoxed(&stack);
46   return stack;
47 }
48 
49 template<class Result, class... Args>
callOpUnboxed(const c10::OperatorHandle & op,Args...args)50 inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) {
51   return op.typed<Result(Args...)>().call(std::forward<Args>(args)...);
52 }
53 
54 template<class Result, class... Args>
callOpUnboxedWithDispatchKey(const c10::OperatorHandle & op,c10::DispatchKey dispatchKey,Args...args)55 inline Result callOpUnboxedWithDispatchKey(const c10::OperatorHandle& op, c10::DispatchKey dispatchKey, Args... args) {
56   return op.typed<Result(Args...)>().callWithDispatchKey(dispatchKey, std::forward<Args>(args)...);
57 }
58 
59 template<class Result, class... Args>
callOpUnboxedWithPrecomputedDispatchKeySet(const c10::OperatorHandle & op,c10::DispatchKeySet ks,Args...args)60 inline Result callOpUnboxedWithPrecomputedDispatchKeySet(const c10::OperatorHandle& op, c10::DispatchKeySet ks, Args... args) {
61   return op.typed<Result(Args...)>().redispatch(ks, std::forward<Args>(args)...);
62 }
63 
expectDoesntFindKernel(const char * op_name,c10::DispatchKey dispatch_key)64 inline void expectDoesntFindKernel(const char* op_name, c10::DispatchKey dispatch_key) {
65   auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
66   EXPECT_ANY_THROW(
67     callOp(*op, dummyTensor(dispatch_key), 5);
68   );
69 }
70 
expectDoesntFindOperator(const char * op_name)71 inline void expectDoesntFindOperator(const char* op_name) {
72   auto op = c10::Dispatcher::singleton().findSchema({op_name, ""});
73   EXPECT_FALSE(op.has_value());
74 }
75 
76 template<class Exception, class Functor>
expectThrows(Functor && functor,const char * expectMessageContains)77 inline void expectThrows(Functor&& functor, const char* expectMessageContains) {
78   try {
79     std::forward<Functor>(functor)();
80   } catch (const Exception& e) {
81     EXPECT_THAT(e.what(), testing::HasSubstr(expectMessageContains));
82     return;
83   }
84   ADD_FAILURE() << "Expected to throw exception containing \""
85     << expectMessageContains << "\" but didn't throw";
86 }
87 
88 template<class T, size_t N>
expectListEquals(c10::ArrayRef<T> expected,std::array<T,N> actual)89 void expectListEquals(c10::ArrayRef<T> expected, std::array<T, N> actual) {
90   EXPECT_EQ(expected.size(), actual.size());
91   for (const auto i : c10::irange(expected.size())) {
92     EXPECT_EQ(expected[i], actual[i]);
93   }
94 }
95 
96 template<class T>
expectListEquals(c10::ArrayRef<T> expected,c10::ArrayRef<T> actual)97 void expectListEquals(c10::ArrayRef<T> expected, c10::ArrayRef<T> actual) {
98   EXPECT_EQ(expected.size(), actual.size());
99   for (const auto i : c10::irange(expected.size())) {
100     EXPECT_EQ(expected[i], actual[i]);
101   }
102 }
103 
104 template<class T>
expectListEquals(c10::ArrayRef<T> expected,c10::List<T> actual)105 void expectListEquals(c10::ArrayRef<T> expected, c10::List<T> actual) {
106   EXPECT_EQ(expected.size(), actual.size());
107   for (const auto i : c10::irange(expected.size())) {
108     EXPECT_EQ(expected[i], actual.get(i));
109   }
110 }
111 
112 template<class T>
expectListEquals(c10::ArrayRef<T> expected,std::vector<T> actual)113 void expectListEquals(c10::ArrayRef<T> expected, std::vector<T> actual) {
114   EXPECT_EQ(expected.size(), actual.size());
115   for (const auto i : c10::irange(expected.size())) {
116     EXPECT_EQ(expected[i], actual[i]);
117   }
118 }
119 
120 // NB: This is not really sound, but all of the type sets constructed here
121 // are singletons so it's fine
extractDispatchKey(const at::Tensor & t)122 static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) {
123   return legacyExtractDispatchKey(t.key_set());
124 }
125