1
2 #include <torch/torch.h>
3 #include <ATen/record_function.h>
4
5 #include "c10/util/Flags.h"
6
7 #include <chrono>
8 #include <iostream>
9 #include <ctime>
10
11 C10_DEFINE_int(iter, 10000, "Number of iterations");
12 C10_DEFINE_int(sampled_iter, 10e6,
13 "Number of iterations for the sampled observer benchmark");
14
15 namespace {
16 const int kTensorSize = 16;
17 const int kSmallTensorSize = 1;
18 const float kLowSamplingProb = 0.0001;
19 }
20
addTestCallback(double sampling_prob=1.0,at::RecordFunctionCallback::StartCallback fn=[](const at::RecordFunction &)->std::unique_ptr<at::ObserverContext>{})21 void addTestCallback(
22 double sampling_prob = 1.0,
23 at::RecordFunctionCallback::StartCallback fn =
24 [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
25 auto cb = at::RecordFunctionCallback(
26 fn,
__anon4c64e09d0302(const at::RecordFunction&, at::ObserverContext*) 27 [](const at::RecordFunction&, at::ObserverContext*) {})
28 .needsInputs(false);
29 if (sampling_prob < 1.0) {
30 cb.samplingProb(sampling_prob);
31 }
32 at::addGlobalCallback(cb);
33 }
34
runTensorGEMMBench(int tensor_size,int iter)35 float runTensorGEMMBench(int tensor_size, int iter) {
36 typedef std::chrono::high_resolution_clock clock;
37 typedef std::chrono::microseconds us;
38 std::chrono::time_point<clock> start_time = clock::now();
39 auto inp = torch::randn({tensor_size, tensor_size});
40 for (auto idx = 0; idx < iter; ++idx) {
41 torch::mm(inp, inp);
42 }
43 auto duration = static_cast<float>(
44 std::chrono::duration_cast<us>(clock::now() - start_time).count());
45 return duration;
46 }
47
runPureRecordFunctionBench(int iter)48 float runPureRecordFunctionBench(int iter) {
49 typedef std::chrono::high_resolution_clock clock;
50 typedef std::chrono::microseconds us;
51 std::chrono::time_point<clock> start_time = clock::now();
52 for (auto idx = 0; idx < iter; ++idx) {
53 auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
54 if (step_callbacks.has_value()) {
55 at::RecordFunction guard(std::move(*step_callbacks));
56 guard.before("Test", -1);
57 }
58 }
59 auto duration = static_cast<float>(
60 std::chrono::duration_cast<us>(clock::now() - start_time).count());
61 return duration;
62 }
63
runBenchmark()64 void runBenchmark() {
65 float duration = 0;
66 for (auto tensor_size : std::set<int>({kSmallTensorSize, kTensorSize})) {
67 duration = runTensorGEMMBench(tensor_size, FLAGS_iter);
68 std::cout << "Tensor GEMM benchmark ("
69 << tensor_size
70 << "x"
71 << tensor_size
72 << ", " << FLAGS_iter << "): " << duration
73 << " us." << std::endl;
74 }
75 duration = runPureRecordFunctionBench(FLAGS_iter);
76 std::cout << "Pure RecordFunction benchmark ("
77 << FLAGS_iter << "): "
78 << duration
79 << " us." << std::endl;
80 }
81
main(int argc,char ** argv)82 int main(int argc, char** argv) {
83 if (!c10::ParseCommandLineFlags(&argc, &argv)) {
84 std::cout << "Failed to parse command line flags" << std::endl;
85 return -1;
86 }
87
88 at::enableRecordFunction();
89 at::clearCallbacks();
90
91 std::cout << "Warm up" << std::endl;
92 runBenchmark();
93
94 std::cout << "Running without observers" << std::endl;
95 runBenchmark();
96
97 addTestCallback();
98 std::cout << "Running with empty non-sampled observer" << std::endl;
99 runBenchmark();
100 at::clearCallbacks();
101
102 addTestCallback(kLowSamplingProb);
103 std::cout << "Running with empty sampled observer" << std::endl;
104 runBenchmark();
105 at::clearCallbacks();
106
107 std::cout << "Checking number of sampled observer invocations" << std::endl;
108 static int cb_count = 0;
109 addTestCallback(
110 kLowSamplingProb,
111 [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
112 ++cb_count;
113 return nullptr;
114 }
115 );
116
117 auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter);
118
119 std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter
120 << " iterations: " << duration
121 << " us, number of callback invocations: " << cb_count
122 << ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb)
123 << " invocations" << std::endl;
124
125 at::clearCallbacks();
126 return 0;
127 }
128