1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // kernel_reference.h: a reference kernel for CPU architectures where we don't 16 // have optimized kernels yet. Also useful for testing, as it's templatized 17 // to have any arbitrary format, allowing tests to cover all sorts of corner 18 // cases. 19 20 #ifndef GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 21 #define GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 22 23 #include "kernel.h" 24 25 #include <cstdio> 26 #include <cstring> 27 28 namespace gemmlowp { 29 30 // This kernel is templatized in an arbitrary Format template parameter, 31 // allowing it to have any arbitrary format. 32 template <typename tFormat> 33 struct ReferenceKernel : KernelBase { 34 typedef tFormat Format; 35 NameReferenceKernel36 const char* Name() const override { 37 static char buf[256]; 38 snprintf(buf, sizeof(buf), 39 "reference(Lhs: %d cells %dx%d %s, Rhs: %d cells %dx%d %s)", 40 Format::Lhs::kCells, Format::Lhs::Cell::kWidth, 41 Format::Lhs::Cell::kDepth, 42 CellOrderName(Format::Lhs::Cell::kOrder), Format::Rhs::kCells, 43 Format::Rhs::Cell::kDepth, Format::Rhs::Cell::kWidth, 44 CellOrderName(Format::Rhs::Cell::kOrder)); 45 return buf; 46 } 47 RunReferenceKernel48 void Run(std::int32_t* dst_ptr, std::size_t dst_row_stride, 49 std::size_t dst_col_stride, const std::uint8_t* lhs_ptr, 50 const std::uint8_t* rhs_ptr, std::size_t start_depth, 51 std::size_t run_depth) const override { 52 std::int32_t accumulator[Format::kRows * Format::kCols]; 53 memset(accumulator, 0, sizeof(accumulator)); 54 55 const int run_depth_cells = static_cast<int>(run_depth / Format::kDepth); 56 57 // The outer loop is over the depth dimension. 58 for (int dc = 0; dc < run_depth_cells; dc++) { 59 // The next two loops are over cells of the Lhs (stacked vertically), 60 // and over cells of the Rhs (stacked horizontally). 61 for (int rc = 0; rc < Format::Lhs::kCells; rc++) { 62 const std::uint8_t* lhs_cell_ptr = 63 lhs_ptr + (dc * Format::Lhs::kCells + rc) * 64 Format::Lhs::Cell::kWidth * Format::kDepth; 65 for (int cc = 0; cc < Format::Rhs::kCells; cc++) { 66 const std::uint8_t* rhs_cell_ptr = 67 rhs_ptr + (dc * Format::Rhs::kCells + cc) * 68 Format::Rhs::Cell::kWidth * Format::kDepth; 69 70 // Now we are inside one cell of the Lhs and inside one cell 71 // of the Rhs, so the remaining inner loops are just 72 // traditional three loops of matrix multiplication. 73 for (int di = 0; di < Format::kDepth; di++) { 74 for (int ri = 0; ri < Format::Lhs::Cell::kWidth; ri++) { 75 for (int ci = 0; ci < Format::Rhs::Cell::kWidth; ci++) { 76 const std::uint8_t* lhs_coeff_ptr = 77 lhs_cell_ptr + 78 OffsetIntoCell<typename Format::Lhs::Cell>(ri, di); 79 const std::uint8_t* rhs_coeff_ptr = 80 rhs_cell_ptr + 81 OffsetIntoCell<typename Format::Rhs::Cell>(ci, di); 82 std::int32_t* accumulator_coeff_ptr = 83 accumulator + (ri + rc * Format::Lhs::Cell::kWidth) + 84 (ci + cc * Format::Rhs::Cell::kWidth) * Format::kRows; 85 *accumulator_coeff_ptr += 86 std::int32_t(*lhs_coeff_ptr) * std::int32_t(*rhs_coeff_ptr); 87 } 88 } 89 } 90 } 91 } 92 } 93 94 if (start_depth == 0) { 95 // start_depth == 0 means we haven't accumulated anything yet, so we need 96 // to overwrite the accumulator, as it hasn't been initialized to zero. 97 for (int r = 0; r < Format::kRows; r++) { 98 for (int c = 0; c < Format::kCols; c++) { 99 dst_ptr[r * dst_row_stride + c * dst_col_stride] = 100 accumulator[r + c * Format::kRows]; 101 } 102 } 103 } else { 104 // We have already accumulated stuff, so we need to continue accumulating 105 // instead of just overwriting. 106 for (int r = 0; r < Format::kRows; r++) { 107 for (int c = 0; c < Format::kCols; c++) { 108 dst_ptr[r * dst_row_stride + c * dst_col_stride] += 109 accumulator[r + c * Format::kRows]; 110 } 111 } 112 } 113 } 114 }; 115 116 } // namespace gemmlowp 117 118 #endif // GEMMLOWP_INTERNAL_KERNEL_REFERENCE_H_ 119