/* * Copyright (C) 2010 The Guava Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package com.google.common.collect.testing; import com.google.common.annotations.GwtIncompatible; import java.io.Serializable; import java.util.AbstractSet; import java.util.Collection; import java.util.Comparator; import java.util.Iterator; import java.util.Map; import java.util.NavigableMap; import java.util.NavigableSet; import java.util.Set; import java.util.SortedMap; import java.util.TreeMap; /** * A wrapper around {@code TreeMap} that aggressively checks to see if keys are mutually comparable. * This implementation passes the navigable map test suites. * * @author Louis Wasserman */ @GwtIncompatible public final class SafeTreeMap implements Serializable, NavigableMap { @SuppressWarnings("unchecked") private static final Comparator NATURAL_ORDER = new Comparator() { @Override public int compare(Object o1, Object o2) { return ((Comparable) o1).compareTo(o2); } }; private final NavigableMap delegate; public SafeTreeMap() { this(new TreeMap()); } public SafeTreeMap(Comparator comparator) { this(new TreeMap(comparator)); } public SafeTreeMap(Map map) { this(new TreeMap(map)); } public SafeTreeMap(SortedMap map) { this(new TreeMap(map)); } private SafeTreeMap(NavigableMap delegate) { this.delegate = delegate; if (delegate == null) { throw new NullPointerException(); } for (K k : keySet()) { checkValid(k); } } @Override public Entry ceilingEntry(K key) { return delegate.ceilingEntry(checkValid(key)); } @Override public K ceilingKey(K key) { return delegate.ceilingKey(checkValid(key)); } @Override public void clear() { delegate.clear(); } @SuppressWarnings("unchecked") @Override public Comparator comparator() { Comparator comparator = delegate.comparator(); if (comparator == null) { comparator = (Comparator) NATURAL_ORDER; } return comparator; } @Override public boolean containsKey(Object key) { try { return delegate.containsKey(checkValid(key)); } catch (NullPointerException | ClassCastException e) { return false; } } @Override public boolean containsValue(Object value) { return delegate.containsValue(value); } @Override public NavigableSet descendingKeySet() { return delegate.descendingKeySet(); } @Override public NavigableMap descendingMap() { return new SafeTreeMap<>(delegate.descendingMap()); } @Override public Set> entrySet() { return new AbstractSet>() { private Set> delegate() { return delegate.entrySet(); } @Override public boolean contains(Object object) { try { return delegate().contains(object); } catch (NullPointerException | ClassCastException e) { return false; } } @Override public Iterator> iterator() { return delegate().iterator(); } @Override public int size() { return delegate().size(); } @Override public boolean remove(Object o) { return delegate().remove(o); } @Override public void clear() { delegate().clear(); } }; } @Override public Entry firstEntry() { return delegate.firstEntry(); } @Override public K firstKey() { return delegate.firstKey(); } @Override public Entry floorEntry(K key) { return delegate.floorEntry(checkValid(key)); } @Override public K floorKey(K key) { return delegate.floorKey(checkValid(key)); } @Override public V get(Object key) { return delegate.get(checkValid(key)); } @Override public SortedMap headMap(K toKey) { return headMap(toKey, false); } @Override public NavigableMap headMap(K toKey, boolean inclusive) { return new SafeTreeMap<>(delegate.headMap(checkValid(toKey), inclusive)); } @Override public Entry higherEntry(K key) { return delegate.higherEntry(checkValid(key)); } @Override public K higherKey(K key) { return delegate.higherKey(checkValid(key)); } @Override public boolean isEmpty() { return delegate.isEmpty(); } @Override public NavigableSet keySet() { return navigableKeySet(); } @Override public Entry lastEntry() { return delegate.lastEntry(); } @Override public K lastKey() { return delegate.lastKey(); } @Override public Entry lowerEntry(K key) { return delegate.lowerEntry(checkValid(key)); } @Override public K lowerKey(K key) { return delegate.lowerKey(checkValid(key)); } @Override public NavigableSet navigableKeySet() { return delegate.navigableKeySet(); } @Override public Entry pollFirstEntry() { return delegate.pollFirstEntry(); } @Override public Entry pollLastEntry() { return delegate.pollLastEntry(); } @Override public V put(K key, V value) { return delegate.put(checkValid(key), value); } @Override public void putAll(Map map) { for (K key : map.keySet()) { checkValid(key); } delegate.putAll(map); } @Override public V remove(Object key) { return delegate.remove(checkValid(key)); } @Override public int size() { return delegate.size(); } @Override public NavigableMap subMap(K fromKey, boolean fromInclusive, K toKey, boolean toInclusive) { return new SafeTreeMap<>( delegate.subMap(checkValid(fromKey), fromInclusive, checkValid(toKey), toInclusive)); } @Override public SortedMap subMap(K fromKey, K toKey) { return subMap(fromKey, true, toKey, false); } @Override public SortedMap tailMap(K fromKey) { return tailMap(fromKey, true); } @Override public NavigableMap tailMap(K fromKey, boolean inclusive) { return new SafeTreeMap<>(delegate.tailMap(checkValid(fromKey), inclusive)); } @Override public Collection values() { return delegate.values(); } private T checkValid(T t) { // a ClassCastException is what's supposed to happen! @SuppressWarnings("unchecked") K k = (K) t; comparator().compare(k, k); return t; } @Override public boolean equals(Object obj) { return delegate.equals(obj); } @Override public int hashCode() { return delegate.hashCode(); } @Override public String toString() { return delegate.toString(); } private static final long serialVersionUID = 0L; }