1 // Copyright 2015 Google Inc. All Rights Reserved.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 // fixedpoint_neon.h: optimized NEON specializations of the templates
16 // in fixedpoint.h.
17
18 #ifndef GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
19 #define GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
20
21 #include "fixedpoint.h"
22
23 #include <arm_neon.h>
24
25 namespace gemmlowp {
26
27 template <>
BitAnd(int32x4_t a,int32x4_t b)28 inline int32x4_t BitAnd(int32x4_t a, int32x4_t b) {
29 return vandq_s32(a, b);
30 }
31
32 template <>
BitOr(int32x4_t a,int32x4_t b)33 inline int32x4_t BitOr(int32x4_t a, int32x4_t b) {
34 return vorrq_s32(a, b);
35 }
36
37 template <>
BitXor(int32x4_t a,int32x4_t b)38 inline int32x4_t BitXor(int32x4_t a, int32x4_t b) {
39 return veorq_s32(a, b);
40 }
41
42 template <>
BitNot(int32x4_t a)43 inline int32x4_t BitNot(int32x4_t a) {
44 return veorq_s32(a, vdupq_n_s32(-1));
45 }
46
47 template <>
Add(int32x4_t a,int32x4_t b)48 inline int32x4_t Add(int32x4_t a, int32x4_t b) {
49 return vaddq_s32(a, b);
50 }
51
52 template <>
Sub(int32x4_t a,int32x4_t b)53 inline int32x4_t Sub(int32x4_t a, int32x4_t b) {
54 return vsubq_s32(a, b);
55 }
56
57 template <>
Neg(int32x4_t a)58 inline int32x4_t Neg(int32x4_t a) {
59 return vnegq_s32(a);
60 }
61
62 template <>
ShiftLeft(int32x4_t a,int offset)63 inline int32x4_t ShiftLeft(int32x4_t a, int offset) {
64 return vshlq_s32(a, vdupq_n_s32(offset));
65 }
66
67 template <>
ShiftRight(int32x4_t a,int offset)68 inline int32x4_t ShiftRight(int32x4_t a, int offset) {
69 return vshlq_s32(a, vdupq_n_s32(-offset));
70 }
71
72 template <>
SelectUsingMask(int32x4_t if_mask,int32x4_t then_val,int32x4_t else_val)73 inline int32x4_t SelectUsingMask(int32x4_t if_mask, int32x4_t then_val,
74 int32x4_t else_val) {
75 return vbslq_s32(vreinterpretq_u32_s32(if_mask), then_val, else_val);
76 }
77
78 template <>
MaskIfEqual(int32x4_t a,int32x4_t b)79 inline int32x4_t MaskIfEqual(int32x4_t a, int32x4_t b) {
80 return vreinterpretq_s32_u32(vceqq_s32(a, b));
81 }
82
83 template <>
MaskIfNotEqual(int32x4_t a,int32x4_t b)84 inline int32x4_t MaskIfNotEqual(int32x4_t a, int32x4_t b) {
85 return BitNot(MaskIfEqual(a, b));
86 }
87
88 template <>
MaskIfZero(int32x4_t a)89 inline int32x4_t MaskIfZero(int32x4_t a) {
90 return MaskIfEqual(a, vdupq_n_s32(0));
91 }
92
93 template <>
MaskIfNonZero(int32x4_t a)94 inline int32x4_t MaskIfNonZero(int32x4_t a) {
95 return vreinterpretq_s32_u32(vtstq_s32(a, a));
96 }
97
98 template <>
MaskIfGreaterThan(int32x4_t a,int32x4_t b)99 inline int32x4_t MaskIfGreaterThan(int32x4_t a, int32x4_t b) {
100 return vreinterpretq_s32_u32(vcgtq_s32(a, b));
101 }
102
103 template <>
MaskIfGreaterThanOrEqual(int32x4_t a,int32x4_t b)104 inline int32x4_t MaskIfGreaterThanOrEqual(int32x4_t a, int32x4_t b) {
105 return vreinterpretq_s32_u32(vcgeq_s32(a, b));
106 }
107
108 template <>
MaskIfLessThan(int32x4_t a,int32x4_t b)109 inline int32x4_t MaskIfLessThan(int32x4_t a, int32x4_t b) {
110 return vreinterpretq_s32_u32(vcltq_s32(a, b));
111 }
112
113 template <>
MaskIfLessThanOrEqual(int32x4_t a,int32x4_t b)114 inline int32x4_t MaskIfLessThanOrEqual(int32x4_t a, int32x4_t b) {
115 return vreinterpretq_s32_u32(vcleq_s32(a, b));
116 }
117
118 template <>
All(int32x4_t a)119 inline bool All(int32x4_t a) {
120 a = vandq_s32(a, vextq_s32(a, a, 1));
121 a = vandq_s32(a, vextq_s32(a, a, 2));
122 return vgetq_lane_s32(a, 0);
123 }
124
125 template <>
Any(int32x4_t a)126 inline bool Any(int32x4_t a) {
127 a = vorrq_s32(a, vextq_s32(a, a, 1));
128 a = vorrq_s32(a, vextq_s32(a, a, 2));
129 return vgetq_lane_s32(a, 0);
130 }
131
132 template <>
RoundingHalfSum(int32x4_t a,int32x4_t b)133 inline int32x4_t RoundingHalfSum(int32x4_t a, int32x4_t b) {
134 return vrhaddq_s32(a, b);
135 }
136
137 template <>
SaturatingRoundingDoublingHighMul(int32x4_t a,int32x4_t b)138 inline int32x4_t SaturatingRoundingDoublingHighMul(int32x4_t a, int32x4_t b) {
139 return vqrdmulhq_s32(a, b);
140 }
141
142 template <int Exponent>
143 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, 1> {
144 static int32x4_t eval(int32x4_t x) { return vqshlq_n_s32(x, Exponent); }
145 };
146
147 template <int Exponent>
148 struct ImplSaturatingRoundingMultiplyByPOT<Exponent, int32x4_t, -1> {
149 static int32x4_t eval(int32x4_t x) { return vrshrq_n_s32(x, -Exponent); }
150 };
151
152 template <>
153 struct FixedPointRawTypeTraits<int32x4_t> {
154 typedef int32_t ScalarRawType;
155 static const int kLanes = 4;
156 };
157
158 template <>
159 inline int32x4_t Dup<int32x4_t>(int32_t x) {
160 return vdupq_n_s32(x);
161 }
162
163 } // end namespace gemmlowp
164
165 #endif // GEMMLOWP_INTERNAL_FIXEDPOINT_NEON_H_
166