• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 #include <cstdint>
2 #include <cstdio>
3 #include <cstdlib>
4 #include <cstring>
5 #include <type_traits>
6 
7 #include "ruy/context.h"
8 #include "ruy/matrix.h"
9 #include "ruy/mul_params.h"
10 #include "ruy/ruy.h"
11 
12 template <typename... Dst>
read_cmdline_args(bool help,int argc,char * argv[],const char * name,const char * format,const char * default_value,const char * allowed_values,Dst...dst)13 void read_cmdline_args(bool help, int argc, char* argv[], const char* name,
14                        const char* format, const char* default_value,
15                        const char* allowed_values, Dst... dst) {
16   if (help) {
17     fprintf(stderr, "%-20s %-12s %-16s %s\n", name, format, default_value,
18             allowed_values ? allowed_values : "");
19     return;
20   }
21   const char* value = default_value;
22   for (int i = 1; i < argc; i++) {
23     if (std::strstr(argv[i], name) == argv[i]) {
24       const char* equal_sign = std::strchr(argv[i], '=');
25       if (equal_sign == argv[i] + std::strlen(name)) {
26         value = equal_sign + 1;
27       }
28       break;
29     }
30   }
31   if (allowed_values) {
32     if (!std::strstr(allowed_values, value)) {
33       fprintf(stderr, "Illegal value %s. The legal values are %s.\n", value,
34               allowed_values);
35       exit(1);
36     }
37   }
38   if (sizeof...(Dst) != sscanf(value, format, dst...)) {
39     fprintf(stderr, "Failed to parse %s\n", value);
40     exit(1);
41   }
42 }
43 
44 struct Params {
45   char types[100];
46   int m, k, n;  // matmul shape m*k*n
47   int paths;
48   int num_threads;
49   int repeat;
50   int lhs_cache_policy;
51   int rhs_cache_policy;
52   int lhs_stride;
53   int rhs_stride;
54   int dst_stride;
55   int lhs_zero_point;
56   int rhs_zero_point;
57   int dst_zero_point;
58   char lhs_order[100];
59   char rhs_order[100];
60   char dst_order[100];
61 };
62 
63 template <typename LhsType, typename RhsType, typename DstType>
run(const Params & params)64 void run(const Params& params) {
65   using AccumType =
66       typename std::conditional<std::is_floating_point<DstType>::value, DstType,
67                                 std::int32_t>::type;
68 
69   ruy::Matrix<LhsType> lhs;
70   ruy::Matrix<RhsType> rhs;
71   ruy::Matrix<DstType> dst;
72 
73   auto parse_order = [](const char* name) {
74     if (!std::strcmp(name, "row-major")) {
75       return ruy::Order::kRowMajor;
76     } else if (!std::strcmp(name, "column-major")) {
77       return ruy::Order::kColMajor;
78     } else {
79       fprintf(stderr, "Failed to parse %s\n", name);
80       exit(1);
81     }
82   };
83 
84   auto make_layout = [](int rows, int cols, int stride, ruy::Order order,
85                         ruy::Layout* layout) {
86     layout->set_rows(rows);
87     layout->set_cols(cols);
88     layout->set_order(order);
89     int base_stride = order == ruy::Order::kRowMajor ? cols : rows;
90     layout->set_stride(stride ? stride : base_stride);
91   };
92 
93   make_layout(params.m, params.k, params.lhs_stride,
94               parse_order(params.lhs_order), lhs.mutable_layout());
95   make_layout(params.k, params.n, params.rhs_stride,
96               parse_order(params.rhs_order), rhs.mutable_layout());
97   make_layout(params.m, params.n, params.dst_stride,
98               parse_order(params.dst_order), dst.mutable_layout());
99 
100   lhs.set_zero_point(params.lhs_zero_point);
101   rhs.set_zero_point(params.rhs_zero_point);
102   dst.set_zero_point(params.dst_zero_point);
103 
104   lhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.lhs_cache_policy));
105   rhs.set_cache_policy(static_cast<ruy::CachePolicy>(params.rhs_cache_policy));
106 
107   auto flat_size = [](const ruy::Layout& layout) {
108     int outer_size =
109         layout.order() == ruy::Order::kRowMajor ? layout.rows() : layout.cols();
110     return outer_size * layout.stride();
111   };
112 
113   std::vector<LhsType> lhs_buf(flat_size(lhs.layout()));
114   std::vector<RhsType> rhs_buf(flat_size(rhs.layout()));
115   std::vector<DstType> dst_buf(flat_size(dst.layout()));
116 
117   lhs.set_data(lhs_buf.data());
118   rhs.set_data(rhs_buf.data());
119   dst.set_data(dst_buf.data());
120 
121   ruy::Context context;
122   context.set_max_num_threads(params.num_threads);
123   context.set_runtime_enabled_paths(static_cast<ruy::Path>(params.paths));
124 
125   ruy::MulParams<AccumType, DstType> mul_params;
126   // Here an actual application might set some mul_params fields.
127   // Quantization multipliers, bias-vector, clamp bounds, etc.
128 
129   for (int r = 0; r < params.repeat; r++) {
130     ruy::Mul(lhs, rhs, mul_params, &context, &dst);
131   }
132 }
133 
main(int argc,char * argv[])134 int main(int argc, char* argv[]) {
135   bool help = argc == 1 || (argc == 2 && !strcmp(argv[1], "--help"));
136   if (help) {
137     fprintf(stderr, "Command-line flags (all in the form --flag=value):\n");
138     fprintf(stderr, "%-20s %-12s %-16s %s\n", "flag", "format", "default",
139             "allowed");
140   }
141   Params params;
142   const char* allowed_types =
143       "f32xf32->f32, i8xi8->i8, i8xi8->i16, i8xi8->i32, u8xu8->i16, u8xi8->u8";
144   const char* allowed_orders = "row-major, column-major";
145   read_cmdline_args(help, argc, argv, "--types", "%s", "f32xf32->f32",
146                     allowed_types, &params.types);
147   read_cmdline_args(help, argc, argv, "--shape", "%dx%dx%d", "100x100x100",
148                     nullptr, &params.m, &params.k, &params.n);
149   read_cmdline_args(help, argc, argv, "--paths", "%x", "0", nullptr,
150                     &params.paths);
151   read_cmdline_args(help, argc, argv, "--num_threads", "%d", "1", nullptr,
152                     &params.num_threads);
153   read_cmdline_args(help, argc, argv, "--repeat", "%d", "1", nullptr,
154                     &params.repeat);
155   read_cmdline_args(help, argc, argv, "--lhs_cache_policy", "%d", "0",
156                     "0, 1, 2, 3", &params.lhs_cache_policy);
157   read_cmdline_args(help, argc, argv, "--rhs_cache_policy", "%d", "0",
158                     "0, 1, 2, 3", &params.rhs_cache_policy);
159   read_cmdline_args(help, argc, argv, "--lhs_stride", "%d", "0", nullptr,
160                     &params.lhs_stride);
161   read_cmdline_args(help, argc, argv, "--rhs_stride", "%d", "0", nullptr,
162                     &params.rhs_stride);
163   read_cmdline_args(help, argc, argv, "--dst_stride", "%d", "0", nullptr,
164                     &params.dst_stride);
165   read_cmdline_args(help, argc, argv, "--lhs_zero_point", "%d", "0", nullptr,
166                     &params.lhs_zero_point);
167   read_cmdline_args(help, argc, argv, "--rhs_zero_point", "%d", "0", nullptr,
168                     &params.rhs_zero_point);
169   read_cmdline_args(help, argc, argv, "--dst_zero_point", "%d", "0", nullptr,
170                     &params.dst_zero_point);
171   read_cmdline_args(help, argc, argv, "--lhs_order", "%s", "row-major",
172                     allowed_orders, &params.lhs_order);
173   read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
174                     allowed_orders, &params.rhs_order);
175   read_cmdline_args(help, argc, argv, "--rhs_order", "%s", "row-major",
176                     allowed_orders, &params.dst_order);
177 
178   if (help) {
179     exit(1);
180   }
181 
182   if (!strcmp(params.types, "f32xf32->f32")) {
183     run<float, float, float>(params);
184   } else if (!strcmp(params.types, "i8xi8->i8")) {
185     run<std::int8_t, std::int8_t, std::int8_t>(params);
186   } else if (!strcmp(params.types, "i8xi8->i16")) {
187     run<std::int8_t, std::int8_t, std::int16_t>(params);
188   } else if (!strcmp(params.types, "i8xi8->i32")) {
189     run<std::int8_t, std::int8_t, std::int32_t>(params);
190   } else if (!strcmp(params.types, "u8xu8->i16")) {
191     run<std::uint8_t, std::uint8_t, std::int16_t>(params);
192   } else if (!strcmp(params.types, "u8xi8->u8")) {
193     run<std::uint8_t, std::int8_t, std::uint8_t>(params);
194   } else {
195     fprintf(stderr, "Unknown types: %s\n", params.types);
196     exit(1);
197   }
198 }
199