1 /* 2 * Copyright (c) 2025 Huawei Device Co., Ltd. 3 * Licensed under the Apache License, Version 2.0 (the "License"); 4 * you may not use this file except in compliance with the License. 5 * You may obtain a copy of the License at 6 * 7 * http://www.apache.org/licenses/LICENSE-2.0 8 * 9 * Unless required by applicable law or agreed to in writing, software 10 * distributed under the License is distributed on an "AS IS" BASIS, 11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 * See the License for the specific language governing permissions and 13 * limitations under the License. 14 */ 15 package org.koalaui.interop; 16 17 import java.nio.ByteBuffer; 18 import java.nio.ByteOrder; 19 import java.nio.IntBuffer; 20 import java.util.function.Function; 21 22 import org.koalaui.arkoala.NativeModule; 23 24 public class CallbackTests { 25 // Todo: where tests will be located? 26 27 public class TestUtils { 28 // Todo: where test utils will be located? assertEquals(String name, T expected, T actual)29 public static <T> void assertEquals(String name, T expected, T actual) { 30 if (!expected.equals(actual)) { 31 System.out.printf("TEST %s FAIL:\n EXPECTED \"%s\"\n ACTUAL \"%s\"\n", name, expected.toString(), actual.toString()); 32 } else { 33 System.out.printf("TEST %s PASS\n", name); 34 } 35 } 36 assertThrows(String name, Function<Void, T> fn)37 public static <T> void assertThrows(String name, Function<Void, T> fn) { 38 boolean caught = false; 39 try { 40 fn.apply(null); 41 } catch (Throwable e) { 42 caught = true; 43 } 44 if (!caught) { 45 System.out.printf("TEST %s FAIL:\n No exception thrown\n", name); 46 } else { 47 System.out.printf("TEST %s PASS\n", name); 48 } 49 } 50 } 51 checkCallback()52 public static void checkCallback() { 53 Integer id1 = CallbackRegistry.wrap(new CallbackType() { 54 @Override 55 public int apply(byte[] args, int length) { 56 return 2024; 57 } 58 }); 59 Integer id2 = CallbackRegistry.wrap(new CallbackType() { 60 @Override 61 public int apply(byte[] args, int length) { 62 return 2025; 63 } 64 }); 65 66 TestUtils.assertEquals("Call callback 1", 2024, CallbackRegistry.call(id1, new byte[] {}, 0)); 67 TestUtils.assertEquals("Call callback 2", 2025, CallbackRegistry.call(id2, new byte[] {}, 0)); 68 TestUtils.assertThrows("Call disposed callback 1", new Function<Void, Integer>() { 69 @Override 70 public Integer apply(Void v) { 71 return CallbackRegistry.call(id1, new byte[] { }, 0); 72 } 73 }); 74 TestUtils.assertThrows("Call callback 0", new Function<Void, Integer>() { 75 @Override 76 public Integer apply(Void v) { 77 return CallbackRegistry.call(0, new byte[] { 2, 4, 6, 8 }, 4); 78 } 79 }); 80 } 81 checkNativeCallback()82 public static void checkNativeCallback() { 83 Integer id1 = CallbackRegistry.wrap(new CallbackType() { 84 @Override 85 public int apply(byte[] args, int length) { 86 return 123456; 87 } 88 }); 89 TestUtils.assertEquals("NativeCallback without args", 123456, NativeModule._TestCallIntNoArgs(id1)); 90 TestUtils.assertThrows("NativeCallback without args called again", new Function<Void, Integer>() { 91 @Override 92 public Integer apply(Void v) { 93 return CallbackRegistry.call(id1, new byte[] { }, 0); 94 } 95 }); 96 TestUtils.assertThrows("NativeCallback without args called again from native", new Function<Void, Integer>() { 97 @Override 98 public Integer apply(Void v) { 99 return NativeModule._TestCallIntNoArgs(id1); 100 } 101 }); 102 103 Integer id2 = CallbackRegistry.wrap(new CallbackType() { 104 @Override 105 public int apply(byte[] args, int length) { 106 ByteBuffer buffer = ByteBuffer.wrap(args); 107 buffer.order(ByteOrder.LITTLE_ENDIAN); 108 IntBuffer intBuffer = buffer.asIntBuffer(); 109 int sum = 0; 110 for (int i = 0; i < length / 4; i++) { 111 sum += intBuffer.get(i); 112 } 113 return sum; 114 } 115 }); 116 int[] arr2 = new int[] { 100, 200, 300, -1000 }; 117 TestUtils.assertEquals("NativeCallback Int32Array sum", -400, NativeModule._TestCallIntIntArraySum(id2, arr2, arr2.length)); 118 119 Integer id3 = CallbackRegistry.wrap(new CallbackType() { 120 @Override 121 public int apply(byte[] args, int length) { 122 ByteBuffer buffer = ByteBuffer.wrap(args); 123 buffer.order(ByteOrder.LITTLE_ENDIAN); 124 IntBuffer intBuffer = buffer.asIntBuffer(); 125 for (int i = 1; i < length / 4; i++) { 126 intBuffer.put(i, intBuffer.get(i) + intBuffer.get(i - 1)); 127 } 128 return 0; 129 } 130 }); 131 int[] arr3 = new int[] { 100, 200, 300, -1000 }; 132 NativeModule._TestCallVoidIntArrayPrefixSum(id3, arr3, arr3.length); 133 TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [0]", 100, arr3[0]); 134 TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [1]", 300, arr3[1]); 135 TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [2]", 600, arr3[2]); 136 TestUtils.assertEquals("NativeCallback Int32Array PrefixSum [3]", -400, arr3[3]); 137 138 long start = System.currentTimeMillis(); 139 Integer id4 = CallbackRegistry.wrap(new CallbackType() { 140 @Override 141 public int apply(byte[] args, int length) { 142 ByteBuffer buffer = ByteBuffer.wrap(args); 143 buffer.order(ByteOrder.LITTLE_ENDIAN); 144 IntBuffer intBuffer = buffer.asIntBuffer(); 145 intBuffer.put(1, intBuffer.get(1) + 1); 146 if (intBuffer.get(0) + intBuffer.get(1) < intBuffer.get(2)) { 147 return NativeModule._TestCallIntRecursiveCallback(id3 + 1, args, args.length); 148 } 149 return 1; 150 } 151 }, false); 152 TestUtils.assertEquals("NativeCallback prepare recursive callback test", id4, id3 + 1); 153 int depth = 500; 154 int count = 100; 155 for (int i = 0; i < count; i++) { 156 int length = 12; 157 byte[] args = new byte[length]; 158 IntBuffer args32 = ByteBuffer.wrap(args).order(ByteOrder.LITTLE_ENDIAN).asIntBuffer(); 159 args32.put(2, depth); 160 NativeModule._TestCallIntRecursiveCallback(id4, args, args.length); 161 if (i == 0) { 162 TestUtils.assertEquals("NativeCallback Recursive [0]", (depth + 1) / 2, args32.get(0)); 163 TestUtils.assertEquals("NativeCallback Recursive [1]", depth / 2, args32.get(1)); 164 } 165 } 166 long passed = System.currentTimeMillis() - start; 167 System.out.println("recursive native callback: " + String.valueOf(passed) + "ms for " + depth * count + " callbacks, " + Math.round((double)passed / (depth * count) * 1000000) + "ms per 1M callbacks"); 168 169 Integer id5 = CallbackRegistry.wrap(new CallbackType() { 170 @Override 171 public int apply(byte[] args, int length) { 172 int sum = 0; 173 for (int i = 0; i < length; i++) { 174 sum += args[i]; 175 } 176 return sum; 177 } 178 }, false); 179 NativeModule._TestCallIntMemory(id5, 1000); 180 } 181 } 182