1 #include <cstdint>
2 #include <cstdio>
3 #include <cstdlib>
4 #include <cstring>
5 #include <type_traits>
6
7 #include "ruy/context.h"
8 #include "ruy/matrix.h"
9 #include "ruy/mul_params.h"
10 #include "ruy/ruy.h"
11
12 template <typename... Dst>
read_cmdline_args(bool help,int argc,char * argv[],const char * name,const char * format,const char * default_value,const char * allowed_values,Dst...dst)13 void read_cmdline_args(bool help, int argc, char* argv[], const char* name,
14 const char* format, const char* default_value,
15 const char* allowed_values, Dst... dst) {
16 if (help) {
17 fprintf(stderr, "%-20s %-12s %-16s %s\n", name, format, default_value,
18 allowed_values ? allowed_values : "");
19 return;
20 }
21 const char* value = default_value;
22 for (int i = 1; i < argc; i++) {
23 if (std::strstr(argv[i], name) == argv[i]) {
24 const char* equal_sign = std::strchr(argv[i], '=');
25 if (equal_sign == argv[i] + std::strlen(name)) {
26 value = equal_sign + 1;
27 }
28 break;
29 }
30 }
31 if (allowed_values) {
32 if (!std::strstr(allowed_values, value)) {
33 fprintf(stderr, "Illegal value %s. The legal values are %s.\n", value,
34 allowed_values);
35 exit(1);
36 }
37 }
38 if (sizeof...(Dst) != sscanf(value, format, dst...)) {
39 fprintf(stderr, "Failed to parse %s\n", value);
40 exit(1);
41 }
42 }
43
44 struct Params {
45 char types[100];
46 int m, k, n; // matmul shape m*k*n
47 int paths;
48 int num_threads;
49 int repeat;
50 int lhs_cache_policy;
51 int rhs_cache_policy;
52 int lhs_stride;
53 int rhs_stride;
54 int dst_stride;
55 int lhs_zero_point;
56 int rhs_zero_point;
57 int dst_zero_point;
58 char lhs_order[100];
59 char rhs_order[100];
60 char dst_order[100];
61 };
62
63 template <typename LhsType, typename RhsType, typename DstType>
run(const Params & params)64 void run(const Params& params) {
65 using AccumType =
66 typename std::conditional<std::is_floating_point<DstType>::value, DstType,
67 std::int32_t>::type;
68
69 ruy::Matrix<LhsType> lhs;
70 ruy::Matrix<RhsType> rhs;
71 ruy::Matrix<DstType> dst;
72
73 auto parse_order = [](const char* name) {
74 if (!std::strcmp(name, "row-major")) {
75 return ruy::Order::kRowMajor;
76 } else if (!std::strcmp(name, "column-major")) {
77 return ruy::Order::kColMajor;
78 } else {
79 fprintf(stderr, "Failed to parse %s\n", name);
80 exit(1);
81 }
82 };
83
84 auto make_layout = [](int rows, int cols, int stride, ruy::Order order,
85 ruy::Layout* layout) {
86 layout->set_rows(rows);
87 layout->set_cols(cols);
88 layout->set_order(order);
89 int base_stride = order == ruy::Order::kRowMajor ? cols : rows;
90 layout->set_stride(stride ? stride : base_stride);
91 };
92
93 make_layout(params.m, params.k, params.lhs_stride,
94 parse_order(params.lhs_order), lhs.mutable_layout());
95 make_layout(params.k, params.n, params.rhs_stride,
96 parse_order(params.rhs_order), rhs.mutable_layout());
97 make_layout(params.m, params.n, params.dst_stride,
98 parse_order(params.dst_order), dst.mutable_layout());
99
100 lhs.set_zero_point(params.lhs_zero_point);
101 rhs.set_zero_point(params.rhs_zero_point);
102 dst.set_zero_point(params.dst_zero_point);
103
104 lhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.lhs_cache_policy));
105 rhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.rhs_cache_policy));
106
107 auto flat_size = [](const ruy::Layout& layout) {
108 int outer_size =
109 layout.order() == ruy::Order::kRowMajor ? layout.rows() : layout.cols();
110 return outer_size * layout.stride();
111 };
112
113 std::vector<LhsType> lhs_buf(flat_size(lhs.layout()));
114 std::vector<RhsType> rhs_buf(flat_size(rhs.layout()));
115 std::vector<DstType> dst_buf(flat_size(dst.layout()));
116
117 lhs.set_data(lhs_buf.data());
118 rhs.set_data(rhs_buf.data());
119 dst.set_data(dst_buf.data());
120
121 ruy::Context context;
122 context.set_max_num_threads(params.num_threads);
123 context.set_runtime_enabled_paths(static_cast<ruy::Path>(params.paths));
124
125 ruy::MulParams<AccumType, DstType> mul_params;
126 // Here an actual application might set some mul_params fields.
127 // Quantization multipliers, bias-vector, clamp bounds, etc.
128
129 for (int r = 0; r < params.repeat; r++) {
130 ruy::Mul(lhs, rhs, mul_params, &context, &dst);
131 }
132 }
133
main(int argc,char * argv[])134 int main(int argc, char* argv[]) {
135 bool help = argc == 1 || (argc == 2 && !strcmp(argv[1], "--help"));
136 if (help) {
137 fprintf(stderr, "Command-line flags (all in the form --flag=value):\n");
138 fprintf(stderr, "%-20s %-12s %-16s %s\n", "flag", "format", "default",
139 "allowed");
140 }
141 Params params;
142 const char* allowed_types =
143 "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8";
144 const char* allowed_orders = "row-major, column-major";
145 read_cmdline_args(help, argc, argv, "--types", "%s", "f32xf32->f32",
146 allowed_types, ¶ms.types);
147 read_cmdline_args(help, argc, argv, "--shape", "%dx%dx%d", "100x100x100",
148 nullptr, ¶ms.m, ¶ms.k, ¶ms.n);
149 read_cmdline_args(help, argc, argv, "--paths", "%x", "0", nullptr,
150 ¶ms.paths);
151 read_cmdline_args(help, argc, argv, "--num_threads", "%d", "1", nullptr,
152 ¶ms.num_threads);
153 read_cmdline_args(help, argc, argv, "--repeat", "%d", "1", nullptr,
154 ¶ms.repeat);
155 read_cmdline_args(help, argc, argv, "--lhs_cache_policy", "%d", "0",
156 "0, 1, 2, 3", ¶ms.lhs_cache_policy);
157 read_cmdline_args(help, argc, argv, "--rhs_cache_policy", "%d", "0",
158 "0, 1, 2, 3", ¶ms.rhs_cache_policy);
159 read_cmdline_args(help, argc, argv, "--lhs_stride", "%d", "0", nullptr,
160 ¶ms.lhs_stride);
161 read_cmdline_args(help, argc, argv, "--rhs_stride", "%d", "0", nullptr,
162 ¶ms.rhs_stride);
163 read_cmdline_args(help, argc, argv, "--dst_stride", "%d", "0", nullptr,
164 ¶ms.dst_stride);
165 read_cmdline_args(help, argc, argv, "--lhs_zero_point", "%d", "0", nullptr,
166 ¶ms.lhs_zero_point);
167 read_cmdline_args(help, argc, argv, "--rhs_zero_point", "%d", "0", nullptr,
168 ¶ms.rhs_zero_point);
169 read_cmdline_args(help, argc, argv, "--dst_zero_point", "%d", "0", nullptr,
170 ¶ms.dst_zero_point);
171 read_cmdline_args(help, argc, argv, "--lhs_order", "%s", "row-major",
172 allowed_orders, ¶ms.lhs_order);
173 read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
174 allowed_orders, ¶ms.rhs_order);
175 read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
176 allowed_orders, ¶ms.dst_order);
177
178 if (help) {
179 exit(1);
180 }
181
182 if (!strcmp(params.types, "f32xf32->f32")) {
183 run<float, float, float>(params);
184 } else if (!strcmp(params.types, "i8xi8->i8")) {
185 run<std::int8_t, std::int8_t, std::int8_t>(params);
186 } else if (!strcmp(params.types, "i8xi8->i16")) {
187 run<std::int8_t, std::int8_t, std::int16_t>(params);
188 } else if (!strcmp(params.types, "i8xi8->i32")) {
189 run<std::int8_t, std::int8_t, std::int32_t>(params);
190 } else if (!strcmp(params.types, "u8xu8->i16")) {
191 run<std::uint8_t, std::uint8_t, std::int16_t>(params);
192 } else if (!strcmp(params.types, "u8xi8->u8")) {
193 run<std::uint8_t, std::int8_t, std::uint8_t>(params);
194 } else {
195 fprintf(stderr, "Unknown types: %s\n", params.types);
196 exit(1);
197 }
198 }
199