1 /* 2 * Copyright 2019 The gRPC Authors 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package io.grpc.xds; 18 19 import static com.google.common.base.Preconditions.checkArgument; 20 import static com.google.common.base.Preconditions.checkNotNull; 21 22 import com.google.common.annotations.VisibleForTesting; 23 import com.google.common.base.MoreObjects; 24 import com.google.common.primitives.UnsignedInteger; 25 import io.grpc.LoadBalancer.PickResult; 26 import io.grpc.LoadBalancer.PickSubchannelArgs; 27 import io.grpc.LoadBalancer.SubchannelPicker; 28 import java.util.Collections; 29 import java.util.List; 30 import java.util.Objects; 31 32 final class WeightedRandomPicker extends SubchannelPicker { 33 34 @VisibleForTesting 35 final List<WeightedChildPicker> weightedChildPickers; 36 37 private final ThreadSafeRandom random; 38 private final long totalWeight; 39 40 static final class WeightedChildPicker { 41 private final long weight; 42 private final SubchannelPicker childPicker; 43 WeightedChildPicker(long weight, SubchannelPicker childPicker)44 WeightedChildPicker(long weight, SubchannelPicker childPicker) { 45 checkArgument(weight >= 0, "weight is negative"); 46 checkArgument(weight <= UnsignedInteger.MAX_VALUE.longValue(), "weight is too large"); 47 checkNotNull(childPicker, "childPicker is null"); 48 49 this.weight = weight; 50 this.childPicker = childPicker; 51 } 52 getWeight()53 long getWeight() { 54 return weight; 55 } 56 getPicker()57 SubchannelPicker getPicker() { 58 return childPicker; 59 } 60 61 @Override equals(Object o)62 public boolean equals(Object o) { 63 if (this == o) { 64 return true; 65 } 66 if (o == null || getClass() != o.getClass()) { 67 return false; 68 } 69 WeightedChildPicker that = (WeightedChildPicker) o; 70 return weight == that.weight && Objects.equals(childPicker, that.childPicker); 71 } 72 73 @Override hashCode()74 public int hashCode() { 75 return Objects.hash(weight, childPicker); 76 } 77 78 @Override toString()79 public String toString() { 80 return MoreObjects.toStringHelper(this) 81 .add("weight", weight) 82 .add("childPicker", childPicker) 83 .toString(); 84 } 85 } 86 WeightedRandomPicker(List<WeightedChildPicker> weightedChildPickers)87 WeightedRandomPicker(List<WeightedChildPicker> weightedChildPickers) { 88 this(weightedChildPickers, ThreadSafeRandom.ThreadSafeRandomImpl.instance); 89 } 90 91 @VisibleForTesting WeightedRandomPicker(List<WeightedChildPicker> weightedChildPickers, ThreadSafeRandom random)92 WeightedRandomPicker(List<WeightedChildPicker> weightedChildPickers, ThreadSafeRandom random) { 93 checkNotNull(weightedChildPickers, "weightedChildPickers in null"); 94 checkArgument(!weightedChildPickers.isEmpty(), "weightedChildPickers is empty"); 95 96 this.weightedChildPickers = Collections.unmodifiableList(weightedChildPickers); 97 98 long totalWeight = 0; 99 for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { 100 long weight = weightedChildPicker.getWeight(); 101 checkArgument(weight >= 0, "weight is negative"); 102 checkNotNull(weightedChildPicker.getPicker(), "childPicker is null"); 103 totalWeight += weight; 104 } 105 this.totalWeight = totalWeight; 106 checkArgument(totalWeight <= UnsignedInteger.MAX_VALUE.longValue(), 107 "total weight greater than unsigned int can hold"); 108 109 this.random = random; 110 } 111 112 @Override pickSubchannel(PickSubchannelArgs args)113 public final PickResult pickSubchannel(PickSubchannelArgs args) { 114 SubchannelPicker childPicker = null; 115 116 if (totalWeight == 0) { 117 childPicker = 118 weightedChildPickers.get(random.nextInt(weightedChildPickers.size())).getPicker(); 119 } else { 120 long rand = random.nextLong(totalWeight); 121 122 // Find the first idx such that rand < accumulatedWeights[idx] 123 // Not using Arrays.binarySearch for better readability. 124 long accumulatedWeight = 0; 125 for (WeightedChildPicker weightedChildPicker : weightedChildPickers) { 126 accumulatedWeight += weightedChildPicker.getWeight(); 127 if (rand < accumulatedWeight) { 128 childPicker = weightedChildPicker.getPicker(); 129 break; 130 } 131 } 132 checkNotNull(childPicker, "childPicker not found"); 133 } 134 135 return childPicker.pickSubchannel(args); 136 } 137 138 @Override toString()139 public String toString() { 140 return MoreObjects.toStringHelper(this) 141 .add("weightedChildPickers", weightedChildPickers) 142 .add("totalWeight", totalWeight) 143 .toString(); 144 } 145 } 146