• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) 2025 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  * http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 package org.koalaui.interop;
16 
17 import java.nio.ByteBuffer;
18 import java.nio.ByteOrder;
19 import java.nio.IntBuffer;
20 import java.util.function.Function;
21 
22 import org.koalaui.arkoala.NativeModule;
23 
24 public class CallbackTests {
25     // Todo: where tests will be located?
26 
27     public class TestUtils {
28         // Todo: where test utils will be located?
assertEquals(String name, T expected, T actual)29         public static <T> void assertEquals(String name, T expected, T actual) {
30             if (!expected.equals(actual)) {
31                 System.out.printf("TEST %s FAIL:\n  EXPECTED \"%s\"\n  ACTUAL   \"%s\"\n", name, expected.toString(), actual.toString());
32             } else {
33                 System.out.printf("TEST %s PASS\n", name);
34             }
35         }
36 
assertThrows(String name, Function<Void, T> fn)37         public static <T> void assertThrows(String name, Function<Void, T> fn) {
38             boolean caught = false;
39             try {
40                 fn.apply(null);
41             } catch (Throwable e) {
42                 caught = true;
43             }
44             if (!caught) {
45                 System.out.printf("TEST %s FAIL:\n  No exception thrown\n", name);
46             } else {
47                 System.out.printf("TEST %s PASS\n", name);
48             }
49         }
50     }
51 
checkCallback()52     public static void checkCallback() {
53         Integer id1 = CallbackRegistry.wrap(new CallbackType() {
54             @Override
55             public int apply(byte[] args, int length) {
56                 return 2024;
57             }
58         });
59         Integer id2 = CallbackRegistry.wrap(new CallbackType() {
60             @Override
61             public int apply(byte[] args, int length) {
62                 return 2025;
63             }
64         });
65 
66         TestUtils.assertEquals("Call callback 1", 2024, CallbackRegistry.call(id1, new byte[] {}, 0));
67         TestUtils.assertEquals("Call callback 2", 2025, CallbackRegistry.call(id2, new byte[] {}, 0));
68         TestUtils.assertThrows("Call disposed callback 1", new Function<Void, Integer>() {
69             @Override
70             public Integer apply(Void v) {
71                 return CallbackRegistry.call(id1, new byte[] { }, 0);
72             }
73         });
74         TestUtils.assertThrows("Call callback 0", new Function<Void, Integer>() {
75             @Override
76             public Integer apply(Void v) {
77                 return CallbackRegistry.call(0, new byte[] { 2, 4, 6, 8 }, 4);
78             }
79         });
80     }
81 
checkNativeCallback()82     public static void checkNativeCallback() {
83         Integer id1 = CallbackRegistry.wrap(new CallbackType() {
84             @Override
85             public int apply(byte[] args, int length) {
86                 return 123456;
87             }
88         });
89         TestUtils.assertEquals("NativeCallback without args", 123456, NativeModule._TestCallIntNoArgs(id1));
90         TestUtils.assertThrows("NativeCallback without args called again", new Function<Void, Integer>() {
91             @Override
92             public Integer apply(Void v) {
93                 return CallbackRegistry.call(id1, new byte[] { }, 0);
94             }
95         });
96         TestUtils.assertThrows("NativeCallback without args called again from native", new Function<Void, Integer>() {
97             @Override
98             public Integer apply(Void v) {
99                 return NativeModule._TestCallIntNoArgs(id1);
100             }
101         });
102 
103         Integer id2 = CallbackRegistry.wrap(new CallbackType() {
104             @Override
105             public int apply(byte[] args, int length) {
106                 ByteBuffer buffer = ByteBuffer.wrap(args);
107                 buffer.order(ByteOrder.LITTLE_ENDIAN);
108                 IntBuffer intBuffer = buffer.asIntBuffer();
109                 int sum = 0;
110                 for (int i = 0; i < length / 4; i++) {
111                     sum += intBuffer.get(i);
112                 }
113                 return sum;
114             }
115         });
116         int[] arr2 = new int[] { 100, 200, 300, -1000 };
117         TestUtils.assertEquals("NativeCallback Int32Array sum", -400, NativeModule._TestCallIntIntArraySum(id2, arr2, arr2.length));
118 
119         Integer id3 = CallbackRegistry.wrap(new CallbackType() {
120             @Override
121             public int apply(byte[] args, int length) {
122                 ByteBuffer buffer = ByteBuffer.wrap(args);
123                 buffer.order(ByteOrder.LITTLE_ENDIAN);
124                 IntBuffer intBuffer = buffer.asIntBuffer();
125                 for (int i = 1; i < length / 4; i++) {
126                     intBuffer.put(i, intBuffer.get(i) + intBuffer.get(i - 1));
127                 }
128                 return 0;
129             }
130         });
131         int[] arr3 = new int[] { 100, 200, 300, -1000 };
132         NativeModule._TestCallVoidIntArrayPrefixSum(id3, arr3, arr3.length);
133         TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [0]", 100, arr3[0]);
134         TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [1]", 300, arr3[1]);
135         TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [2]", 600, arr3[2]);
136         TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [3]", -400, arr3[3]);
137 
138         long start = System.currentTimeMillis();
139         Integer id4 = CallbackRegistry.wrap(new CallbackType() {
140             @Override
141             public int apply(byte[] args, int length) {
142                 ByteBuffer buffer = ByteBuffer.wrap(args);
143                 buffer.order(ByteOrder.LITTLE_ENDIAN);
144                 IntBuffer intBuffer = buffer.asIntBuffer();
145                 intBuffer.put(1, intBuffer.get(1) + 1);
146                 if (intBuffer.get(0) + intBuffer.get(1) < intBuffer.get(2)) {
147                     return NativeModule._TestCallIntRecursiveCallback(id3 + 1, args, args.length);
148                 }
149                 return 1;
150             }
151         }, false);
152         TestUtils.assertEquals("NativeCallback prepare recursive callback test", id4, id3 + 1);
153         int depth = 500;
154         int count = 100;
155         for (int i = 0; i < count; i++) {
156             int length = 12;
157             byte[] args = new byte[length];
158             IntBuffer args32 = ByteBuffer.wrap(args).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer();
159             args32.put(2, depth);
160             NativeModule._TestCallIntRecursiveCallback(id4, args, args.length);
161             if (i == 0) {
162                 TestUtils.assertEquals("NativeCallback Recursive [0]", (depth + 1) / 2, args32.get(0));
163                 TestUtils.assertEquals("NativeCallback Recursive [1]", depth / 2, args32.get(1));
164             }
165         }
166         long passed = System.currentTimeMillis() - start;
167         System.out.println("recursive native callback: " + String.valueOf(passed) + "ms for " + depth * count + " callbacks, " + Math.round((double)passed / (depth * count) * 1000000) + "ms per 1M callbacks");
168 
169         Integer id5 = CallbackRegistry.wrap(new CallbackType() {
170             @Override
171             public int apply(byte[] args, int length) {
172                 int sum = 0;
173                 for (int i = 0; i < length; i++) {
174                     sum += args[i];
175                 }
176                 return sum;
177             }
178         }, false);
179         NativeModule._TestCallIntMemory(id5, 1000);
180     }
181 }
182