• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_BFLOAT16_H
25 #define ARM_COMPUTE_BFLOAT16_H
26 
27 #include <cstdint>
28 
29 namespace arm_compute
30 {
31 namespace
32 {
33 /** Convert float to bfloat16
34  *
35  * @param[in] v Floating-point value to convert to bfloat
36  *
37  * @return Converted value
38  */
float_to_bf16(const float v)39 inline uint16_t float_to_bf16(const float v)
40 {
41     const uint32_t *fromptr = reinterpret_cast<const uint32_t *>(&v);
42 #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
43     uint16_t res;
44 
45     __asm __volatile(
46         "ldr    s0, [%[fromptr]]\n"
47         ".inst    0x1e634000\n" // BFCVT h0, s0
48         "str    h0, [%[toptr]]\n"
49         :
50         : [fromptr] "r"(fromptr), [toptr] "r"(&res)
51         : "v0", "memory");
52 #else  /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
53     uint16_t       res   = (*fromptr >> 16);
54     const uint16_t error = (*fromptr & 0x0000ffff);
55     uint16_t       bf_l  = res & 0x0001;
56     if((error > 0x8000) || ((error == 0x8000) && (bf_l != 0)))
57     {
58         res += 1;
59     }
60 #endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
61     return res;
62 }
63 
64 /** Convert bfloat16 to float
65  *
66  * @param[in] v Bfloat16 value to convert to float
67  *
68  * @return Converted value
69  */
bf16_to_float(const uint16_t & v)70 inline float bf16_to_float(const uint16_t &v)
71 {
72     const uint32_t lv = (v << 16);
73     const float   *fp = reinterpret_cast<const float *>(&lv);
74 
75     return *fp;
76 }
77 }
78 
79 /** Brain floating point representation class */
80 class bfloat16 final
81 {
82 public:
83     /** Default Constructor */
bfloat16()84     bfloat16()
85         : value(0)
86     {
87     }
88     /** Constructor
89      *
90      * @param[in] v Floating-point value
91      */
bfloat16(float v)92     explicit bfloat16(float v)
93         : value(float_to_bf16(v))
94     {
95     }
96     /** Assignment operator
97      *
98      * @param[in] v Floating point value to assign
99      *
100      * @return The updated object
101      */
102     bfloat16 &operator=(float v)
103     {
104         value = float_to_bf16(v);
105         return *this;
106     }
107     /** Floating point conversion operator
108      *
109      * @return Floating point representation of the value
110      */
111     operator float() const
112     {
113         return bf16_to_float(value);
114     }
115     /** Lowest representative value
116      *
117      * @return Returns the lowest finite value representable by bfloat16
118      */
lowest()119     static bfloat16 lowest()
120     {
121         bfloat16 val;
122         val.value = 0xFF7F;
123         return val;
124     }
125     /** Largest representative value
126      *
127      * @return Returns the largest finite value representable by bfloat16
128      */
max()129     static bfloat16 max()
130     {
131         bfloat16 val;
132         val.value = 0x7F7F;
133         return val;
134     }
135 
136 private:
137     uint16_t value;
138 };
139 } // namespace arm_compute
140 #endif /* ARM_COMPUTE_BFLOAT16_H */