• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <benchmark/benchmark.h>
2 #include <torch/csrc/jit/runtime/static/impl.h>
3 #include "deep_wide_pt.h"
4 
5 const int embedding_size = 32;
6 const int num_features = 50;
7 
8 using namespace torch;
9 
BM_deep_wide_base(benchmark::State & state)10 static void BM_deep_wide_base(benchmark::State& state) {
11   std::shared_ptr<DeepAndWide> net =
12       std::make_shared<DeepAndWide>(num_features);
13 
14   const int batch_size = state.range(0);
15   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
16   auto user_emb = torch::randn({batch_size, 1, embedding_size});
17   auto wide = torch::randn({batch_size, num_features});
18   // warmup
19   net->forward(ad_emb_packed, user_emb, wide);
20   for (auto _ : state) {
21     net->forward(ad_emb_packed, user_emb, wide);
22   }
23 }
24 
BM_deep_wide_fast(benchmark::State & state)25 static void BM_deep_wide_fast(benchmark::State& state) {
26   std::shared_ptr<DeepAndWideFast> net =
27       std::make_shared<DeepAndWideFast>(num_features);
28 
29   const int batch_size = state.range(0);
30   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
31   auto user_emb = torch::randn({batch_size, 1, embedding_size});
32   auto wide = torch::randn({batch_size, num_features});
33   // warmup
34   net->forward(ad_emb_packed, user_emb, wide);
35   for (auto _ : state) {
36     net->forward(ad_emb_packed, user_emb, wide);
37   }
38 }
39 
BM_deep_wide_jit_graph_executor(benchmark::State & state)40 static void BM_deep_wide_jit_graph_executor(benchmark::State& state) {
41   auto mod = getDeepAndWideSciptModel();
42 
43   const int batch_size = state.range(0);
44   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
45   auto user_emb = torch::randn({batch_size, 1, embedding_size});
46   auto wide = torch::randn({batch_size, num_features});
47 
48   std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
49 
50   TORCH_CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
51 
52   mod.forward(inputs);
53   for (auto _ : state) {
54     mod.forward(inputs);
55   }
56 }
57 
BM_deep_wide_jit_profiling_executor(benchmark::State & state)58 static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
59   auto mod = getDeepAndWideSciptModel();
60 
61   const int batch_size = state.range(0);
62   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
63   auto user_emb = torch::randn({batch_size, 1, embedding_size});
64   auto wide = torch::randn({batch_size, num_features});
65 
66   std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
67 
68   TORCH_CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
69 
70   mod.forward(inputs);
71   for (auto _ : state) {
72     mod.forward(inputs);
73   }
74 }
75 
BM_deep_wide_static(benchmark::State & state)76 static void BM_deep_wide_static(benchmark::State& state) {
77   auto mod = getDeepAndWideSciptModel();
78   torch::jit::StaticModule smod(mod);
79 
80   const int batch_size = state.range(0);
81   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
82   auto user_emb = torch::randn({batch_size, 1, embedding_size});
83   auto wide = torch::randn({batch_size, num_features});
84 
85   std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
86 
87   smod(inputs, {});
88   for (auto _ : state) {
89     smod(inputs, {});
90   }
91 }
92 
getStaticModule()93 std::shared_ptr<torch::jit::StaticModule> getStaticModule() {
94   static auto smod =
95       std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
96   return smod;
97 }
98 
BM_deep_wide_static_threaded(benchmark::State & state)99 static void BM_deep_wide_static_threaded(benchmark::State& state) {
100   auto sm = getStaticModule();
101   torch::jit::StaticRuntime sr(*sm);
102 
103   const int batch_size = 1; // state.range(0);
104   auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
105   auto user_emb = torch::randn({batch_size, 1, embedding_size});
106   auto wide = torch::randn({batch_size, num_features});
107 
108   std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
109 
110   sr(inputs, {});
111   for (auto _ : state) {
112     sr(inputs, {});
113   }
114 }
115 
BM_leaky_relu_const(benchmark::State & state)116 static void BM_leaky_relu_const(benchmark::State& state) {
117   auto mod = getLeakyReLUConstScriptModel();
118   torch::jit::StaticModule smod(mod);
119 
120   const int batch_size = state.range(0);
121   auto data = torch::randn({batch_size, num_features});
122   std::vector<c10::IValue> inputs({data});
123 
124   smod(inputs, {});
125   for (auto _ : state) {
126     smod(inputs, {});
127   }
128 }
129 
BM_leaky_relu(benchmark::State & state)130 static void BM_leaky_relu(benchmark::State& state) {
131   auto mod = getLeakyReLUScriptModel();
132   torch::jit::StaticModule smod(mod);
133 
134   const int batch_size = state.range(0);
135   auto neg_slope = torch::randn(1);
136   auto data = torch::randn({batch_size, num_features});
137   std::vector<c10::IValue> inputs({data, neg_slope[0]});
138 
139   smod(inputs, {});
140   for (auto _ : state) {
141     smod(inputs, {});
142   }
143 }
144 
145 BENCHMARK(BM_leaky_relu)->RangeMultiplier(8)->Ranges({{1, 20}});
146 BENCHMARK(BM_leaky_relu_const)->RangeMultiplier(8)->Ranges({{1, 20}});
147 
BM_signed_log1p(benchmark::State & state)148 static void BM_signed_log1p(benchmark::State& state) {
149   auto mod = getSignedLog1pModel();
150   torch::jit::StaticModule smod(mod);
151 
152   const int num_elements = state.range(0);
153   auto data = torch::randn({num_elements});
154   std::vector<c10::IValue> inputs({data});
155 
156   smod(inputs, {});
157   for (auto _ : state) {
158     smod(inputs, {});
159   }
160 }
161 
162 BENCHMARK(BM_signed_log1p)->RangeMultiplier(8)->Ranges({{16, 65536}});
163 
BM_long_static_memory_optimization(benchmark::State & state)164 static void BM_long_static_memory_optimization(benchmark::State& state) {
165   auto mod = getLongScriptModel();
166   torch::jit::StaticModuleOptions opts;
167   opts.optimize_memory = state.range(1);
168   torch::jit::StaticModule smod(mod, false, opts);
169 
170   const auto N = state.range(0);
171   auto a = torch::randn({N, N});
172   auto b = torch::randn({N, N});
173   auto c = torch::randn({N, N});
174   std::vector<c10::IValue> inputs({a, b, c});
175 
176   smod(inputs, {});
177   for (auto _ : state) {
178     smod(inputs, {});
179   }
180 }
181 
182 BENCHMARK(BM_deep_wide_base)->RangeMultiplier(8)->Ranges({{1, 20}});
183 BENCHMARK(BM_deep_wide_fast)->RangeMultiplier(8)->Ranges({{1, 20}});
184 
185 BENCHMARK(BM_deep_wide_jit_graph_executor)
186     ->RangeMultiplier(8)
187     ->Ranges({{1, 20}});
188 
189 BENCHMARK(BM_deep_wide_jit_profiling_executor)
190     ->RangeMultiplier(8)
191     ->Ranges({{1, 20}});
192 
193 BENCHMARK(BM_deep_wide_static)->RangeMultiplier(8)->Ranges({{1, 20}});
194 BENCHMARK(BM_deep_wide_static_threaded)->Threads(8);
195 
196 BENCHMARK(BM_long_static_memory_optimization)
197     ->Args({2 << 0, 0})
198     ->Args({2 << 2, 0})
199     ->Args({2 << 4, 0})
200     ->Args({2 << 8, 0})
201     ->Args({2 << 0, 1})
202     ->Args({2 << 2, 1})
203     ->Args({2 << 4, 1})
204     ->Args({2 << 8, 1});
205 
main(int argc,char ** argv)206 int main(int argc, char** argv) {
207   c10::ParseCommandLineFlags(&argc, &argv);
208   ::benchmark::Initialize(&argc, argv);
209   ::benchmark::RunSpecifiedBenchmarks();
210 }
211