1 #include <benchmark/benchmark.h>
2 #include <torch/csrc/jit/runtime/static/impl.h>
3 #include "deep_wide_pt.h"
4
5 const int embedding_size = 32;
6 const int num_features = 50;
7
8 using namespace torch;
9
BM_deep_wide_base(benchmark::State & state)10 static void BM_deep_wide_base(benchmark::State& state) {
11 std::shared_ptr<DeepAndWide> net =
12 std::make_shared<DeepAndWide>(num_features);
13
14 const int batch_size = state.range(0);
15 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
16 auto user_emb = torch::randn({batch_size, 1, embedding_size});
17 auto wide = torch::randn({batch_size, num_features});
18 // warmup
19 net->forward(ad_emb_packed, user_emb, wide);
20 for (auto _ : state) {
21 net->forward(ad_emb_packed, user_emb, wide);
22 }
23 }
24
BM_deep_wide_fast(benchmark::State & state)25 static void BM_deep_wide_fast(benchmark::State& state) {
26 std::shared_ptr<DeepAndWideFast> net =
27 std::make_shared<DeepAndWideFast>(num_features);
28
29 const int batch_size = state.range(0);
30 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
31 auto user_emb = torch::randn({batch_size, 1, embedding_size});
32 auto wide = torch::randn({batch_size, num_features});
33 // warmup
34 net->forward(ad_emb_packed, user_emb, wide);
35 for (auto _ : state) {
36 net->forward(ad_emb_packed, user_emb, wide);
37 }
38 }
39
BM_deep_wide_jit_graph_executor(benchmark::State & state)40 static void BM_deep_wide_jit_graph_executor(benchmark::State& state) {
41 auto mod = getDeepAndWideSciptModel();
42
43 const int batch_size = state.range(0);
44 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
45 auto user_emb = torch::randn({batch_size, 1, embedding_size});
46 auto wide = torch::randn({batch_size, num_features});
47
48 std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
49
50 TORCH_CHECK_EQ(setenv("TORCH_JIT_DISABLE_NEW_EXECUTOR", "1", 1), 0);
51
52 mod.forward(inputs);
53 for (auto _ : state) {
54 mod.forward(inputs);
55 }
56 }
57
BM_deep_wide_jit_profiling_executor(benchmark::State & state)58 static void BM_deep_wide_jit_profiling_executor(benchmark::State& state) {
59 auto mod = getDeepAndWideSciptModel();
60
61 const int batch_size = state.range(0);
62 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
63 auto user_emb = torch::randn({batch_size, 1, embedding_size});
64 auto wide = torch::randn({batch_size, num_features});
65
66 std::vector<IValue> inputs({ad_emb_packed, user_emb, wide});
67
68 TORCH_CHECK_EQ(unsetenv("TORCH_JIT_DISABLE_NEW_EXECUTOR"), 0);
69
70 mod.forward(inputs);
71 for (auto _ : state) {
72 mod.forward(inputs);
73 }
74 }
75
BM_deep_wide_static(benchmark::State & state)76 static void BM_deep_wide_static(benchmark::State& state) {
77 auto mod = getDeepAndWideSciptModel();
78 torch::jit::StaticModule smod(mod);
79
80 const int batch_size = state.range(0);
81 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
82 auto user_emb = torch::randn({batch_size, 1, embedding_size});
83 auto wide = torch::randn({batch_size, num_features});
84
85 std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
86
87 smod(inputs, {});
88 for (auto _ : state) {
89 smod(inputs, {});
90 }
91 }
92
getStaticModule()93 std::shared_ptr<torch::jit::StaticModule> getStaticModule() {
94 static auto smod =
95 std::make_shared<torch::jit::StaticModule>(getDeepAndWideSciptModel());
96 return smod;
97 }
98
BM_deep_wide_static_threaded(benchmark::State & state)99 static void BM_deep_wide_static_threaded(benchmark::State& state) {
100 auto sm = getStaticModule();
101 torch::jit::StaticRuntime sr(*sm);
102
103 const int batch_size = 1; // state.range(0);
104 auto ad_emb_packed = torch::randn({batch_size, 1, embedding_size});
105 auto user_emb = torch::randn({batch_size, 1, embedding_size});
106 auto wide = torch::randn({batch_size, num_features});
107
108 std::vector<c10::IValue> inputs({ad_emb_packed, user_emb, wide});
109
110 sr(inputs, {});
111 for (auto _ : state) {
112 sr(inputs, {});
113 }
114 }
115
BM_leaky_relu_const(benchmark::State & state)116 static void BM_leaky_relu_const(benchmark::State& state) {
117 auto mod = getLeakyReLUConstScriptModel();
118 torch::jit::StaticModule smod(mod);
119
120 const int batch_size = state.range(0);
121 auto data = torch::randn({batch_size, num_features});
122 std::vector<c10::IValue> inputs({data});
123
124 smod(inputs, {});
125 for (auto _ : state) {
126 smod(inputs, {});
127 }
128 }
129
BM_leaky_relu(benchmark::State & state)130 static void BM_leaky_relu(benchmark::State& state) {
131 auto mod = getLeakyReLUScriptModel();
132 torch::jit::StaticModule smod(mod);
133
134 const int batch_size = state.range(0);
135 auto neg_slope = torch::randn(1);
136 auto data = torch::randn({batch_size, num_features});
137 std::vector<c10::IValue> inputs({data, neg_slope[0]});
138
139 smod(inputs, {});
140 for (auto _ : state) {
141 smod(inputs, {});
142 }
143 }
144
145 BENCHMARK(BM_leaky_relu)->RangeMultiplier(8)->Ranges({{1, 20}});
146 BENCHMARK(BM_leaky_relu_const)->RangeMultiplier(8)->Ranges({{1, 20}});
147
BM_signed_log1p(benchmark::State & state)148 static void BM_signed_log1p(benchmark::State& state) {
149 auto mod = getSignedLog1pModel();
150 torch::jit::StaticModule smod(mod);
151
152 const int num_elements = state.range(0);
153 auto data = torch::randn({num_elements});
154 std::vector<c10::IValue> inputs({data});
155
156 smod(inputs, {});
157 for (auto _ : state) {
158 smod(inputs, {});
159 }
160 }
161
162 BENCHMARK(BM_signed_log1p)->RangeMultiplier(8)->Ranges({{16, 65536}});
163
BM_long_static_memory_optimization(benchmark::State & state)164 static void BM_long_static_memory_optimization(benchmark::State& state) {
165 auto mod = getLongScriptModel();
166 torch::jit::StaticModuleOptions opts;
167 opts.optimize_memory = state.range(1);
168 torch::jit::StaticModule smod(mod, false, opts);
169
170 const auto N = state.range(0);
171 auto a = torch::randn({N, N});
172 auto b = torch::randn({N, N});
173 auto c = torch::randn({N, N});
174 std::vector<c10::IValue> inputs({a, b, c});
175
176 smod(inputs, {});
177 for (auto _ : state) {
178 smod(inputs, {});
179 }
180 }
181
182 BENCHMARK(BM_deep_wide_base)->RangeMultiplier(8)->Ranges({{1, 20}});
183 BENCHMARK(BM_deep_wide_fast)->RangeMultiplier(8)->Ranges({{1, 20}});
184
185 BENCHMARK(BM_deep_wide_jit_graph_executor)
186 ->RangeMultiplier(8)
187 ->Ranges({{1, 20}});
188
189 BENCHMARK(BM_deep_wide_jit_profiling_executor)
190 ->RangeMultiplier(8)
191 ->Ranges({{1, 20}});
192
193 BENCHMARK(BM_deep_wide_static)->RangeMultiplier(8)->Ranges({{1, 20}});
194 BENCHMARK(BM_deep_wide_static_threaded)->Threads(8);
195
196 BENCHMARK(BM_long_static_memory_optimization)
197 ->Args({2 << 0, 0})
198 ->Args({2 << 2, 0})
199 ->Args({2 << 4, 0})
200 ->Args({2 << 8, 0})
201 ->Args({2 << 0, 1})
202 ->Args({2 << 2, 1})
203 ->Args({2 << 4, 1})
204 ->Args({2 << 8, 1});
205
main(int argc,char ** argv)206 int main(int argc, char** argv) {
207 c10::ParseCommandLineFlags(&argc, &argv);
208 ::benchmark::Initialize(&argc, argv);
209 ::benchmark::RunSpecifiedBenchmarks();
210 }
211