1 /* 2 * Copyright (C) 2020 The Dagger Authors. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package dagger.hilt.android; 18 19 import static com.google.common.truth.Truth.assertThat; 20 21 import android.os.Build; 22 import androidx.fragment.app.Fragment; 23 import androidx.fragment.app.FragmentActivity; 24 import androidx.lifecycle.ViewModel; 25 import androidx.lifecycle.ViewModelProvider; 26 import androidx.test.core.app.ActivityScenario; 27 import androidx.test.ext.junit.runners.AndroidJUnit4; 28 import dagger.hilt.android.lifecycle.HiltViewModel; 29 import dagger.hilt.android.testing.BindValue; 30 import dagger.hilt.android.testing.HiltAndroidRule; 31 import dagger.hilt.android.testing.HiltAndroidTest; 32 import dagger.hilt.android.testing.HiltTestApplication; 33 import javax.inject.Inject; 34 import org.junit.Rule; 35 import org.junit.Test; 36 import org.junit.runner.RunWith; 37 import org.robolectric.annotation.Config; 38 39 @HiltAndroidTest 40 @RunWith(AndroidJUnit4.class) 41 // Robolectric requires Java9 to run API 29 and above, so use API 28 instead 42 @Config(sdk = Build.VERSION_CODES.P, application = HiltTestApplication.class) 43 public class DefaultViewModelFactoryTest { 44 45 @Rule public final HiltAndroidRule rule = new HiltAndroidRule(this); 46 47 @BindValue String hiltStringValue = "hilt"; 48 49 @Test activityFactoryFallsBackToBase()50 public void activityFactoryFallsBackToBase() { 51 try (ActivityScenario<TestActivity> scenario = ActivityScenario.launch(TestActivity.class)) { 52 scenario.onActivity( 53 activity -> { 54 assertThat(new ViewModelProvider(activity).get(TestHiltViewModel.class).value) 55 .isEqualTo("hilt"); 56 assertThat(new ViewModelProvider(activity).get(TestViewModel.class).value) 57 .isEqualTo("non-hilt"); 58 }); 59 } 60 } 61 62 @Test fragmentFactoryFallbsBackToBase()63 public void fragmentFactoryFallbsBackToBase() { 64 // TODO(danysantiago): Use FragmentScenario when it becomes available. 65 try (ActivityScenario<TestActivity> scenario = ActivityScenario.launch(TestActivity.class)) { 66 scenario.onActivity( 67 activity -> { 68 TestFragment fragment = new TestFragment(); 69 activity.getSupportFragmentManager().beginTransaction().add(fragment, "").commitNow(); 70 assertThat(new ViewModelProvider(fragment).get(TestHiltViewModel.class).value) 71 .isEqualTo("hilt"); 72 assertThat(new ViewModelProvider(fragment).get(TestViewModel.class).value) 73 .isEqualTo("non-hilt"); 74 }); 75 } 76 } 77 78 @HiltViewModel 79 public static final class TestHiltViewModel extends ViewModel { 80 final String value; 81 82 @Inject TestHiltViewModel(String value)83 TestHiltViewModel(String value) { 84 this.value = value; 85 } 86 } 87 88 public static final class TestViewModel extends ViewModel { 89 final String value; 90 // Take in a string so it cannot be constructed by the default view model factory TestViewModel(String value)91 public TestViewModel(String value) { 92 this.value = value; 93 } 94 } 95 96 @AndroidEntryPoint(BaseActivity.class) 97 public static final class TestActivity extends Hilt_DefaultViewModelFactoryTest_TestActivity {} 98 99 public static class BaseActivity extends FragmentActivity { 100 @SuppressWarnings("unchecked") getDefaultViewModelProviderFactory()101 @Override public ViewModelProvider.Factory getDefaultViewModelProviderFactory() { 102 return new ViewModelProvider.Factory() { 103 @Override public <T extends ViewModel> T create(Class<T> clazz) { 104 assertThat(clazz).isEqualTo(TestViewModel.class); 105 return (T) new TestViewModel("non-hilt"); 106 } 107 }; 108 } 109 } 110 111 @AndroidEntryPoint(BaseFragment.class) 112 public static final class TestFragment extends Hilt_DefaultViewModelFactoryTest_TestFragment {} 113 114 public static class BaseFragment extends Fragment { 115 @SuppressWarnings("unchecked") 116 @Override public ViewModelProvider.Factory getDefaultViewModelProviderFactory() { 117 return new ViewModelProvider.Factory() { 118 @Override public <T extends ViewModel> T create(Class<T> clazz) { 119 assertThat(clazz).isEqualTo(TestViewModel.class); 120 return (T) new TestViewModel("non-hilt"); 121 } 122 }; 123 } 124 } 125 } 126