• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2016 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 import com.android.jack.annotations.CalledByInvokeCustom;
18 import com.android.jack.annotations.Constant;
19 import com.android.jack.annotations.LinkerMethodHandle;
20 import com.android.jack.annotations.MethodHandleKind;
21 
22 import java.lang.invoke.CallSite;
23 import java.lang.invoke.ConstantCallSite;
24 import java.lang.invoke.MethodHandle;
25 import java.lang.invoke.MethodHandles;
26 import java.lang.invoke.MethodType;
27 
28 import java.lang.Thread;
29 import java.lang.ThreadLocal;
30 import java.util.concurrent.atomic.AtomicInteger;
31 import java.util.concurrent.CyclicBarrier;
32 
33 public class TestInvokeCustomWithConcurrentThreads extends Thread {
34   private static final int NUMBER_OF_THREADS = 16;
35 
36   private static final AtomicInteger nextIndex = new AtomicInteger(0);
37 
38   private static final ThreadLocal<Integer> threadIndex =
39       new ThreadLocal<Integer>() {
40         @Override
41         protected Integer initialValue() {
42           return nextIndex.getAndIncrement();
43         }
44       };
45 
46   // Array of call sites instantiated, one per thread
47   private static final CallSite[] instantiated = new CallSite[NUMBER_OF_THREADS];
48 
49   // Array of counters for how many times each instantiated call site is called
50   private static final AtomicInteger[] called = new AtomicInteger[NUMBER_OF_THREADS];
51 
52   // Array of call site indicies of which call site a thread invoked
53   private static final AtomicInteger[] targetted = new AtomicInteger[NUMBER_OF_THREADS];
54 
55   // Synchronization barrier all threads will wait on in the bootstrap method.
56   private static final CyclicBarrier barrier = new CyclicBarrier(NUMBER_OF_THREADS);
57 
TestInvokeCustomWithConcurrentThreads()58   private TestInvokeCustomWithConcurrentThreads() {}
59 
getThreadIndex()60   private static int getThreadIndex() {
61     return threadIndex.get().intValue();
62   }
63 
notUsed(int x)64   public static int notUsed(int x) {
65     return x;
66   }
67 
68   @Override
run()69   public void run() {
70     int x = setCalled(-1 /* argument dropped */);
71     notUsed(x);
72   }
73 
74   @CalledByInvokeCustom(
75       invokeMethodHandle = @LinkerMethodHandle(kind = MethodHandleKind.INVOKE_STATIC,
76           enclosingType = TestInvokeCustomWithConcurrentThreads.class,
77           name = "linkerMethod",
78           argumentTypes = {MethodHandles.Lookup.class, String.class, MethodType.class}),
79       name = "setCalled",
80       returnType = int.class,
81       argumentTypes = {int.class})
setCalled(int index)82   private static int setCalled(int index) {
83     called[index].getAndIncrement();
84     targetted[getThreadIndex()].set(index);
85     return 0;
86   }
87 
88   @SuppressWarnings("unused")
linkerMethod(MethodHandles.Lookup caller, String name, MethodType methodType)89   private static CallSite linkerMethod(MethodHandles.Lookup caller,
90                                        String name,
91                                        MethodType methodType) throws Throwable {
92     int threadIndex = getThreadIndex();
93     MethodHandle mh =
94         caller.findStatic(TestInvokeCustomWithConcurrentThreads.class, name, methodType);
95     assertEquals(methodType, mh.type());
96     assertEquals(mh.type().parameterCount(), 1);
97     mh = MethodHandles.insertArguments(mh, 0, threadIndex);
98     mh = MethodHandles.dropArguments(mh, 0, int.class);
99     assertEquals(mh.type().parameterCount(), 1);
100     assertEquals(methodType, mh.type());
101 
102     // Wait for all threads to be in this method.
103     // Multiple call sites should be created, but only one
104     // invoked.
105     barrier.await();
106 
107     instantiated[getThreadIndex()] = new ConstantCallSite(mh);
108     return instantiated[getThreadIndex()];
109   }
110 
test()111   public static void test() throws Throwable {
112     // Initialize counters for which call site gets invoked
113     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
114       called[i] = new AtomicInteger(0);
115       targetted[i] = new AtomicInteger(0);
116     }
117 
118     // Run threads that each invoke-custom the call site
119     Thread [] threads = new Thread[NUMBER_OF_THREADS];
120     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
121       threads[i] = new TestInvokeCustomWithConcurrentThreads();
122       threads[i].start();
123     }
124 
125     // Wait for all threads to complete
126     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
127       threads[i].join();
128     }
129 
130     // Check one call site instance won
131     int winners = 0;
132     int votes = 0;
133     for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
134       assertNotEquals(instantiated[i], null);
135       if (called[i].get() != 0) {
136         winners++;
137         votes += called[i].get();
138       }
139     }
140 
141     System.out.println("Winners " + winners + " Votes " + votes);
142 
143     // We assert this below but output details when there's an error as
144     // it's non-deterministic.
145     if (winners != 1) {
146       System.out.println("Threads did not the same call-sites:");
147       for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
148         System.out.format(" Thread % 2d invoked call site instance #%02d\n",
149                           i, targetted[i].get());
150       }
151     }
152 
153     // We assert this below but output details when there's an error as
154     // it's non-deterministic.
155     if (votes != NUMBER_OF_THREADS) {
156       System.out.println("Call-sites invocations :");
157       for (int i = 0; i < NUMBER_OF_THREADS; ++i) {
158         System.out.format(" Call site instance #%02d was invoked % 2d times\n",
159                           i, called[i].get());
160       }
161     }
162 
163     assertEquals(winners, 1);
164     assertEquals(votes, NUMBER_OF_THREADS);
165   }
166 
assertTrue(boolean value)167   public static void assertTrue(boolean value) {
168     if (!value) {
169       throw new AssertionError("assertTrue value: " + value);
170     }
171   }
172 
assertEquals(byte b1, byte b2)173   public static void assertEquals(byte b1, byte b2) {
174     if (b1 == b2) { return; }
175     throw new AssertionError("assertEquals b1: " + b1 + ", b2: " + b2);
176   }
177 
assertEquals(char c1, char c2)178   public static void assertEquals(char c1, char c2) {
179     if (c1 == c2) { return; }
180     throw new AssertionError("assertEquals c1: " + c1 + ", c2: " + c2);
181   }
182 
assertEquals(short s1, short s2)183   public static void assertEquals(short s1, short s2) {
184     if (s1 == s2) { return; }
185     throw new AssertionError("assertEquals s1: " + s1 + ", s2: " + s2);
186   }
187 
assertEquals(int i1, int i2)188   public static void assertEquals(int i1, int i2) {
189     if (i1 == i2) { return; }
190     throw new AssertionError("assertEquals i1: " + i1 + ", i2: " + i2);
191   }
192 
assertEquals(long l1, long l2)193   public static void assertEquals(long l1, long l2) {
194     if (l1 == l2) { return; }
195     throw new AssertionError("assertEquals l1: " + l1 + ", l2: " + l2);
196   }
197 
assertEquals(float f1, float f2)198   public static void assertEquals(float f1, float f2) {
199     if (f1 == f2) { return; }
200     throw new AssertionError("assertEquals f1: " + f1 + ", f2: " + f2);
201   }
202 
assertEquals(double d1, double d2)203   public static void assertEquals(double d1, double d2) {
204     if (d1 == d2) { return; }
205     throw new AssertionError("assertEquals d1: " + d1 + ", d2: " + d2);
206   }
207 
assertEquals(Object o, Object p)208   public static void assertEquals(Object o, Object p) {
209     if (o == p) { return; }
210     if (o != null && p != null && o.equals(p)) { return; }
211     throw new AssertionError("assertEquals: o1: " + o + ", o2: " + p);
212   }
213 
assertNotEquals(Object o, Object p)214   public static void assertNotEquals(Object o, Object p) {
215     if (o != p) { return; }
216     if (o != null && p != null && !o.equals(p)) { return; }
217     throw new AssertionError("assertNotEquals: o1: " + o + ", o2: " + p);
218   }
219 
assertEquals(String s1, String s2)220   public static void assertEquals(String s1, String s2) {
221     if (s1 == s2) {
222       return;
223     }
224 
225     if (s1 != null && s2 != null && s1.equals(s2)) {
226       return;
227     }
228 
229     throw new AssertionError("assertEquals s1: " + s1 + ", s2: " + s2);
230   }
231 }
232