• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016 The gRPC Authors
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package io.grpc.util;
18 
19 import static com.google.common.truth.Truth.assertThat;
20 import static io.grpc.ConnectivityState.CONNECTING;
21 import static io.grpc.ConnectivityState.IDLE;
22 import static io.grpc.ConnectivityState.READY;
23 import static io.grpc.ConnectivityState.SHUTDOWN;
24 import static io.grpc.ConnectivityState.TRANSIENT_FAILURE;
25 import static io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.STATE_INFO;
26 import static org.junit.Assert.assertEquals;
27 import static org.junit.Assert.assertFalse;
28 import static org.junit.Assert.assertNotNull;
29 import static org.junit.Assert.assertNotSame;
30 import static org.junit.Assert.assertNull;
31 import static org.junit.Assert.assertSame;
32 import static org.junit.Assert.assertTrue;
33 import static org.mockito.Matchers.any;
34 import static org.mockito.Matchers.eq;
35 import static org.mockito.Matchers.isA;
36 import static org.mockito.Mockito.atLeast;
37 import static org.mockito.Mockito.doAnswer;
38 import static org.mockito.Mockito.doReturn;
39 import static org.mockito.Mockito.inOrder;
40 import static org.mockito.Mockito.mock;
41 import static org.mockito.Mockito.never;
42 import static org.mockito.Mockito.times;
43 import static org.mockito.Mockito.verify;
44 import static org.mockito.Mockito.verifyNoMoreInteractions;
45 import static org.mockito.Mockito.when;
46 
47 import com.google.common.collect.Lists;
48 import com.google.common.collect.Maps;
49 import io.grpc.Attributes;
50 import io.grpc.ConnectivityState;
51 import io.grpc.ConnectivityStateInfo;
52 import io.grpc.EquivalentAddressGroup;
53 import io.grpc.LoadBalancer;
54 import io.grpc.LoadBalancer.Helper;
55 import io.grpc.LoadBalancer.PickResult;
56 import io.grpc.LoadBalancer.PickSubchannelArgs;
57 import io.grpc.LoadBalancer.Subchannel;
58 import io.grpc.LoadBalancer.SubchannelPicker;
59 import io.grpc.Metadata;
60 import io.grpc.Metadata.Key;
61 import io.grpc.Status;
62 import io.grpc.internal.GrpcAttributes;
63 import io.grpc.util.RoundRobinLoadBalancerFactory.EmptyPicker;
64 import io.grpc.util.RoundRobinLoadBalancerFactory.ReadyPicker;
65 import io.grpc.util.RoundRobinLoadBalancerFactory.Ref;
66 import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer;
67 import io.grpc.util.RoundRobinLoadBalancerFactory.RoundRobinLoadBalancer.StickinessState;
68 import java.net.SocketAddress;
69 import java.util.ArrayList;
70 import java.util.Arrays;
71 import java.util.Collections;
72 import java.util.HashMap;
73 import java.util.Iterator;
74 import java.util.List;
75 import java.util.Map;
76 import org.junit.After;
77 import org.junit.Before;
78 import org.junit.Test;
79 import org.junit.runner.RunWith;
80 import org.junit.runners.JUnit4;
81 import org.mockito.ArgumentCaptor;
82 import org.mockito.Captor;
83 import org.mockito.InOrder;
84 import org.mockito.Mock;
85 import org.mockito.MockitoAnnotations;
86 import org.mockito.invocation.InvocationOnMock;
87 import org.mockito.stubbing.Answer;
88 
89 /** Unit test for {@link RoundRobinLoadBalancerFactory}. */
90 @RunWith(JUnit4.class)
91 public class RoundRobinLoadBalancerTest {
92   private RoundRobinLoadBalancer loadBalancer;
93   private List<EquivalentAddressGroup> servers = Lists.newArrayList();
94   private Map<EquivalentAddressGroup, Subchannel> subchannels = Maps.newLinkedHashMap();
95   private static final Attributes.Key<String> MAJOR_KEY = Attributes.Key.create("major-key");
96   private Attributes affinity = Attributes.newBuilder().set(MAJOR_KEY, "I got the keys").build();
97 
98   @Captor
99   private ArgumentCaptor<SubchannelPicker> pickerCaptor;
100   @Captor
101   private ArgumentCaptor<ConnectivityState> stateCaptor;
102   @Captor
103   private ArgumentCaptor<EquivalentAddressGroup> eagCaptor;
104   @Mock
105   private Helper mockHelper;
106 
107   @Mock // This LoadBalancer doesn't use any of the arg fields, as verified in tearDown().
108   private PickSubchannelArgs mockArgs;
109 
110   @Before
setUp()111   public void setUp() {
112     MockitoAnnotations.initMocks(this);
113 
114     for (int i = 0; i < 3; i++) {
115       SocketAddress addr = new FakeSocketAddress("server" + i);
116       EquivalentAddressGroup eag = new EquivalentAddressGroup(addr);
117       servers.add(eag);
118       Subchannel sc = mock(Subchannel.class);
119       when(sc.getAddresses()).thenReturn(eag);
120       subchannels.put(eag, sc);
121     }
122 
123     when(mockHelper.createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class)))
124         .then(new Answer<Subchannel>() {
125           @Override
126           public Subchannel answer(InvocationOnMock invocation) throws Throwable {
127             Object[] args = invocation.getArguments();
128             Subchannel subchannel = subchannels.get(args[0]);
129             when(subchannel.getAttributes()).thenReturn((Attributes) args[1]);
130             return subchannel;
131           }
132         });
133 
134     loadBalancer = (RoundRobinLoadBalancer) RoundRobinLoadBalancerFactory.getInstance()
135         .newLoadBalancer(mockHelper);
136   }
137 
138   @After
tearDown()139   public void tearDown() throws Exception {
140     verifyNoMoreInteractions(mockArgs);
141   }
142 
143   @Test
pickAfterResolved()144   public void pickAfterResolved() throws Exception {
145     final Subchannel readySubchannel = subchannels.values().iterator().next();
146     loadBalancer.handleResolvedAddressGroups(servers, affinity);
147     loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
148 
149     verify(mockHelper, times(3)).createSubchannel(eagCaptor.capture(),
150         any(Attributes.class));
151 
152     assertThat(eagCaptor.getAllValues()).containsAllIn(subchannels.keySet());
153     for (Subchannel subchannel : subchannels.values()) {
154       verify(subchannel).requestConnection();
155       verify(subchannel, never()).shutdown();
156     }
157 
158     verify(mockHelper, times(2))
159         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
160 
161     assertEquals(CONNECTING, stateCaptor.getAllValues().get(0));
162     assertEquals(READY, stateCaptor.getAllValues().get(1));
163     assertThat(getList(pickerCaptor.getValue())).containsExactly(readySubchannel);
164 
165     verifyNoMoreInteractions(mockHelper);
166   }
167 
168   @Test
pickAfterResolvedUpdatedHosts()169   public void pickAfterResolvedUpdatedHosts() throws Exception {
170     Subchannel removedSubchannel = mock(Subchannel.class);
171     Subchannel oldSubchannel = mock(Subchannel.class);
172     Subchannel newSubchannel = mock(Subchannel.class);
173 
174     for (Subchannel subchannel : Lists.newArrayList(removedSubchannel, oldSubchannel,
175         newSubchannel)) {
176       when(subchannel.getAttributes()).thenReturn(Attributes.newBuilder().set(STATE_INFO,
177           new Ref<ConnectivityStateInfo>(
178               ConnectivityStateInfo.forNonError(READY))).build());
179     }
180 
181     FakeSocketAddress removedAddr = new FakeSocketAddress("removed");
182     FakeSocketAddress oldAddr = new FakeSocketAddress("old");
183     FakeSocketAddress newAddr = new FakeSocketAddress("new");
184 
185     final Map<EquivalentAddressGroup, Subchannel> subchannels2 = Maps.newHashMap();
186     subchannels2.put(new EquivalentAddressGroup(removedAddr), removedSubchannel);
187     subchannels2.put(new EquivalentAddressGroup(oldAddr), oldSubchannel);
188 
189     List<EquivalentAddressGroup> currentServers =
190         Lists.newArrayList(
191             new EquivalentAddressGroup(removedAddr),
192             new EquivalentAddressGroup(oldAddr));
193 
194     doAnswer(new Answer<Subchannel>() {
195       @Override
196       public Subchannel answer(InvocationOnMock invocation) throws Throwable {
197         Object[] args = invocation.getArguments();
198         return subchannels2.get(args[0]);
199       }
200     }).when(mockHelper).createSubchannel(any(EquivalentAddressGroup.class), any(Attributes.class));
201 
202     loadBalancer.handleResolvedAddressGroups(currentServers, affinity);
203 
204     InOrder inOrder = inOrder(mockHelper);
205 
206     inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
207     SubchannelPicker picker = pickerCaptor.getValue();
208     assertThat(getList(picker)).containsExactly(removedSubchannel, oldSubchannel);
209 
210     verify(removedSubchannel, times(1)).requestConnection();
211     verify(oldSubchannel, times(1)).requestConnection();
212 
213     assertThat(loadBalancer.getSubchannels()).containsExactly(removedSubchannel,
214         oldSubchannel);
215 
216     subchannels2.clear();
217     subchannels2.put(new EquivalentAddressGroup(oldAddr), oldSubchannel);
218     subchannels2.put(new EquivalentAddressGroup(newAddr), newSubchannel);
219 
220     List<EquivalentAddressGroup> latestServers =
221         Lists.newArrayList(
222             new EquivalentAddressGroup(oldAddr),
223             new EquivalentAddressGroup(newAddr));
224 
225     loadBalancer.handleResolvedAddressGroups(latestServers, affinity);
226 
227     verify(newSubchannel, times(1)).requestConnection();
228     verify(removedSubchannel, times(1)).shutdown();
229 
230     loadBalancer.handleSubchannelState(removedSubchannel,
231             ConnectivityStateInfo.forNonError(SHUTDOWN));
232 
233     assertThat(loadBalancer.getSubchannels()).containsExactly(oldSubchannel,
234         newSubchannel);
235 
236     verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
237         any(Attributes.class));
238     inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
239 
240     picker = pickerCaptor.getValue();
241     assertThat(getList(picker)).containsExactly(oldSubchannel, newSubchannel);
242 
243     // test going from non-empty to empty
244     loadBalancer.handleResolvedAddressGroups(Collections.<EquivalentAddressGroup>emptyList(),
245             affinity);
246 
247     inOrder.verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
248     assertEquals(PickResult.withNoResult(), pickerCaptor.getValue().pickSubchannel(mockArgs));
249 
250     verifyNoMoreInteractions(mockHelper);
251   }
252 
253   @Test
pickAfterStateChange()254   public void pickAfterStateChange() throws Exception {
255     InOrder inOrder = inOrder(mockHelper);
256     loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
257     Subchannel subchannel = loadBalancer.getSubchannels().iterator().next();
258     Ref<ConnectivityStateInfo> subchannelStateInfo = subchannel.getAttributes().get(
259         STATE_INFO);
260 
261     inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), isA(EmptyPicker.class));
262     assertThat(subchannelStateInfo.value).isEqualTo(ConnectivityStateInfo.forNonError(IDLE));
263 
264     loadBalancer.handleSubchannelState(subchannel,
265         ConnectivityStateInfo.forNonError(READY));
266     inOrder.verify(mockHelper).updateBalancingState(eq(READY), pickerCaptor.capture());
267     assertThat(pickerCaptor.getValue()).isInstanceOf(ReadyPicker.class);
268     assertThat(subchannelStateInfo.value).isEqualTo(
269         ConnectivityStateInfo.forNonError(READY));
270 
271     Status error = Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯");
272     loadBalancer.handleSubchannelState(subchannel,
273         ConnectivityStateInfo.forTransientFailure(error));
274     assertThat(subchannelStateInfo.value).isEqualTo(
275         ConnectivityStateInfo.forTransientFailure(error));
276     inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture());
277     assertThat(pickerCaptor.getValue()).isInstanceOf(EmptyPicker.class);
278 
279     loadBalancer.handleSubchannelState(subchannel,
280         ConnectivityStateInfo.forNonError(IDLE));
281     assertThat(subchannelStateInfo.value).isEqualTo(
282         ConnectivityStateInfo.forNonError(IDLE));
283 
284     verify(subchannel, times(2)).requestConnection();
285     verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
286         any(Attributes.class));
287     verifyNoMoreInteractions(mockHelper);
288   }
289 
nextSubchannel(Subchannel current, List<Subchannel> allSubChannels)290   private Subchannel nextSubchannel(Subchannel current, List<Subchannel> allSubChannels) {
291     return allSubChannels.get((allSubChannels.indexOf(current) + 1) % allSubChannels.size());
292   }
293 
294   @Test
pickerRoundRobin()295   public void pickerRoundRobin() throws Exception {
296     Subchannel subchannel = mock(Subchannel.class);
297     Subchannel subchannel1 = mock(Subchannel.class);
298     Subchannel subchannel2 = mock(Subchannel.class);
299 
300     ReadyPicker picker = new ReadyPicker(Collections.unmodifiableList(
301         Lists.<Subchannel>newArrayList(subchannel, subchannel1, subchannel2)),
302         0 /* startIndex */, null /* stickinessState */);
303 
304     assertThat(picker.getList()).containsExactly(subchannel, subchannel1, subchannel2);
305 
306     assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel());
307     assertEquals(subchannel1, picker.pickSubchannel(mockArgs).getSubchannel());
308     assertEquals(subchannel2, picker.pickSubchannel(mockArgs).getSubchannel());
309     assertEquals(subchannel, picker.pickSubchannel(mockArgs).getSubchannel());
310   }
311 
312   @Test
pickerEmptyList()313   public void pickerEmptyList() throws Exception {
314     SubchannelPicker picker = new EmptyPicker(Status.UNKNOWN);
315 
316     assertEquals(null, picker.pickSubchannel(mockArgs).getSubchannel());
317     assertEquals(Status.UNKNOWN,
318         picker.pickSubchannel(mockArgs).getStatus());
319   }
320 
321   @Test
nameResolutionErrorWithNoChannels()322   public void nameResolutionErrorWithNoChannels() throws Exception {
323     Status error = Status.NOT_FOUND.withDescription("nameResolutionError");
324     loadBalancer.handleNameResolutionError(error);
325     verify(mockHelper).updateBalancingState(eq(TRANSIENT_FAILURE), pickerCaptor.capture());
326     LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
327     assertNull(pickResult.getSubchannel());
328     assertEquals(error, pickResult.getStatus());
329     verifyNoMoreInteractions(mockHelper);
330   }
331 
332   @Test
nameResolutionErrorWithActiveChannels()333   public void nameResolutionErrorWithActiveChannels() throws Exception {
334     final Subchannel readySubchannel = subchannels.values().iterator().next();
335     loadBalancer.handleResolvedAddressGroups(servers, affinity);
336     loadBalancer.handleSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY));
337     loadBalancer.handleNameResolutionError(Status.NOT_FOUND.withDescription("nameResolutionError"));
338 
339     verify(mockHelper, times(3)).createSubchannel(any(EquivalentAddressGroup.class),
340         any(Attributes.class));
341     verify(mockHelper, times(3))
342         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
343 
344     Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
345     assertEquals(CONNECTING, stateIterator.next());
346     assertEquals(READY, stateIterator.next());
347     assertEquals(TRANSIENT_FAILURE, stateIterator.next());
348 
349     LoadBalancer.PickResult pickResult = pickerCaptor.getValue().pickSubchannel(mockArgs);
350     assertEquals(readySubchannel, pickResult.getSubchannel());
351     assertEquals(Status.OK.getCode(), pickResult.getStatus().getCode());
352 
353     LoadBalancer.PickResult pickResult2 = pickerCaptor.getValue().pickSubchannel(mockArgs);
354     assertEquals(readySubchannel, pickResult2.getSubchannel());
355     verifyNoMoreInteractions(mockHelper);
356   }
357 
358   @Test
subchannelStateIsolation()359   public void subchannelStateIsolation() throws Exception {
360     Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
361     Subchannel sc1 = subchannelIterator.next();
362     Subchannel sc2 = subchannelIterator.next();
363     Subchannel sc3 = subchannelIterator.next();
364 
365     loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
366     verify(sc1, times(1)).requestConnection();
367     verify(sc2, times(1)).requestConnection();
368     verify(sc3, times(1)).requestConnection();
369 
370     loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
371     loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(READY));
372     loadBalancer.handleSubchannelState(sc3, ConnectivityStateInfo.forNonError(READY));
373     loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(IDLE));
374     loadBalancer
375         .handleSubchannelState(sc3, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
376 
377     verify(mockHelper, times(6))
378         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
379     Iterator<ConnectivityState> stateIterator = stateCaptor.getAllValues().iterator();
380     Iterator<SubchannelPicker> pickers = pickerCaptor.getAllValues().iterator();
381     // The picker is incrementally updated as subchannels become READY
382     assertEquals(CONNECTING, stateIterator.next());
383     assertThat(pickers.next()).isInstanceOf(EmptyPicker.class);
384     assertEquals(READY, stateIterator.next());
385     assertThat(getList(pickers.next())).containsExactly(sc1);
386     assertEquals(READY, stateIterator.next());
387     assertThat(getList(pickers.next())).containsExactly(sc1, sc2);
388     assertEquals(READY, stateIterator.next());
389     assertThat(getList(pickers.next())).containsExactly(sc1, sc2, sc3);
390     // The IDLE subchannel is dropped from the picker, but a reconnection is requested
391     assertEquals(READY, stateIterator.next());
392     assertThat(getList(pickers.next())).containsExactly(sc1, sc3);
393     verify(sc2, times(2)).requestConnection();
394     // The failing subchannel is dropped from the picker, with no requested reconnect
395     assertEquals(READY, stateIterator.next());
396     assertThat(getList(pickers.next())).containsExactly(sc1);
397     verify(sc3, times(1)).requestConnection();
398     assertThat(stateIterator.hasNext()).isFalse();
399     assertThat(pickers.hasNext()).isFalse();
400   }
401 
402   @Test
noStickinessEnabled_withStickyHeader()403   public void noStickinessEnabled_withStickyHeader() {
404     loadBalancer.handleResolvedAddressGroups(servers, Attributes.EMPTY);
405     for (Subchannel subchannel : subchannels.values()) {
406       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
407     }
408     verify(mockHelper, times(4))
409         .updateBalancingState(any(ConnectivityState.class), pickerCaptor.capture());
410     SubchannelPicker picker = pickerCaptor.getValue();
411 
412     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
413     Metadata headerWithStickinessValue = new Metadata();
414     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
415     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
416 
417     List<Subchannel> allSubchannels = getList(picker);
418     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
419     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
420     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
421     Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel();
422 
423     assertEquals(nextSubchannel(sc1, allSubchannels), sc2);
424     assertEquals(nextSubchannel(sc2, allSubchannels), sc3);
425     assertEquals(nextSubchannel(sc3, allSubchannels), sc1);
426     assertEquals(sc4, sc1);
427 
428     assertNull(loadBalancer.getStickinessMapForTest());
429   }
430 
431   @Test
stickinessEnabled_withoutStickyHeader()432   public void stickinessEnabled_withoutStickyHeader() {
433     Map<String, Object> serviceConfig = new HashMap<String, Object>();
434     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
435     Attributes attributes = Attributes.newBuilder()
436         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
437     loadBalancer.handleResolvedAddressGroups(servers, attributes);
438     for (Subchannel subchannel : subchannels.values()) {
439       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
440     }
441     verify(mockHelper, times(4))
442         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
443     SubchannelPicker picker = pickerCaptor.getValue();
444 
445     doReturn(new Metadata()).when(mockArgs).getHeaders();
446 
447     List<Subchannel> allSubchannels = getList(picker);
448 
449     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
450     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
451     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
452     Subchannel sc4 = picker.pickSubchannel(mockArgs).getSubchannel();
453 
454     assertEquals(nextSubchannel(sc1, allSubchannels), sc2);
455     assertEquals(nextSubchannel(sc2, allSubchannels), sc3);
456     assertEquals(nextSubchannel(sc3, allSubchannels), sc1);
457     assertEquals(sc4, sc1);
458     verify(mockArgs, times(4)).getHeaders();
459     assertNotNull(loadBalancer.getStickinessMapForTest());
460     assertThat(loadBalancer.getStickinessMapForTest()).isEmpty();
461   }
462 
463   @Test
stickinessEnabled_withStickyHeader()464   public void stickinessEnabled_withStickyHeader() {
465     Map<String, Object> serviceConfig = new HashMap<String, Object>();
466     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
467     Attributes attributes = Attributes.newBuilder()
468         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
469     loadBalancer.handleResolvedAddressGroups(servers, attributes);
470     for (Subchannel subchannel : subchannels.values()) {
471       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
472     }
473     verify(mockHelper, times(4))
474         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
475     SubchannelPicker picker = pickerCaptor.getValue();
476 
477     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
478     Metadata headerWithStickinessValue = new Metadata();
479     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
480     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
481 
482     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
483     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
484     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
485     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
486     assertEquals(sc1, picker.pickSubchannel(mockArgs).getSubchannel());
487 
488     verify(mockArgs, atLeast(4)).getHeaders();
489     assertNotNull(loadBalancer.getStickinessMapForTest());
490     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
491   }
492 
493   @Test
stickinessEnabled_withDifferentStickyHeaders()494   public void stickinessEnabled_withDifferentStickyHeaders() {
495     Map<String, Object> serviceConfig = new HashMap<String, Object>();
496     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
497     Attributes attributes = Attributes.newBuilder()
498         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
499     loadBalancer.handleResolvedAddressGroups(servers, attributes);
500     for (Subchannel subchannel : subchannels.values()) {
501       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
502     }
503     verify(mockHelper, times(4))
504         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
505     SubchannelPicker picker = pickerCaptor.getValue();
506 
507     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
508     Metadata headerWithStickinessValue1 = new Metadata();
509     headerWithStickinessValue1.put(stickinessKey, "my-sticky-value");
510 
511     Metadata headerWithStickinessValue2 = new Metadata();
512     headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2");
513 
514     List<Subchannel> allSubchannels = getList(picker);
515 
516     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
517     Subchannel sc1a = picker.pickSubchannel(mockArgs).getSubchannel();
518 
519     doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
520     Subchannel sc2a = picker.pickSubchannel(mockArgs).getSubchannel();
521 
522     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
523     Subchannel sc1b = picker.pickSubchannel(mockArgs).getSubchannel();
524 
525     doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
526     Subchannel sc2b = picker.pickSubchannel(mockArgs).getSubchannel();
527 
528     assertEquals(sc1a, sc1b);
529     assertEquals(sc2a, sc2b);
530     assertEquals(nextSubchannel(sc1a, allSubchannels), sc2a);
531     assertEquals(nextSubchannel(sc1b, allSubchannels), sc2b);
532 
533     verify(mockArgs, atLeast(4)).getHeaders();
534     assertNotNull(loadBalancer.getStickinessMapForTest());
535     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(2);
536   }
537 
538   @Test
stickiness_goToTransientFailure_pick_backToReady()539   public void stickiness_goToTransientFailure_pick_backToReady() {
540     Map<String, Object> serviceConfig = new HashMap<String, Object>();
541     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
542     Attributes attributes = Attributes.newBuilder()
543         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
544     loadBalancer.handleResolvedAddressGroups(servers, attributes);
545     for (Subchannel subchannel : subchannels.values()) {
546       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
547     }
548     verify(mockHelper, times(4))
549         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
550     SubchannelPicker picker = pickerCaptor.getValue();
551 
552     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
553     Metadata headerWithStickinessValue = new Metadata();
554     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
555     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
556 
557     // first pick
558     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
559 
560     // go to transient failure
561     loadBalancer
562         .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
563 
564     verify(mockHelper, times(5))
565         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
566     picker = pickerCaptor.getValue();
567 
568     // second pick
569     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
570 
571     // go back to ready
572     loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
573 
574     verify(mockHelper, times(6))
575         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
576     picker = pickerCaptor.getValue();
577 
578     // third pick
579     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
580     assertEquals(sc2, sc3);
581     verify(mockArgs, atLeast(3)).getHeaders();
582     assertNotNull(loadBalancer.getStickinessMapForTest());
583     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
584   }
585 
586   @Test
stickiness_goToTransientFailure_backToReady_pick()587   public void stickiness_goToTransientFailure_backToReady_pick() {
588     Map<String, Object> serviceConfig = new HashMap<String, Object>();
589     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
590     Attributes attributes = Attributes.newBuilder()
591         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
592     loadBalancer.handleResolvedAddressGroups(servers, attributes);
593     for (Subchannel subchannel : subchannels.values()) {
594       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
595     }
596     verify(mockHelper, times(4))
597         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
598     SubchannelPicker picker = pickerCaptor.getValue();
599 
600     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
601     Metadata headerWithStickinessValue1 = new Metadata();
602     headerWithStickinessValue1.put(stickinessKey, "my-sticky-value");
603     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
604 
605     // first pick
606     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
607 
608     // go to transient failure
609     loadBalancer
610         .handleSubchannelState(sc1, ConnectivityStateInfo.forTransientFailure(Status.UNAVAILABLE));
611 
612     Metadata headerWithStickinessValue2 = new Metadata();
613     headerWithStickinessValue2.put(stickinessKey, "my-sticky-value2");
614     doReturn(headerWithStickinessValue2).when(mockArgs).getHeaders();
615     verify(mockHelper, times(5))
616         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
617     picker = pickerCaptor.getValue();
618 
619     // second pick with a different stickiness value
620     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
621 
622     // go back to ready
623     loadBalancer.handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(READY));
624 
625     doReturn(headerWithStickinessValue1).when(mockArgs).getHeaders();
626     verify(mockHelper, times(6))
627         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
628     picker = pickerCaptor.getValue();
629 
630     // third pick with my-sticky-value1
631     Subchannel sc3 = picker.pickSubchannel(mockArgs).getSubchannel();
632     assertEquals(sc1, sc3);
633 
634     verify(mockArgs, atLeast(3)).getHeaders();
635     assertNotNull(loadBalancer.getStickinessMapForTest());
636     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(2);
637   }
638 
639   @Test
stickiness_oneSubchannelShutdown()640   public void stickiness_oneSubchannelShutdown() {
641     Map<String, Object> serviceConfig = new HashMap<String, Object>();
642     serviceConfig.put("stickinessMetadataKey", "my-sticky-key");
643     Attributes attributes = Attributes.newBuilder()
644         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig).build();
645     loadBalancer.handleResolvedAddressGroups(servers, attributes);
646     for (Subchannel subchannel : subchannels.values()) {
647       loadBalancer.handleSubchannelState(subchannel, ConnectivityStateInfo.forNonError(READY));
648     }
649     verify(mockHelper, times(4))
650         .updateBalancingState(stateCaptor.capture(), pickerCaptor.capture());
651     SubchannelPicker picker = pickerCaptor.getValue();
652 
653     Key<String> stickinessKey = Key.of("my-sticky-key", Metadata.ASCII_STRING_MARSHALLER);
654     Metadata headerWithStickinessValue = new Metadata();
655     headerWithStickinessValue.put(stickinessKey, "my-sticky-value");
656     doReturn(headerWithStickinessValue).when(mockArgs).getHeaders();
657 
658     List<Subchannel> allSubchannels = Lists.newArrayList(getList(picker));
659 
660     Subchannel sc1 = picker.pickSubchannel(mockArgs).getSubchannel();
661 
662     // shutdown channel directly
663     loadBalancer
664         .handleSubchannelState(sc1, ConnectivityStateInfo.forNonError(ConnectivityState.SHUTDOWN));
665 
666     assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);
667 
668     assertEquals(nextSubchannel(sc1, allSubchannels),
669                  picker.pickSubchannel(mockArgs).getSubchannel());
670     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
671     verify(mockArgs, atLeast(2)).getHeaders();
672 
673     Subchannel sc2 = picker.pickSubchannel(mockArgs).getSubchannel();
674 
675     assertEquals(sc2, loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);
676 
677     // shutdown channel via name resolver change
678     List<EquivalentAddressGroup> newServers = new ArrayList<>(servers);
679     newServers.remove(sc2.getAddresses());
680 
681     loadBalancer.handleResolvedAddressGroups(newServers, attributes);
682 
683     verify(sc2, times(1)).shutdown();
684 
685     loadBalancer.handleSubchannelState(sc2, ConnectivityStateInfo.forNonError(SHUTDOWN));
686 
687     assertNull(loadBalancer.getStickinessMapForTest().get("my-sticky-value").value);
688 
689     assertEquals(nextSubchannel(sc2, allSubchannels),
690             picker.pickSubchannel(mockArgs).getSubchannel());
691     assertThat(loadBalancer.getStickinessMapForTest()).hasSize(1);
692     verify(mockArgs, atLeast(2)).getHeaders();
693   }
694 
695   @Test
stickiness_resolveTwice_metadataKeyChanged()696   public void stickiness_resolveTwice_metadataKeyChanged() {
697     Map<String, Object> serviceConfig1 = new HashMap<String, Object>();
698     serviceConfig1.put("stickinessMetadataKey", "my-sticky-key1");
699     Attributes attributes1 = Attributes.newBuilder()
700         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig1).build();
701     loadBalancer.handleResolvedAddressGroups(servers, attributes1);
702     Map<String, ?> stickinessMap1 = loadBalancer.getStickinessMapForTest();
703 
704     Map<String, Object> serviceConfig2 = new HashMap<String, Object>();
705     serviceConfig2.put("stickinessMetadataKey", "my-sticky-key2");
706     Attributes attributes2 = Attributes.newBuilder()
707         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig2).build();
708     loadBalancer.handleResolvedAddressGroups(servers, attributes2);
709     Map<String, ?> stickinessMap2 = loadBalancer.getStickinessMapForTest();
710 
711     assertNotSame(stickinessMap1, stickinessMap2);
712   }
713 
714   @Test
stickiness_resolveTwice_metadataKeyUnChanged()715   public void stickiness_resolveTwice_metadataKeyUnChanged() {
716     Map<String, Object> serviceConfig1 = new HashMap<String, Object>();
717     serviceConfig1.put("stickinessMetadataKey", "my-sticky-key1");
718     Attributes attributes1 = Attributes.newBuilder()
719         .set(GrpcAttributes.NAME_RESOLVER_SERVICE_CONFIG, serviceConfig1).build();
720     loadBalancer.handleResolvedAddressGroups(servers, attributes1);
721     Map<String, ?> stickinessMap1 = loadBalancer.getStickinessMapForTest();
722 
723     loadBalancer.handleResolvedAddressGroups(servers, attributes1);
724     Map<String, ?> stickinessMap2 = loadBalancer.getStickinessMapForTest();
725 
726     assertSame(stickinessMap1, stickinessMap2);
727   }
728 
729   @Test(expected = IllegalArgumentException.class)
readyPicker_emptyList()730   public void readyPicker_emptyList() {
731     // ready picker list must be non-empty
732     new ReadyPicker(Collections.<Subchannel>emptyList(), 0, null);
733   }
734 
735   @Test
internalPickerComparisons()736   public void internalPickerComparisons() {
737     EmptyPicker emptyOk1 = new EmptyPicker(Status.OK);
738     EmptyPicker emptyOk2 = new EmptyPicker(Status.OK.withDescription("different OK"));
739     EmptyPicker emptyErr = new EmptyPicker(Status.UNKNOWN.withDescription("¯\\_(ツ)_//¯"));
740 
741     Iterator<Subchannel> subchannelIterator = subchannels.values().iterator();
742     Subchannel sc1 = subchannelIterator.next();
743     Subchannel sc2 = subchannelIterator.next();
744     StickinessState stickinessState = new StickinessState("stick-key");
745     ReadyPicker ready1 = new ReadyPicker(Arrays.asList(sc1, sc2), 0, null);
746     ReadyPicker ready2 = new ReadyPicker(Arrays.asList(sc1), 0, null);
747     ReadyPicker ready3 = new ReadyPicker(Arrays.asList(sc2, sc1), 1, null);
748     ReadyPicker ready4 = new ReadyPicker(Arrays.asList(sc1, sc2), 1, stickinessState);
749     ReadyPicker ready5 = new ReadyPicker(Arrays.asList(sc2, sc1), 0, stickinessState);
750 
751     assertTrue(emptyOk1.isEquivalentTo(emptyOk2));
752     assertFalse(emptyOk1.isEquivalentTo(emptyErr));
753     assertFalse(ready1.isEquivalentTo(ready2));
754     assertTrue(ready1.isEquivalentTo(ready3));
755     assertFalse(ready3.isEquivalentTo(ready4));
756     assertTrue(ready4.isEquivalentTo(ready5));
757     assertFalse(emptyOk1.isEquivalentTo(ready1));
758     assertFalse(ready1.isEquivalentTo(emptyOk1));
759   }
760 
761 
getList(SubchannelPicker picker)762   private static List<Subchannel> getList(SubchannelPicker picker) {
763     return picker instanceof ReadyPicker ? ((ReadyPicker) picker).getList() :
764         Collections.<Subchannel>emptyList();
765   }
766 
767   private static class FakeSocketAddress extends SocketAddress {
768     final String name;
769 
FakeSocketAddress(String name)770     FakeSocketAddress(String name) {
771       this.name = name;
772     }
773 
774     @Override
toString()775     public String toString() {
776       return "FakeSocketAddress-" + name;
777     }
778   }
779 }
780