1 // Copyright 2015 The Gemmlowp Authors. All Rights Reserved. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // output.h: processing the 32-bit accumulators output by the unpack 16 // stage, obtaining the final result matrix entries and storing them into 17 // the destination matrix. 18 19 #ifndef GEMMLOWP_INTERNAL_OUTPUT_H_ 20 #define GEMMLOWP_INTERNAL_OUTPUT_H_ 21 22 #include <cmath> 23 #include <tuple> 24 #include <type_traits> 25 #include <typeinfo> 26 27 #include "../fixedpoint/fixedpoint.h" 28 #include "../public/output_stages.h" 29 #include "simd_wrappers.h" 30 31 namespace gemmlowp { 32 33 template <typename OutputStage, typename InputBufferType> 34 struct OutputStageEvalBufferImpl { 35 // This generic template body should never be hit. 36 static_assert( 37 std::is_same<InputBufferType, void>::value, 38 "Unimplemented: missing implementation of this output pipeline stage " 39 "for this data type. This would happen if some architecture-specific " 40 "SIMD back-end (output_$arch.h) were incomplete."); 41 }; 42 43 template <typename OutputStage, typename InputType> 44 struct OutputStageEvalImpl { 45 static constexpr int kRows = InputType::kRows; 46 static constexpr int kCols = InputType::kCols; 47 using InputBufferType = typename InputType::BufferType; 48 using BufferEvalImplType = 49 OutputStageEvalBufferImpl<OutputStage, InputBufferType>; 50 using OutputBufferType = typename BufferEvalImplType::OutputType; 51 using OutputScalarType = typename OutputBufferType::ScalarType; 52 using OutputType = RegisterBlock<OutputScalarType, kRows, kCols>; 53 OutputStageEvalImplOutputStageEvalImpl54 OutputStageEvalImpl(const OutputStage& s) : buffer_eval_impl(s) {} 55 EvalOutputStageEvalImpl56 OutputType Eval(InputType input, int, int) const { 57 OutputType output; 58 output.buf = buffer_eval_impl.Eval(input.buf); 59 return output; 60 } 61 62 const BufferEvalImplType buffer_eval_impl; 63 }; 64 65 template <int Size> 66 struct OutputStageEvalBufferImpl<OutputStageQuantizeDownInt32ToUint8Scale, 67 RegisterBuffer<std::int32_t, Size>> { 68 using InputType = RegisterBuffer<std::int32_t, Size>; 69 using OutputType = RegisterBuffer<std::int32_t, Size>; 70 71 typedef OutputStageQuantizeDownInt32ToUint8Scale OutputStage; 72 73 OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} 74 75 OutputType Eval(InputType input) const { 76 const int result_shift = output_stage.result_shift; 77 const std::int32_t result_mult_int = output_stage.result_mult_int; 78 using RegisterType = typename InputType::RegisterType; 79 const RegisterType result_offset = 80 Dup<RegisterType>(output_stage.result_offset); 81 OutputType output; 82 for (int i = 0; i < InputType::kRegisterCount; i++) { 83 output.reg[i] = RoundingDivideByPOT( 84 Mul(Add(input.reg[i], result_offset), result_mult_int), result_shift); 85 } 86 return output; 87 } 88 89 const OutputStage& output_stage; 90 }; 91 92 template <int Rows, int Cols, VectorShape Shape> 93 struct OutputStageEvalImpl<OutputStageQuantizeDownInt32ToUint8ScalePC<Shape>, 94 RegisterBlock<std::int32_t, Rows, Cols>> { 95 typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; 96 typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; 97 typedef OutputStageQuantizeDownInt32ToUint8ScalePC<Shape> OutputStage; 98 99 OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} 100 101 OutputType Eval(InputType input, int row, int col) const { 102 OutputType output; 103 const int result_shift = output_stage.result_shift; 104 const int pos = Shape == VectorShape::Col ? row : col; 105 const auto result_mult_int = 106 LoadForBroadcasting<InputType>(output_stage.result_mult_int, pos); 107 const auto result_offset = 108 LoadForBroadcasting<InputType>(output_stage.result_offset, pos); 109 const auto dividend = BroadcastMul<InputType>( 110 BroadcastAdd<InputType>(input, result_offset), result_mult_int); 111 for (int i = 0; i < InputType::kRegisterCount; i++) { 112 output.buf.reg[i] = 113 RoundingDivideByPOT(dividend.buf.reg[i], result_shift); 114 } 115 return output; 116 } 117 118 const OutputStage& output_stage; 119 }; 120 121 template <int Size> 122 struct OutputStageEvalBufferImpl< 123 OutputStageQuantizeDownInt32ByFixedPoint, 124 RegisterBuffer<std::int32_t, Size>> { 125 typedef RegisterBuffer<std::int32_t, Size> InputType; 126 typedef RegisterBuffer<std::int32_t, Size> OutputType; 127 128 typedef OutputStageQuantizeDownInt32ByFixedPoint OutputStage; 129 130 OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} 131 132 OutputType Eval(InputType input) const { 133 OutputType output; 134 using RegisterType = typename InputType::RegisterType; 135 const RegisterType result_offset_after_shift = 136 Dup<RegisterType>(output_stage.result_offset_after_shift); 137 for (int i = 0; i < InputType::kRegisterCount; i++) { 138 const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( 139 input.reg[i], output_stage.result_fixedpoint_multiplier); 140 output.reg[i] = 141 Add(RoundingDivideByPOT(mulhigh_val, output_stage.result_shift), 142 result_offset_after_shift); 143 } 144 return output; 145 } 146 147 const OutputStage& output_stage; 148 }; 149 150 template <int Size> 151 struct OutputStageEvalBufferImpl<OutputStageScaleInt32ByFixedPointAndExponent, 152 RegisterBuffer<std::int32_t, Size>> { 153 typedef RegisterBuffer<std::int32_t, Size> InputType; 154 typedef RegisterBuffer<std::int32_t, Size> OutputType; 155 156 typedef OutputStageScaleInt32ByFixedPointAndExponent OutputStage; 157 158 OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { 159 left_shift = std::max(0, output_stage.result_exponent); 160 right_shift = std::max(0, -output_stage.result_exponent); 161 } 162 163 OutputType Eval(InputType input) const { 164 OutputType output; 165 using RegisterType = typename InputType::RegisterType; 166 const RegisterType result_offset_after_shift = 167 Dup<RegisterType>(output_stage.result_offset_after_shift); 168 for (int i = 0; i < InputType::kRegisterCount; i++) { 169 const RegisterType mulhigh_val = SaturatingRoundingDoublingHighMul( 170 ShiftLeft(input.reg[i], left_shift), 171 output_stage.result_fixedpoint_multiplier); 172 output.reg[i] = Add(RoundingDivideByPOT(mulhigh_val, right_shift), 173 result_offset_after_shift); 174 } 175 return output; 176 } 177 178 const OutputStage& output_stage; 179 int left_shift; 180 int right_shift; 181 }; 182 183 template <int Rows, int Cols, VectorShape Shape> 184 struct OutputStageEvalImpl< 185 OutputStageScaleInt32ByFixedPointAndExponentPC<Shape>, 186 RegisterBlock<std::int32_t, Rows, Cols>> { 187 typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; 188 typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; 189 190 typedef OutputStageScaleInt32ByFixedPointAndExponentPC<Shape> OutputStage; 191 192 OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} 193 194 OutputType Eval(InputType input, int row, int col) const { 195 OutputType output; 196 const int pos = Shape == VectorShape::Row ? col : row; 197 using RegisterType = typename InputType::RegisterType; 198 const RegisterType result_offset_after_shift = 199 Dup<RegisterType>(output_stage.result_offset_after_shift); 200 auto left_shift = 201 LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); 202 auto right_shift = 203 LoadForBroadcasting<InputType>(output_stage.result_exponent, pos); 204 const auto result_fixedpoint_multiplier = LoadForBroadcasting<InputType>( 205 output_stage.result_fixedpoint_multiplier, pos); 206 for (int i = 0; i < decltype(left_shift)::kRegisterCount; i++) { 207 left_shift.buf.reg[i] = Max(left_shift.buf.reg[i], 0); 208 right_shift.buf.reg[i] = Max(-right_shift.buf.reg[i], 0); 209 } 210 const auto mulhigh_val = BroadcastSaturatingRoundingDoublingHighMul( 211 BroadcastShiftLeft(input, left_shift), result_fixedpoint_multiplier); 212 const auto rdpot_val = 213 BroadcastRoundingDivideByPOT(mulhigh_val, right_shift); 214 for (int i = 0; i < InputType::kRegisterCount; i++) { 215 output.buf.reg[i] = Add(rdpot_val.buf.reg[i], result_offset_after_shift); 216 } 217 return output; 218 } 219 220 const OutputStage& output_stage; 221 }; 222 223 // Implementation of OutputStageSaturatingCastToUint8 for scalar data. 224 template <int Size> 225 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToUint8, 226 RegisterBuffer<std::int32_t, Size>> { 227 typedef RegisterBuffer<std::int32_t, Size> InputType; 228 typedef RegisterBuffer<std::uint8_t, Size> OutputType; 229 static_assert(InputType::kRegisterLanes == 1, 230 "This path is only for scalar values"); 231 232 typedef OutputStageSaturatingCastToUint8 OutputStage; 233 234 OutputStageEvalBufferImpl(const OutputStage&) {} 235 236 OutputType Eval(InputType input) const { 237 OutputType output; 238 for (int i = 0; i < InputType::kRegisterCount; i++) { 239 std::int32_t data = input.reg[i]; 240 output.reg[i] = data > 255 ? 255 : data < 0 ? 0 : data; 241 } 242 return output; 243 } 244 }; 245 246 // Implementation of OutputStageSaturatingCastToInt8 for scalar data. 247 template <int Size> 248 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt8, 249 RegisterBuffer<std::int32_t, Size>> { 250 typedef RegisterBuffer<std::int32_t, Size> InputType; 251 typedef RegisterBuffer<std::int8_t, Size> OutputType; 252 static_assert(InputType::kRegisterLanes == 1, 253 "This path is only for scalar values"); 254 255 typedef OutputStageSaturatingCastToInt8 OutputStage; 256 257 OutputStageEvalBufferImpl(const OutputStage&) {} 258 259 OutputType Eval(InputType input) const { 260 OutputType output; 261 for (int i = 0; i < InputType::kRegisterCount; i++) { 262 std::int32_t data = input.reg[i]; 263 output.reg[i] = data > 127 ? 127 : data < -128 ? -128 : data; 264 } 265 return output; 266 } 267 }; 268 269 // Implementation of OutputStageSaturatingCastToInt16 for scalar data. 270 template <int Size> 271 struct OutputStageEvalBufferImpl<OutputStageSaturatingCastToInt16, 272 RegisterBuffer<std::int32_t, Size>> { 273 typedef RegisterBuffer<std::int32_t, Size> InputType; 274 typedef RegisterBuffer<std::int16_t, Size> OutputType; 275 static_assert(InputType::kRegisterLanes == 1, 276 "This path is only for scalar values"); 277 278 typedef OutputStageSaturatingCastToInt16 OutputStage; 279 280 OutputStageEvalBufferImpl(const OutputStage&) {} 281 282 OutputType Eval(InputType input) const { 283 OutputType output; 284 for (int i = 0; i < InputType::kRegisterCount; i++) { 285 std::int32_t data = input.reg[i]; 286 output.reg[i] = data > 32767 ? 32767 : data < -32768 ? -32768 : data; 287 } 288 return output; 289 } 290 }; 291 292 // Implementation of OutputStageTruncatingCastToUint8 for scalar data 293 template <int Size> 294 struct OutputStageEvalBufferImpl<OutputStageTruncatingCastToUint8, 295 RegisterBuffer<std::int32_t, Size>> { 296 typedef RegisterBuffer<std::int32_t, Size> InputType; 297 typedef RegisterBuffer<std::uint8_t, Size> OutputType; 298 static_assert(InputType::kRegisterLanes == 1, 299 "This path is only for scalar values"); 300 301 typedef OutputStageTruncatingCastToUint8 OutputStage; 302 303 OutputStageEvalBufferImpl(const OutputStage&) {} 304 305 OutputType Eval(InputType input) const { 306 OutputType output; 307 for (int i = 0; i < InputType::kRegisterCount; i++) { 308 output.reg[i] = input.reg[i]; 309 } 310 return output; 311 } 312 }; 313 314 template <int Rows, int Cols, typename VectorType> 315 struct OutputStageEvalImpl<OutputStageBiasAddition<VectorType>, 316 RegisterBlock<std::int32_t, Rows, Cols>> { 317 typedef RegisterBlock<std::int32_t, Rows, Cols> InputType; 318 typedef RegisterBlock<std::int32_t, Rows, Cols> OutputType; 319 typedef OutputStageBiasAddition<VectorType> OutputStage; 320 321 OutputStageEvalImpl(const OutputStage& s) : output_stage(s) {} 322 323 OutputType Eval(InputType input, int row, int col) const { 324 const int pos = VectorType::kShape == VectorShape::Row ? col : row; 325 return BroadcastAdd<InputType>( 326 input, LoadForBroadcasting<InputType>(output_stage.bias_vector, pos)); 327 } 328 329 const OutputStage& output_stage; 330 }; 331 332 template <int Size> 333 struct OutputStageEvalBufferImpl<OutputStageClamp, 334 RegisterBuffer<std::int32_t, Size>> { 335 typedef RegisterBuffer<std::int32_t, Size> InputType; 336 typedef RegisterBuffer<std::int32_t, Size> OutputType; 337 338 typedef OutputStageClamp OutputStage; 339 340 OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) {} 341 342 OutputType Eval(InputType input) const { 343 using RegisterType = typename InputType::RegisterType; 344 const RegisterType min = Dup<RegisterType>(output_stage.min); 345 const RegisterType max = Dup<RegisterType>(output_stage.max); 346 OutputType output; 347 for (int i = 0; i < InputType::kRegisterCount; i++) { 348 output.reg[i] = Min(Max(input.reg[i], min), max); 349 } 350 return output; 351 } 352 353 const OutputStage& output_stage; 354 }; 355 356 template <int Size> 357 struct OutputStageEvalBufferImpl<OutputStageTanh, 358 RegisterBuffer<std::int32_t, Size>> { 359 typedef RegisterBuffer<std::int32_t, Size> InputType; 360 typedef RegisterBuffer<std::int32_t, Size> OutputType; 361 using RegisterType = typename InputType::RegisterType; 362 typedef RegisterType DataType; 363 typedef OutputStageTanh OutputStage; 364 365 OutputStageEvalBufferImpl(const OutputStage& s) : output_stage(s) { 366 const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; 367 const std::int32_t real_amplitude_as_int32 = 368 output_stage.real_amplitude_as_int32; 369 370 input_cutoff_min = real_zero_as_int32 - 8 * real_amplitude_as_int32; 371 input_cutoff_max = real_zero_as_int32 + 8 * real_amplitude_as_int32; 372 output_min = real_zero_as_int32 - real_amplitude_as_int32; 373 output_max = real_zero_as_int32 + real_amplitude_as_int32; 374 375 double inverse_amplitude_normalized_double = 1.0 / real_amplitude_as_int32; 376 inverse_amplitude_neg_exponent = 0; 377 while (inverse_amplitude_normalized_double < 0.5) { 378 inverse_amplitude_normalized_double *= 2; 379 inverse_amplitude_neg_exponent++; 380 } 381 inverse_amplitude_normalized = FixedPoint<DataType, 0>::FromDouble( 382 inverse_amplitude_normalized_double); 383 384 double amplitude_normalized_double = real_amplitude_as_int32; 385 amplitude_exponent = 0; 386 while (amplitude_normalized_double >= 1.0) { 387 amplitude_normalized_double *= 0.5; 388 amplitude_exponent++; 389 } 390 amplitude_normalized = 391 FixedPoint<DataType, 0>::FromDouble(amplitude_normalized_double); 392 } 393 394 OutputType Eval(InputType input) const { 395 const std::int32_t real_zero_as_int32 = output_stage.real_zero_as_int32; 396 397 typedef FixedPoint<DataType, 3> F3; 398 typedef FixedPoint<DataType, 0> F0; 399 400 OutputType output; 401 402 for (int i = 0; i < OutputType::kRegisterCount; i++) { 403 // fixed-point affine transformation 404 DataType input_centered = 405 Sub(input.reg[i], Dup<DataType>(real_zero_as_int32)); 406 F3 fixedpoint_input = 407 F3::FromRaw(input_centered) * inverse_amplitude_normalized; 408 // left shift 409 fixedpoint_input.raw() = ShiftLeft(fixedpoint_input.raw(), 410 28 - inverse_amplitude_neg_exponent); 411 // fixed-point tanh and multiplication 412 F0 fixedpoint_output = tanh(fixedpoint_input) * amplitude_normalized; 413 // right shift 414 DataType int32_output = 415 Add(Dup<DataType>(real_zero_as_int32), 416 ShiftRight(fixedpoint_output.raw(), 31 - amplitude_exponent)); 417 418 DataType mask_if_below_cutoff_min = 419 MaskIfLessThanOrEqual(input.reg[i], Dup<DataType>(input_cutoff_min)); 420 DataType mask_if_above_cutoff_max = MaskIfGreaterThanOrEqual( 421 input.reg[i], Dup<DataType>(input_cutoff_max)); 422 423 output.reg[i] = SelectUsingMask( 424 mask_if_below_cutoff_min, Dup<DataType>(output_min), 425 SelectUsingMask(mask_if_above_cutoff_max, Dup<DataType>(output_max), 426 int32_output)); 427 } 428 return output; 429 } 430 431 const OutputStage& output_stage; 432 std::int32_t input_cutoff_min, input_cutoff_max; 433 std::int32_t output_min, output_max; 434 FixedPoint<DataType, 0> inverse_amplitude_normalized; 435 int inverse_amplitude_neg_exponent; 436 FixedPoint<DataType, 0> amplitude_normalized; 437 int amplitude_exponent; 438 }; 439 440 // OutputPipelineOutputType is a helper to determine the output data type of a 441 // pipeline, for a 442 // given input data type. It is a recursive template; see the explanation on 443 // OutputPipelineEvalImpl below. 444 template <typename OutputPipelineType, int FirstStage, typename InputType, 445 bool StopRecursion = 446 FirstStage == std::tuple_size<OutputPipelineType>::value> 447 struct OutputPipelineOutputType { 448 typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type 449 FirstStageType; 450 typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType 451 FirstStageOutputType; 452 typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage + 1, 453 FirstStageOutputType>::Type Type; 454 }; 455 456 template <typename OutputPipelineType, int FirstStage, typename InputType> 457 struct OutputPipelineOutputType<OutputPipelineType, FirstStage, InputType, 458 true> { 459 typedef InputType Type; 460 }; 461 462 // OutputPipelineEvalImpl is a helper to implement the evaluation of 463 // the whole pipeline. It is a recursive template to implement compile-time 464 // unrolling of the loop over all pipeline stages. The 'FirstStage' parameter 465 // is how we implement recursion: each specialization implements only 466 // evaluation starting at 'FirstStage'. The StopRecursion parameter is just a 467 // helper to implement the termination of the recursion as a partial 468 // specialization below. 469 template <typename OutputPipelineType, int FirstStage, typename InputType, 470 bool StopRecursion = 471 FirstStage == std::tuple_size<OutputPipelineType>::value> 472 struct OutputPipelineEvalImpl { 473 typedef typename std::tuple_element<FirstStage, OutputPipelineType>::type 474 FirstStageType; 475 typedef typename OutputStageEvalImpl<FirstStageType, InputType>::OutputType 476 FirstStageOutputType; 477 typedef typename OutputPipelineOutputType<OutputPipelineType, FirstStage, 478 InputType>::Type OutputType; 479 480 OutputPipelineEvalImpl(const OutputPipelineType& output_pipeline) 481 : head_impl(std::get<FirstStage>(output_pipeline)), 482 tail_impl(output_pipeline) {} 483 484 OutputType Eval(InputType input, int row, int col) const { 485 // Evaluate the first stage. 486 FirstStageOutputType first_stage_output = head_impl.Eval(input, row, col); 487 // Recurse into the remaining stages. 488 return tail_impl.Eval(first_stage_output, row, col); 489 } 490 491 const OutputStageEvalImpl<FirstStageType, InputType> head_impl; 492 const OutputPipelineEvalImpl<OutputPipelineType, FirstStage + 1, 493 FirstStageOutputType> 494 tail_impl; 495 }; 496 497 // Specialization on 'StopRecursion' for terminating the recursion. 498 template <typename OutputPipelineType, int FirstStage, typename InputType> 499 struct OutputPipelineEvalImpl<OutputPipelineType, FirstStage, InputType, true> { 500 OutputPipelineEvalImpl(const OutputPipelineType&) {} 501 502 InputType Eval(InputType input, int, int) const { 503 // Terminating the recursion. 504 return input; 505 } 506 }; 507 508 template <typename RegisterBlockType, typename DstType> 509 struct StoreFinalOutputImpl { 510 static_assert(std::is_same<RegisterBlockType, void>::value, 511 "This generic impl should never be hit"); 512 }; 513 514 template <typename ScalarType, int Rows, int Cols, typename DstType> 515 struct StoreFinalOutputImpl<RegisterBlock<ScalarType, Rows, Cols>, DstType> { 516 using RegisterBlockType = RegisterBlock<ScalarType, Rows, Cols>; 517 static void Run(const RegisterBlockType& src, DstType* dst, int row, 518 int col) { 519 for (int r = 0; r < Rows; r++) { 520 for (int c = 0; c < Cols; c++) { 521 *dst->data(row + r, col + c) = src.buf.reg[r + c * Rows]; 522 } 523 } 524 } 525 }; 526 527 // StoreFinalOutput takes the final value at the end of the output pipeline and 528 // stores it into the destination matrix. It can be specialized for different 529 // data types; the generic implementation here is typically used only for plain 530 // old scalar (not SIMD) types. 531 template <typename RegisterBlockType, typename DstType> 532 void StoreFinalOutput(RegisterBlockType src, DstType* dst, int row, int col) { 533 StoreFinalOutputImpl<RegisterBlockType, DstType>::Run(src, dst, row, col); 534 } 535 536 template <typename OutputPipelineType, typename InputType> 537 struct OutputPipelineExecutor { 538 OutputPipelineExecutor(const OutputPipelineType& output_pipeline) 539 : output_pipeline_eval_impl_(output_pipeline) {} 540 541 // Execute is the entry point into the output pipeline evaluation 542 // code. It should be the only thing that unpack code calls. It takes the 543 // result 544 // of the unpack stage and stores it into the destination matrix. 545 template <typename DstType> 546 void Execute(InputType input, DstType* dst, int src_global_row, 547 int src_global_col, int dst_row, int dst_col) const { 548 // Statically assert that the output pipeline matches the given destination 549 // matrix's scalar type. 550 typedef typename OutputPipelineOutputType< 551 OutputPipelineType, 0, InputType>::Type::BufferType::ScalarType 552 553 ScalarOutputType; 554 typedef typename DstType::Scalar ScalarDstType; 555 static_assert(std::is_same<ScalarOutputType, ScalarDstType>::value, 556 "mismatched destination scalar type and output pipeline"); 557 558 // Evaluate the output pipeline. 559 auto output = 560 output_pipeline_eval_impl_.Eval(input, src_global_row, src_global_col); 561 // Store the result into the destination matrix. 562 StoreFinalOutput(output, dst, dst_row, dst_col); 563 } 564 565 const OutputPipelineEvalImpl<OutputPipelineType, 0, InputType> 566 output_pipeline_eval_impl_; 567 }; 568 569 } // namespace gemmlowp 570 571 #ifdef GEMMLOWP_NEON 572 #include "output_neon.h" 573 #elif defined(GEMMLOWP_SSE4) 574 #include "output_sse.h" 575 #elif defined(GEMMLOWP_MSA) 576 #include "output_msa.h" 577 #endif 578 579 #endif // GEMMLOWP_INTERNAL_OUTPUT_H_ 580