Lines Matching refs:params
26 void Gemm(const Params& params);
31 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, in EstimateScratchSize() argument
35 params.left_stream, kernel_m, kernel_k); in EstimateScratchSize()
36 const int rhs_chunks = ((params.n + kernel_n - 1) / kernel_n); in EstimateScratchSize()
40 params.right_stream, kernel_n, kernel_k); in EstimateScratchSize()
46 static void ExecuteDispatch3D(const P& params) { in ExecuteDispatch3D() argument
91 << k_leftovers << " -- " << params.m << "x" << params.n << "x" in ExecuteDispatch3D()
92 << params.k << std::endl; in ExecuteDispatch3D()
93 LeftStreamF::Debug(params.left_stream); in ExecuteDispatch3D()
94 LeftStreamL::Debug(params.left_stream); in ExecuteDispatch3D()
96 RightStreamF::Debug(params.right_stream); in ExecuteDispatch3D()
97 RightStreamL::Debug(params.right_stream); in ExecuteDispatch3D()
99 OutputStreamFF::Debug(params.fused_kernel.output_stream); in ExecuteDispatch3D()
100 OutputStreamLF::Debug(params.fused_kernel.output_stream); in ExecuteDispatch3D()
102 KernelFF::Debug(params.fused_kernel); in ExecuteDispatch3D()
103 KernelFL::Debug(params.fused_kernel); in ExecuteDispatch3D()
104 KernelLF::Debug(params.fused_kernel); in ExecuteDispatch3D()
105 KernelLL::Debug(params.fused_kernel); in ExecuteDispatch3D()
109 int lhs_chunks = params.m / m; in ExecuteDispatch3D()
110 int rhs_chunks = params.n / n; in ExecuteDispatch3D()
114 std::uint8_t* packed_lhs = params.scratch; in ExecuteDispatch3D()
116 params.scratch + LeftStreamF::Scratch(params.left_stream); in ExecuteDispatch3D()
122 RightStreamF::PackedStride(params.right_stream); in ExecuteDispatch3D()
126 reinterpret_cast<const std::uint8_t*>(params.rhs); in ExecuteDispatch3D()
128 RightStreamF::UnpackedStride(params.right_stream); in ExecuteDispatch3D()
132 params.right_stream, in ExecuteDispatch3D()
140 params.right_stream, in ExecuteDispatch3D()
147 reinterpret_cast<const std::uint8_t*>(params.lhs); in ExecuteDispatch3D()
148 std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); in ExecuteDispatch3D()
153 LeftStreamF::UnpackedStride(params.left_stream); in ExecuteDispatch3D()
155 OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); in ExecuteDispatch3D()
157 OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); in ExecuteDispatch3D()
161 params.left_stream, in ExecuteDispatch3D()
170 params.fused_kernel, in ExecuteDispatch3D()
179 params.fused_kernel, in ExecuteDispatch3D()
190 OutputStreamLF::UnpackedAdvance(params.fused_kernel.output_stream); in ExecuteDispatch3D()
193 params.left_stream, in ExecuteDispatch3D()
202 params.fused_kernel, in ExecuteDispatch3D()
211 params.fused_kernel, in ExecuteDispatch3D()
220 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, in EstimateScratchSize() argument
222 const int lhs_chunks = ((params.m + kernel_m - 1) / kernel_m); in EstimateScratchSize()
226 params.left_stream, kernel_m, kernel_k); in EstimateScratchSize()
229 params.right_stream, kernel_n, kernel_k); in EstimateScratchSize()
235 static void ExecuteDispatch3D(const P& params) { in ExecuteDispatch3D() argument
279 << k_leftovers << " -- " << params.m << "x" << params.n << "x" in ExecuteDispatch3D()
280 << params.k << std::endl; in ExecuteDispatch3D()
281 LeftStreamF::Debug(params.left_stream); in ExecuteDispatch3D()
282 LeftStreamL::Debug(params.left_stream); in ExecuteDispatch3D()
284 RightStreamF::Debug(params.right_stream); in ExecuteDispatch3D()
285 RightStreamL::Debug(params.right_stream); in ExecuteDispatch3D()
287 OutputStreamFF::Debug(params.fused_kernel.output_stream); in ExecuteDispatch3D()
288 OutputStreamFL::Debug(params.fused_kernel.output_stream); in ExecuteDispatch3D()
290 KernelFF::Debug(params.fused_kernel); in ExecuteDispatch3D()
291 KernelFL::Debug(params.fused_kernel); in ExecuteDispatch3D()
292 KernelLF::Debug(params.fused_kernel); in ExecuteDispatch3D()
293 KernelLL::Debug(params.fused_kernel); in ExecuteDispatch3D()
297 int lhs_chunks = params.m / m; in ExecuteDispatch3D()
298 int rhs_chunks = params.n / n; in ExecuteDispatch3D()
301 std::uint8_t* packed_rhs = params.scratch; in ExecuteDispatch3D()
303 params.scratch + RightStreamF::Scratch(params.right_stream); in ExecuteDispatch3D()
309 LeftStreamF::PackedStride(params.left_stream); in ExecuteDispatch3D()
313 reinterpret_cast<const std::uint8_t*>(params.lhs); in ExecuteDispatch3D()
315 LeftStreamF::UnpackedStride(params.left_stream); in ExecuteDispatch3D()
319 params.left_stream, in ExecuteDispatch3D()
327 params.left_stream, in ExecuteDispatch3D()
334 reinterpret_cast<const std::uint8_t*>(params.rhs); in ExecuteDispatch3D()
335 std::uint8_t* result_strip = reinterpret_cast<std::uint8_t*>(params.result); in ExecuteDispatch3D()
340 RightStreamF::UnpackedStride(params.right_stream); in ExecuteDispatch3D()
342 OutputStreamFF::UnpackedAdvance(params.fused_kernel.output_stream); in ExecuteDispatch3D()
344 OutputStreamFF::UnpackedStride(params.fused_kernel.output_stream); in ExecuteDispatch3D()
348 params.right_stream, in ExecuteDispatch3D()
357 params.fused_kernel, in ExecuteDispatch3D()
366 params.fused_kernel, in ExecuteDispatch3D()
377 OutputStreamFL::UnpackedStride(params.fused_kernel.output_stream); in ExecuteDispatch3D()
380 params.right_stream, in ExecuteDispatch3D()
389 params.fused_kernel, in ExecuteDispatch3D()
398 params.fused_kernel, in ExecuteDispatch3D()
418 const Params& params, Params* task_params) { in UpdateCacheFriendlyTask() argument
422 params.left_stream, params.lhs, m_offset, 0); in UpdateCacheFriendlyTask()
427 params.right_stream, params.rhs, n_offset, 0); in UpdateCacheFriendlyTask()
431 Offset(params.fused_kernel.output_stream, params.result, m_offset, in UpdateCacheFriendlyTask()
441 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, in EstimateScratchSize() argument
448 static void ExecuteDispatch3D(const P& params) { in ExecuteDispatch3D() argument
457 const int lhs_scratch = LeftStream::Scratch(params.left_stream); in ExecuteDispatch3D()
458 const int rhs_scratch = RightStream::Scratch(params.right_stream); in ExecuteDispatch3D()
462 rhs_scratch, params.n, n); in ExecuteDispatch3D()
466 n_leftovers, k_leftovers>(params); in ExecuteDispatch3D()
470 const int cache_friendly_dim = params.n / cache_friendly_tasks_count; in ExecuteDispatch3D()
472 P task_params = params; in ExecuteDispatch3D()
474 internal::UpdateCacheFriendlyTask(0, params.m, i * cache_friendly_dim, in ExecuteDispatch3D()
475 cache_friendly_dim, params, in ExecuteDispatch3D()
480 internal::UpdateCacheFriendlyTask(0, params.m, dim_sum, params.n - dim_sum, in ExecuteDispatch3D()
481 params, &task_params); in ExecuteDispatch3D()
490 static int EstimateScratchSize(const P& params, int kernel_m, int kernel_n, in EstimateScratchSize() argument
497 static void ExecuteDispatch3D(const P& params) { in ExecuteDispatch3D() argument
506 const int lhs_scratch = LeftStream::Scratch(params.left_stream); in ExecuteDispatch3D()
507 const int rhs_scratch = RightStream::Scratch(params.right_stream); in ExecuteDispatch3D()
511 lhs_scratch, params.m, m); in ExecuteDispatch3D()
515 n_leftovers, k_leftovers>(params); in ExecuteDispatch3D()
519 const int cache_friendly_dim = params.m / cache_friendly_tasks_count; in ExecuteDispatch3D()
521 P task_params = params; in ExecuteDispatch3D()
524 cache_friendly_dim, 0, params.n, params, in ExecuteDispatch3D()
529 internal::UpdateCacheFriendlyTask(dim_sum, params.m - dim_sum, 0, params.n, in ExecuteDispatch3D()
530 params, &task_params); in ExecuteDispatch3D()
542 static void Execute(const P& params, int k) { in Execute()
553 variable_k>(params); in Execute()
556 variable_k - 1>::Execute(params, k); in Execute()
564 static void Execute(const P& params, int k) {
574 0>(params);
589 static void Execute(const P& params, int n, int k) {
599 dim_k - 1>::Execute(params, k);
602 variable_n - 1>::Execute(params, n, k);
609 static void Execute(const P& params, int n, int k) {
619 dim_k - 1>::Execute(params, k);
634 static void Execute(const P& params, int m, int n, int k) {
644 dim_n - 1>::Execute(params, n, k);
647 params, m, n, k);
654 static void Execute(const P& params, int m, int n, int k) {
663 Dispatch3DStage2<E, P, dim_m, dim_n, dim_k, 0, dim_n - 1>::Execute(params,
678 inline void Gemm(const Params& params) {
680 kernel_m - 1>::Execute(params, params.m % kernel_m,
681 params.n % kernel_n,
682 params.k % kernel_k);