1 /*
<lambda>null2  * Copyright (C) 2017 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 package androidx.room.rxjava3
17 
18 import androidx.arch.core.executor.ArchTaskExecutor
19 import androidx.arch.core.executor.testing.CountingTaskExecutorRule
20 import androidx.kruth.assertThat
21 import androidx.room.InvalidationTracker
22 import androidx.room.RoomDatabase
23 import io.reactivex.rxjava3.functions.Consumer
24 import io.reactivex.rxjava3.observers.TestObserver
25 import io.reactivex.rxjava3.subscribers.TestSubscriber
26 import java.util.concurrent.Callable
27 import java.util.concurrent.TimeUnit
28 import java.util.concurrent.atomic.AtomicReference
29 import org.junit.Before
30 import org.junit.Rule
31 import org.junit.Test
32 import org.junit.runner.RunWith
33 import org.junit.runners.JUnit4
34 import org.mockito.invocation.InvocationOnMock
35 import org.mockito.kotlin.any
36 import org.mockito.kotlin.doAnswer
37 import org.mockito.kotlin.mock
38 import org.mockito.kotlin.never
39 import org.mockito.kotlin.times
40 import org.mockito.kotlin.verify
41 import org.mockito.kotlin.whenever
42 
43 @RunWith(JUnit4::class)
44 class RxRoomTest {
45     @get:Rule var mExecutor = CountingTaskExecutorRule()
46     private lateinit var mDatabase: RoomDatabase
47     private lateinit var mInvalidationTracker: InvalidationTracker
48     private val mAddedObservers: MutableList<InvalidationTracker.Observer> = ArrayList()
49 
50     @Before
51     fun init() {
52         mDatabase = mock()
53         mInvalidationTracker = mock()
54         whenever(mDatabase.invalidationTracker).thenReturn(mInvalidationTracker)
55         whenever(mDatabase.queryExecutor).thenReturn(ArchTaskExecutor.getIOThreadExecutor())
56         doAnswer { invocation: InvocationOnMock ->
57                 mAddedObservers.add(invocation.arguments[0] as InvalidationTracker.Observer)
58                 null
59             }
60             .whenever(mInvalidationTracker)
61             .addObserver(any())
62     }
63 
64     @Test
65     fun basicAddRemove_Flowable() {
66         val flowable = createFlowable(mDatabase, "a", "b")
67         verify(mInvalidationTracker, never()).addObserver(any())
68         var disposable = flowable.subscribe()
69         verify(mInvalidationTracker).addObserver(any())
70         assertThat(mAddedObservers.size).isEqualTo(1)
71         val observer = mAddedObservers[0]
72         disposable.dispose()
73         verify(mInvalidationTracker).removeObserver(observer)
74         disposable = flowable.subscribe()
75         verify(mInvalidationTracker, times(2)).addObserver(any())
76         assertThat(mAddedObservers.size).isEqualTo(2)
77         assertThat(mAddedObservers[1]).isNotSameInstanceAs(observer)
78 
79         val observer2 = mAddedObservers[1]
80         disposable.dispose()
81         verify(mInvalidationTracker).removeObserver(observer2)
82     }
83 
84     @Test
85     fun basicAddRemove_Observable() {
86         val observable = createObservable(mDatabase, "a", "b")
87         verify(mInvalidationTracker, never()).addObserver(any())
88         var disposable = observable.subscribe()
89         verify(mInvalidationTracker).addObserver(any())
90         assertThat(mAddedObservers.size).isEqualTo(1)
91         val observer = mAddedObservers[0]
92         disposable.dispose()
93         verify(mInvalidationTracker).removeObserver(observer)
94         disposable = observable.subscribe()
95         verify(mInvalidationTracker, times(2)).addObserver(any())
96         assertThat(mAddedObservers.size).isEqualTo(2)
97         assertThat(mAddedObservers[1]).isNotSameInstanceAs(observer)
98 
99         val observer2 = mAddedObservers[1]
100         disposable.dispose()
101         verify(mInvalidationTracker).removeObserver(observer2)
102     }
103 
104     @Test
105     fun basicNotify_Flowable() {
106         val tables = arrayOf("a", "b")
107         val tableSet: Set<String> = HashSet(listOf(*tables))
108         val flowable = createFlowable(mDatabase, *tables)
109         val consumer = CountingConsumer()
110         val disposable = flowable.subscribe(consumer)
111         assertThat(mAddedObservers.size).isEqualTo(1)
112         val observer = mAddedObservers[0]
113         assertThat(consumer.mCount).isEqualTo(1)
114         observer.onInvalidated(tableSet)
115         assertThat(consumer.mCount).isEqualTo(2)
116         observer.onInvalidated(tableSet)
117         assertThat(consumer.mCount).isEqualTo(3)
118         disposable.dispose()
119         observer.onInvalidated(tableSet)
120         assertThat(consumer.mCount).isEqualTo(3)
121     }
122 
123     @Test
124     fun basicNotify_Observable() {
125         val tables = arrayOf("a", "b")
126         val tableSet: Set<String> = HashSet(listOf(*tables))
127         val observable = createObservable(mDatabase, *tables)
128         val consumer = CountingConsumer()
129         val disposable = observable.subscribe(consumer)
130         assertThat(mAddedObservers.size).isEqualTo(1)
131         val observer = mAddedObservers[0]
132         assertThat(consumer.mCount).isEqualTo(1)
133         observer.onInvalidated(tableSet)
134         assertThat(consumer.mCount).isEqualTo(2)
135         observer.onInvalidated(tableSet)
136         assertThat(consumer.mCount).isEqualTo(3)
137         disposable.dispose()
138         observer.onInvalidated(tableSet)
139         assertThat(consumer.mCount).isEqualTo(3)
140     }
141 
142     @Test
143     @Suppress("DEPRECATION")
144     fun internalCallable_Flowable() {
145         val value = AtomicReference<Any>(null)
146         val tables = arrayOf("a", "b")
147         val tableSet: Set<String> = HashSet(listOf(*tables))
148         val flowable = createFlowable(mDatabase, false, tables, Callable { value.get() })
149         val consumer = CountingConsumer()
150         val disposable = flowable.subscribe(consumer)
151         drain()
152         val observer = mAddedObservers[0]
153         // no value because it is null
154         assertThat(consumer.mCount).isEqualTo(0)
155         value.set("bla")
156         observer.onInvalidated(tableSet)
157         drain()
158         // get value
159         assertThat(consumer.mCount).isEqualTo(1)
160         observer.onInvalidated(tableSet)
161         drain()
162         // get value
163         assertThat(consumer.mCount).isEqualTo(2)
164         value.set(null)
165         observer.onInvalidated(tableSet)
166         drain()
167         // no value
168         assertThat(consumer.mCount).isEqualTo(2)
169         disposable.dispose()
170     }
171 
172     @Test
173     @Suppress("DEPRECATION")
174     fun internalCallable_Observable() {
175         val value = AtomicReference<Any>(null)
176         val tables = arrayOf("a", "b")
177         val tableSet: Set<String> = HashSet(listOf(*tables))
178         val flowable = createObservable(mDatabase, false, tables, Callable { value.get() })
179         val consumer = CountingConsumer()
180         val disposable = flowable.subscribe(consumer)
181         drain()
182         val observer = mAddedObservers[0]
183         // no value because it is null
184         assertThat(consumer.mCount).isEqualTo(0)
185         value.set("bla")
186         observer.onInvalidated(tableSet)
187         drain()
188         // get value
189         assertThat(consumer.mCount).isEqualTo(1)
190         observer.onInvalidated(tableSet)
191         drain()
192         // get value
193         assertThat(consumer.mCount).isEqualTo(2)
194         value.set(null)
195         observer.onInvalidated(tableSet)
196         drain()
197         // no value
198         assertThat(consumer.mCount).isEqualTo(2)
199         disposable.dispose()
200     }
201 
202     @Test
203     @Suppress("DEPRECATION")
204     fun exception_Flowable() {
205         val flowable =
206             createFlowable<String>(
207                 mDatabase,
208                 false,
209                 arrayOf("a"),
210                 Callable { throw Exception("i want exception") }
211             )
212         val subscriber = TestSubscriber<String>()
213         flowable.subscribe(subscriber)
214         drain()
215         subscriber.assertError { throwable: Throwable -> throwable.message == "i want exception" }
216     }
217 
218     @Test
219     @Suppress("DEPRECATION")
220     fun exception_Observable() {
221         val flowable =
222             createObservable<String>(
223                 mDatabase,
224                 false,
225                 arrayOf("a"),
226                 Callable { throw Exception("i want exception") }
227             )
228         val observer = TestObserver<String>()
229         flowable.subscribe(observer)
230         drain()
231         observer.assertError { throwable: Throwable -> throwable.message == "i want exception" }
232     }
233 
234     @Throws(Exception::class)
235     private fun drain() {
236         mExecutor.drainTasks(10, TimeUnit.SECONDS)
237     }
238 
239     private class CountingConsumer : Consumer<Any> {
240         var mCount = 0
241 
242         override fun accept(o: Any) {
243             mCount++
244         }
245     }
246 }
247