1 /* <lambda>null2 * Copyright (C) 2017 The Android Open Source Project 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 package androidx.room.rxjava3 17 18 import androidx.arch.core.executor.ArchTaskExecutor 19 import androidx.arch.core.executor.testing.CountingTaskExecutorRule 20 import androidx.kruth.assertThat 21 import androidx.room.InvalidationTracker 22 import androidx.room.RoomDatabase 23 import io.reactivex.rxjava3.functions.Consumer 24 import io.reactivex.rxjava3.observers.TestObserver 25 import io.reactivex.rxjava3.subscribers.TestSubscriber 26 import java.util.concurrent.Callable 27 import java.util.concurrent.TimeUnit 28 import java.util.concurrent.atomic.AtomicReference 29 import org.junit.Before 30 import org.junit.Rule 31 import org.junit.Test 32 import org.junit.runner.RunWith 33 import org.junit.runners.JUnit4 34 import org.mockito.invocation.InvocationOnMock 35 import org.mockito.kotlin.any 36 import org.mockito.kotlin.doAnswer 37 import org.mockito.kotlin.mock 38 import org.mockito.kotlin.never 39 import org.mockito.kotlin.times 40 import org.mockito.kotlin.verify 41 import org.mockito.kotlin.whenever 42 43 @RunWith(JUnit4::class) 44 class RxRoomTest { 45 @get:Rule var mExecutor = CountingTaskExecutorRule() 46 private lateinit var mDatabase: RoomDatabase 47 private lateinit var mInvalidationTracker: InvalidationTracker 48 private val mAddedObservers: MutableList<InvalidationTracker.Observer> = ArrayList() 49 50 @Before 51 fun init() { 52 mDatabase = mock() 53 mInvalidationTracker = mock() 54 whenever(mDatabase.invalidationTracker).thenReturn(mInvalidationTracker) 55 whenever(mDatabase.queryExecutor).thenReturn(ArchTaskExecutor.getIOThreadExecutor()) 56 doAnswer { invocation: InvocationOnMock -> 57 mAddedObservers.add(invocation.arguments[0] as InvalidationTracker.Observer) 58 null 59 } 60 .whenever(mInvalidationTracker) 61 .addObserver(any()) 62 } 63 64 @Test 65 fun basicAddRemove_Flowable() { 66 val flowable = createFlowable(mDatabase, "a", "b") 67 verify(mInvalidationTracker, never()).addObserver(any()) 68 var disposable = flowable.subscribe() 69 verify(mInvalidationTracker).addObserver(any()) 70 assertThat(mAddedObservers.size).isEqualTo(1) 71 val observer = mAddedObservers[0] 72 disposable.dispose() 73 verify(mInvalidationTracker).removeObserver(observer) 74 disposable = flowable.subscribe() 75 verify(mInvalidationTracker, times(2)).addObserver(any()) 76 assertThat(mAddedObservers.size).isEqualTo(2) 77 assertThat(mAddedObservers[1]).isNotSameInstanceAs(observer) 78 79 val observer2 = mAddedObservers[1] 80 disposable.dispose() 81 verify(mInvalidationTracker).removeObserver(observer2) 82 } 83 84 @Test 85 fun basicAddRemove_Observable() { 86 val observable = createObservable(mDatabase, "a", "b") 87 verify(mInvalidationTracker, never()).addObserver(any()) 88 var disposable = observable.subscribe() 89 verify(mInvalidationTracker).addObserver(any()) 90 assertThat(mAddedObservers.size).isEqualTo(1) 91 val observer = mAddedObservers[0] 92 disposable.dispose() 93 verify(mInvalidationTracker).removeObserver(observer) 94 disposable = observable.subscribe() 95 verify(mInvalidationTracker, times(2)).addObserver(any()) 96 assertThat(mAddedObservers.size).isEqualTo(2) 97 assertThat(mAddedObservers[1]).isNotSameInstanceAs(observer) 98 99 val observer2 = mAddedObservers[1] 100 disposable.dispose() 101 verify(mInvalidationTracker).removeObserver(observer2) 102 } 103 104 @Test 105 fun basicNotify_Flowable() { 106 val tables = arrayOf("a", "b") 107 val tableSet: Set<String> = HashSet(listOf(*tables)) 108 val flowable = createFlowable(mDatabase, *tables) 109 val consumer = CountingConsumer() 110 val disposable = flowable.subscribe(consumer) 111 assertThat(mAddedObservers.size).isEqualTo(1) 112 val observer = mAddedObservers[0] 113 assertThat(consumer.mCount).isEqualTo(1) 114 observer.onInvalidated(tableSet) 115 assertThat(consumer.mCount).isEqualTo(2) 116 observer.onInvalidated(tableSet) 117 assertThat(consumer.mCount).isEqualTo(3) 118 disposable.dispose() 119 observer.onInvalidated(tableSet) 120 assertThat(consumer.mCount).isEqualTo(3) 121 } 122 123 @Test 124 fun basicNotify_Observable() { 125 val tables = arrayOf("a", "b") 126 val tableSet: Set<String> = HashSet(listOf(*tables)) 127 val observable = createObservable(mDatabase, *tables) 128 val consumer = CountingConsumer() 129 val disposable = observable.subscribe(consumer) 130 assertThat(mAddedObservers.size).isEqualTo(1) 131 val observer = mAddedObservers[0] 132 assertThat(consumer.mCount).isEqualTo(1) 133 observer.onInvalidated(tableSet) 134 assertThat(consumer.mCount).isEqualTo(2) 135 observer.onInvalidated(tableSet) 136 assertThat(consumer.mCount).isEqualTo(3) 137 disposable.dispose() 138 observer.onInvalidated(tableSet) 139 assertThat(consumer.mCount).isEqualTo(3) 140 } 141 142 @Test 143 @Suppress("DEPRECATION") 144 fun internalCallable_Flowable() { 145 val value = AtomicReference<Any>(null) 146 val tables = arrayOf("a", "b") 147 val tableSet: Set<String> = HashSet(listOf(*tables)) 148 val flowable = createFlowable(mDatabase, false, tables, Callable { value.get() }) 149 val consumer = CountingConsumer() 150 val disposable = flowable.subscribe(consumer) 151 drain() 152 val observer = mAddedObservers[0] 153 // no value because it is null 154 assertThat(consumer.mCount).isEqualTo(0) 155 value.set("bla") 156 observer.onInvalidated(tableSet) 157 drain() 158 // get value 159 assertThat(consumer.mCount).isEqualTo(1) 160 observer.onInvalidated(tableSet) 161 drain() 162 // get value 163 assertThat(consumer.mCount).isEqualTo(2) 164 value.set(null) 165 observer.onInvalidated(tableSet) 166 drain() 167 // no value 168 assertThat(consumer.mCount).isEqualTo(2) 169 disposable.dispose() 170 } 171 172 @Test 173 @Suppress("DEPRECATION") 174 fun internalCallable_Observable() { 175 val value = AtomicReference<Any>(null) 176 val tables = arrayOf("a", "b") 177 val tableSet: Set<String> = HashSet(listOf(*tables)) 178 val flowable = createObservable(mDatabase, false, tables, Callable { value.get() }) 179 val consumer = CountingConsumer() 180 val disposable = flowable.subscribe(consumer) 181 drain() 182 val observer = mAddedObservers[0] 183 // no value because it is null 184 assertThat(consumer.mCount).isEqualTo(0) 185 value.set("bla") 186 observer.onInvalidated(tableSet) 187 drain() 188 // get value 189 assertThat(consumer.mCount).isEqualTo(1) 190 observer.onInvalidated(tableSet) 191 drain() 192 // get value 193 assertThat(consumer.mCount).isEqualTo(2) 194 value.set(null) 195 observer.onInvalidated(tableSet) 196 drain() 197 // no value 198 assertThat(consumer.mCount).isEqualTo(2) 199 disposable.dispose() 200 } 201 202 @Test 203 @Suppress("DEPRECATION") 204 fun exception_Flowable() { 205 val flowable = 206 createFlowable<String>( 207 mDatabase, 208 false, 209 arrayOf("a"), 210 Callable { throw Exception("i want exception") } 211 ) 212 val subscriber = TestSubscriber<String>() 213 flowable.subscribe(subscriber) 214 drain() 215 subscriber.assertError { throwable: Throwable -> throwable.message == "i want exception" } 216 } 217 218 @Test 219 @Suppress("DEPRECATION") 220 fun exception_Observable() { 221 val flowable = 222 createObservable<String>( 223 mDatabase, 224 false, 225 arrayOf("a"), 226 Callable { throw Exception("i want exception") } 227 ) 228 val observer = TestObserver<String>() 229 flowable.subscribe(observer) 230 drain() 231 observer.assertError { throwable: Throwable -> throwable.message == "i want exception" } 232 } 233 234 @Throws(Exception::class) 235 private fun drain() { 236 mExecutor.drainTasks(10, TimeUnit.SECONDS) 237 } 238 239 private class CountingConsumer : Consumer<Any> { 240 var mCount = 0 241 242 override fun accept(o: Any) { 243 mCount++ 244 } 245 } 246 } 247