• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2015 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package com.example.android.rs.blasbenchmark;
18 
19 import android.renderscript.*;
20 import android.util.Log;
21 import java.util.Random;
22 import java.lang.Math;
23 
24 public class SGEMMTest extends TestBase {
25 
26     static {
27         System.loadLibrary("gemmdata");
28     }
29 
getData(byte[] a, byte[] b, byte[] c)30     native void getData(byte[] a, byte[] b, byte[] c);
31 
32     ScriptIntrinsicBLAS mBLAS;
33     private Allocation matA;
34     private Allocation matB;
35     private Allocation matC;
36 
37     private int m;
38     private int n;
39     private int k;
40 
41     private int a_offset;
42     private int b_offset;
43     private int mTestSize;
44     private final float allowedError = 0.000001f;
45 
SGEMMTest(int testSize)46     SGEMMTest(int testSize) {
47         mTestSize = testSize;
48     }
49 
createTest()50     public void createTest() {
51         mBLAS = ScriptIntrinsicBLAS.create(mRS);
52         setTest();
53     }
54 
setTest()55     private void setTest() {
56         switch (mTestSize) {
57             case 1:
58                 setTestSmall();
59                 break;
60             case 2:
61                 setTestMedium();
62                 break;
63             case 3:
64                 setTestLarge();
65                 break;
66             default:
67                 break;
68         }
69     }
70 
71     // Calculate the square of the L2 norm of a matrix.
calcL2Norm(float[] input)72     private float calcL2Norm(float[] input) {
73         float l2Norm = 0.f;
74         for (int i = 0; i < input.length; ++i) {
75             l2Norm += input[i] * input[i];
76         }
77         return l2Norm;
78     }
79 
80     // Test whether the error of each element is samller the allowed error range.
testWithTolerance(float[] out, float[] ref)81     private boolean testWithTolerance(float[] out, float[] ref) {
82         float l2NormOut = calcL2Norm(out);
83         float l2NormRef = calcL2Norm(ref);
84         float tolerance = allowedError * (l2NormOut < l2NormRef ? l2NormOut : l2NormRef);
85         tolerance /= m * n;
86         for (int i = 0; i < out.length; ++i) {
87             float err = out[i] - ref[i];
88             float absErr = err * err;
89             if (absErr > tolerance) {
90                 return false;
91             }
92         }
93         return true;
94     }
95 
96     // Transform byte data into float, given a offset.
byteToFloat(byte[] input, int offset)97     private float[] byteToFloat(byte[] input, int offset) {
98         float[] output = new float[input.length];
99         for (int i = 0; i < input.length; ++i) {
100             output[i] = (float)(input[i] - offset);
101         }
102         return output;
103     }
104 
105     // Calculate the reference result for C = A*B
getGEMMResult(int m, int n, int k, float[] a_float, float[] b_float)106     private float[] getGEMMResult(int m, int n, int k, float[] a_float, float[] b_float) {
107         float[] c_float = new float[m * n];
108         for (int j = 0; j < n; j++) {
109             for (int i = 0; i < m; i++) {
110                 float total = 0.f;
111                 for (int l = 0; l < k; l++) {
112                     int a_index = ((i * k) + l);
113                     int b_index = ((l * n) + j);
114                     float mult = a_float[a_index] * b_float[b_index];
115                     total += mult;
116                 }
117                 int c_index = ((i * n) + j);
118                 c_float[c_index] = total;
119             }
120         }
121         return c_float;
122     }
123 
124     // This test multiplies a couple of small float matrices, and compares the
125     // results with java-calculated expectations. The data here is arbitrary.
setTestSmall()126     public void setTestSmall() {
127         m = 2;
128         n = 4;
129         k = 3;
130         a_offset = 0;
131         b_offset = 12;
132 
133         float[] a_float = byteToFloat(new byte[] {
134                 1, 2, 3,
135                 4, 5, 6,
136             }, a_offset);
137 
138         float[] b_float = byteToFloat(new byte[] {
139                 11, 7, 3,
140                 10, 6, 2,
141                 9, 5, 1,
142                 8, 4, 0,
143             }, b_offset);
144 
145         Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
146         Type a_type = builder.setX(k).setY(m).create();
147         Type b_type = builder.setX(n).setY(k).create();
148         Type c_type = builder.setX(n).setY(m).create();
149 
150         matA = Allocation.createTyped(mRS, a_type);
151         matB = Allocation.createTyped(mRS, b_type);
152         matC = Allocation.createTyped(mRS, c_type);
153 
154         matA.copyFrom(a_float);
155         matB.copyFrom(b_float);
156 
157         //During setup, do a sample run to see if the result is correct.
158         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
159                     1.0f, matA, matB, 0.f, matC);
160         float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
161         float[] c_float_out = new float[m * n];
162         matC.copyTo(c_float_out);
163         if (!testWithTolerance(c_float_ref, c_float_out)) {
164             Log.e(TAG, "Result is not correct!");
165             throw new AssertionError("Result is not correct.");
166         }
167     }
168 
169     // This test multiplies another two medium matrices, and compares the
170     // results with the expected values. The data here is arbitrary.
setTestMedium()171     public void setTestMedium() {
172         m = 7;
173         n = 9;
174         k = 23;
175         a_offset = 13;
176         b_offset = 23;
177 
178         float[] a_float = byteToFloat(new byte[] {
179                 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
180                 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1,
181                 1, 23, 2, 22, 3, 21, 4, 20, 5, 19, 6, 18, 7, 17, 8, 16, 9, 15, 10, 14, 11, 13, 12,
182                 23, 1, 22, 2, 21, 3, 20, 4, 19, 5, 18, 6, 17, 7, 16, 8, 15, 9, 14, 10, 13, 11, 12,
183                 1, 1, 1, 1, 1, 1, 1, 1, 1, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12,
184                 3, 1, 4, 1, 5, 8, 2, 3, 1, 14, 11, 15, 18, 12, 13, 11, 14, 11, 15, 18, 12, 13, 11,
185                 8, 0, 5, 8, 1, 3, 7, 5, 7, 13, 10, 23, 13, 11, 17, 23, 12, 19, 17, 13, 14, 10, 19,
186             }, a_offset);
187 
188         float[] b_float = byteToFloat(new byte[] {
189                 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9, 11, 0, 2, 4, 6, 8, 10, 1, 3, 5, 7, 9,
190                 0, 20, 40, 60, 80, 10, 11, 13, 15, 17, 19, 21, 10, 12, 14, 6, 8, 10, 1, 3, 5, 7, 9,
191                 1, 21, 41, 61, 81, 11, 12, 14, 16, 18, 20, 22, 11, 13, 15, 7, 9, 11, 2, 4, 6, 8, 9,
192                 0, 19, 39, 59, 79, 9, 10, 12, 14, 16, 18, 20, 9, 11, 13, 5, 7, 9, 0, 2, 4, 6, 8,
193                 2, 22, 42, 62, 82, 12, 13, 15, 17, 19, 21, 23, 12, 14, 16, 8, 9, 12, 3, 5, 7, 9, 9,
194                 0, 18, 38, 58, 78, 8, 9, 11, 13, 15, 17, 19, 8, 10, 12, 4, 6, 8, 0, 1, 3, 5, 7,
195                 3, 23, 43, 63, 83, 13, 14, 16, 18, 20, 22, 24, 13, 15, 17, 9, 9, 13, 4, 6, 8, 9, 9,
196                 0, 17, 37, 57, 77, 7, 8, 10, 12, 14, 16, 18, 7, 9, 11, 3, 5, 7, 0, 0, 2, 4, 6,
197                 10, 20, 30, 40, 50, 1, 2, 3, 4, 5, 11, 12, 13, 14, 15, 21, 22, 23, 24, 25, 1, 2, 3,
198             }, b_offset);
199 
200         Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
201         Type a_type = builder.setX(k).setY(m).create();
202         Type b_type = builder.setX(n).setY(k).create();
203         Type c_type = builder.setX(n).setY(m).create();
204 
205         matA = Allocation.createTyped(mRS, a_type);
206         matB = Allocation.createTyped(mRS, b_type);
207         matC = Allocation.createTyped(mRS, c_type);
208 
209         matA.copyFrom(a_float);
210         matB.copyFrom(b_float);
211 
212         //During setup, do a sample run to see if the result is correct.
213         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
214                     1.0f, matA, matB, 0.f, matC);
215         float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
216         float[] c_float_out = new float[m * n];
217         matC.copyTo(c_float_out);
218         if (!testWithTolerance(c_float_ref, c_float_out)) {
219             Log.e(TAG, "Result is not correct!");
220             throw new AssertionError("Result is not correct.");
221         }
222     }
223 
224 
225     // This test takes a large set of real data captured from a convolutional
226     // neural network solving a computer vision problem, and runs it through SGEMM.
setTestLarge()227     public void setTestLarge() {
228 
229         m = 256;
230         n = 192;
231         k = 1152;
232         a_offset = 0;
233         b_offset = 84;
234 
235         int a_count = (m * k);
236         int b_count = (n * k);
237         int c_count = (m * n);
238 
239         byte[] a_byte = new byte[a_count];
240         byte[] b_byte = new byte[b_count];
241         byte[] c_byte = new byte[c_count];
242 
243         getData(a_byte, b_byte, c_byte);
244 
245         float[] a_float = byteToFloat(a_byte, a_offset);
246         float[] b_float = byteToFloat(b_byte, b_offset);
247 
248         Type.Builder builder = new Type.Builder(mRS, Element.F32(mRS));
249         Type a_type = builder.setX(k).setY(m).create();
250         Type b_type = builder.setX(n).setY(k).create();
251         Type c_type = builder.setX(n).setY(m).create();
252 
253         matA = Allocation.createTyped(mRS, a_type);
254         matB = Allocation.createTyped(mRS, b_type);
255         matC = Allocation.createTyped(mRS, c_type);
256 
257         matA.copyFrom(a_float);
258         matB.copyFrom(b_float);
259 
260         //During setup, do a sample run to see if the result is correct.
261         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
262                     1.0f, matA, matB, 0.f, matC);
263         float[] c_float_ref = getGEMMResult(m, n, k, a_float, b_float);
264         float[] c_float_out = new float[c_count];
265         matC.copyTo(c_float_out);
266         if (!testWithTolerance(c_float_ref, c_float_out)) {
267             Log.e(TAG, "Result is not correct!");
268             throw new AssertionError("Result is not correct.");
269         }
270     }
271 
runTest()272     public void runTest() {
273         mBLAS.SGEMM(ScriptIntrinsicBLAS.NO_TRANSPOSE, ScriptIntrinsicBLAS.NO_TRANSPOSE,
274                     1.0f, matA, matB, 0.f, matC);
275     }
276 
getTestInfo()277     public String getTestInfo() {
278         return "SGEMM Test: m=" + m + ", n=" + n + ", k=" + k;
279     }
280 }
281