• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 
2 #include <torch/torch.h>
3 #include <ATen/record_function.h>
4 
5 #include "c10/util/Flags.h"
6 
7 #include <chrono>
8 #include <iostream>
9 #include <ctime>
10 
11 C10_DEFINE_int(iter, 10000, "Number of iterations");
12 C10_DEFINE_int(sampled_iter, 10e6,
13     "Number of iterations for the sampled observer benchmark");
14 
15 namespace {
16 const int kTensorSize = 16;
17 const int kSmallTensorSize = 1;
18 const float kLowSamplingProb = 0.0001;
19 }
20 
addTestCallback(double sampling_prob=1.0,at::RecordFunctionCallback::StartCallback fn=[](const at::RecordFunction &)->std::unique_ptr<at::ObserverContext>{})21 void addTestCallback(
22     double sampling_prob = 1.0,
23     at::RecordFunctionCallback::StartCallback fn =
24         [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
25   auto cb = at::RecordFunctionCallback(
26       fn,
__anon4c64e09d0302(const at::RecordFunction&, at::ObserverContext*) 27       [](const at::RecordFunction&, at::ObserverContext*) {})
28     .needsInputs(false);
29   if (sampling_prob < 1.0) {
30     cb.samplingProb(sampling_prob);
31   }
32   at::addGlobalCallback(cb);
33 }
34 
runTensorGEMMBench(int tensor_size,int iter)35 float runTensorGEMMBench(int tensor_size, int iter) {
36   typedef std::chrono::high_resolution_clock clock;
37   typedef std::chrono::microseconds us;
38   std::chrono::time_point<clock> start_time = clock::now();
39   auto inp = torch::randn({tensor_size, tensor_size});
40   for (auto idx = 0; idx < iter; ++idx) {
41     torch::mm(inp, inp);
42   }
43   auto duration = static_cast<float>(
44       std::chrono::duration_cast<us>(clock::now() - start_time).count());
45   return duration;
46 }
47 
runPureRecordFunctionBench(int iter)48 float runPureRecordFunctionBench(int iter) {
49   typedef std::chrono::high_resolution_clock clock;
50   typedef std::chrono::microseconds us;
51   std::chrono::time_point<clock> start_time = clock::now();
52   for (auto idx = 0; idx < iter; ++idx) {
53     auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
54     if (step_callbacks.has_value()) {
55       at::RecordFunction guard(std::move(*step_callbacks));
56       guard.before("Test", -1);
57     }
58   }
59   auto duration = static_cast<float>(
60       std::chrono::duration_cast<us>(clock::now() - start_time).count());
61   return duration;
62 }
63 
runBenchmark()64 void runBenchmark() {
65   float duration = 0;
66   for (auto tensor_size : std::set<int>({kSmallTensorSize, kTensorSize})) {
67     duration = runTensorGEMMBench(tensor_size, FLAGS_iter);
68     std::cout << "Tensor GEMM benchmark ("
69               << tensor_size
70               << "x"
71               << tensor_size
72               << ", " << FLAGS_iter << "): " << duration
73               << " us." << std::endl;
74   }
75   duration = runPureRecordFunctionBench(FLAGS_iter);
76   std::cout << "Pure RecordFunction benchmark ("
77             << FLAGS_iter << "): "
78             << duration
79             << " us." << std::endl;
80 }
81 
main(int argc,char ** argv)82 int main(int argc, char** argv) {
83   if (!c10::ParseCommandLineFlags(&argc, &argv)) {
84     std::cout << "Failed to parse command line flags" << std::endl;
85     return -1;
86   }
87 
88   at::enableRecordFunction();
89   at::clearCallbacks();
90 
91   std::cout << "Warm up" << std::endl;
92   runBenchmark();
93 
94   std::cout << "Running without observers" << std::endl;
95   runBenchmark();
96 
97   addTestCallback();
98   std::cout << "Running with empty non-sampled observer" << std::endl;
99   runBenchmark();
100   at::clearCallbacks();
101 
102   addTestCallback(kLowSamplingProb);
103   std::cout << "Running with empty sampled observer" << std::endl;
104   runBenchmark();
105   at::clearCallbacks();
106 
107   std::cout << "Checking number of sampled observer invocations" << std::endl;
108   static int cb_count = 0;
109   addTestCallback(
110       kLowSamplingProb,
111       [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
112         ++cb_count;
113         return nullptr;
114       }
115   );
116 
117   auto duration = runPureRecordFunctionBench(FLAGS_sampled_iter);
118 
119   std::cout << "Pure RecordFunction runtime of " << FLAGS_sampled_iter
120             << " iterations: " << duration
121             << " us, number of callback invocations: " << cb_count
122             << ", expected number: ~" << (int)(FLAGS_sampled_iter * kLowSamplingProb)
123             << " invocations" << std::endl;
124 
125   at::clearCallbacks();
126   return 0;
127 }
128