1 /* 2 * Copyright (C) 2020 The Dagger Authors. 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 package dagger.hilt.android.internal.lifecycle; 18 19 import static androidx.lifecycle.SavedStateHandleSupport.createSavedStateHandle; 20 21 import android.app.Activity; 22 import android.os.Bundle; 23 import androidx.annotation.NonNull; 24 import androidx.annotation.Nullable; 25 import androidx.lifecycle.ViewModel; 26 import androidx.lifecycle.ViewModelProvider; 27 import androidx.lifecycle.viewmodel.CreationExtras; 28 import androidx.savedstate.SavedStateRegistryOwner; 29 import dagger.Module; 30 import dagger.hilt.EntryPoint; 31 import dagger.hilt.EntryPoints; 32 import dagger.hilt.InstallIn; 33 import dagger.hilt.android.components.ActivityComponent; 34 import dagger.hilt.android.components.ViewModelComponent; 35 import dagger.hilt.android.internal.builders.ViewModelComponentBuilder; 36 import dagger.multibindings.Multibinds; 37 import java.util.Map; 38 import javax.inject.Provider; 39 import kotlin.jvm.functions.Function1; 40 41 /** 42 * View Model Provider Factory for the Hilt Extension. 43 * 44 * <p>A provider for this factory will be installed in the {@link 45 * dagger.hilt.android.components.ActivityComponent} and {@link 46 * dagger.hilt.android.components.FragmentComponent}. An instance of this factory will also be the 47 * default factory by activities and fragments annotated with {@link 48 * dagger.hilt.android.AndroidEntryPoint}. 49 */ 50 public final class HiltViewModelFactory implements ViewModelProvider.Factory { 51 52 /** Hilt entry point for getting the multi-binding map of ViewModels. */ 53 @EntryPoint 54 @InstallIn(ViewModelComponent.class) 55 public interface ViewModelFactoriesEntryPoint { 56 @HiltViewModelMap getHiltViewModelMap()57 Map<Class<?>, Provider<ViewModel>> getHiltViewModelMap(); 58 59 // From ViewModel class names to user defined @AssistedFactory-annotated implementations. 60 @HiltViewModelAssistedMap getHiltViewModelAssistedMap()61 Map<Class<?>, Object> getHiltViewModelAssistedMap(); 62 } 63 64 /** Creation extra key for the callbacks that create @AssistedInject-annotated ViewModels. */ 65 public static final CreationExtras.Key<Function1<Object, ViewModel>> CREATION_CALLBACK_KEY = 66 new CreationExtras.Key<Function1<Object, ViewModel>>() {}; 67 68 /** Hilt module for providing the empty multi-binding map of ViewModels. */ 69 @Module 70 @InstallIn(ViewModelComponent.class) 71 interface ViewModelModule { 72 @Multibinds 73 @HiltViewModelMap hiltViewModelMap()74 Map<Class<?>, ViewModel> hiltViewModelMap(); 75 76 @Multibinds 77 @HiltViewModelAssistedMap hiltViewModelAssistedMap()78 Map<Class<?>, Object> hiltViewModelAssistedMap(); 79 } 80 81 private final Map<Class<?>, Boolean> hiltViewModelKeys; 82 private final ViewModelProvider.Factory delegateFactory; 83 private final ViewModelProvider.Factory hiltViewModelFactory; 84 HiltViewModelFactory( @onNull Map<Class<?>, Boolean> hiltViewModelKeys, @NonNull ViewModelProvider.Factory delegateFactory, @NonNull ViewModelComponentBuilder viewModelComponentBuilder)85 public HiltViewModelFactory( 86 @NonNull Map<Class<?>, Boolean> hiltViewModelKeys, 87 @NonNull ViewModelProvider.Factory delegateFactory, 88 @NonNull ViewModelComponentBuilder viewModelComponentBuilder) { 89 this.hiltViewModelKeys = hiltViewModelKeys; 90 this.delegateFactory = delegateFactory; 91 this.hiltViewModelFactory = 92 new ViewModelProvider.Factory() { 93 @NonNull 94 @Override 95 public <T extends ViewModel> T create( 96 @NonNull Class<T> modelClass, @NonNull CreationExtras extras) { 97 RetainedLifecycleImpl lifecycle = new RetainedLifecycleImpl(); 98 ViewModelComponent component = 99 viewModelComponentBuilder 100 .savedStateHandle(createSavedStateHandle(extras)) 101 .viewModelLifecycle(lifecycle) 102 .build(); 103 T viewModel = createViewModel(component, modelClass, extras); 104 viewModel.addCloseable(lifecycle::dispatchOnCleared); 105 return viewModel; 106 } 107 108 private <T extends ViewModel> T createViewModel( 109 @NonNull ViewModelComponent component, 110 @NonNull Class<T> modelClass, 111 @NonNull CreationExtras extras) { 112 Provider<? extends ViewModel> provider = 113 EntryPoints.get(component, ViewModelFactoriesEntryPoint.class) 114 .getHiltViewModelMap() 115 .get(modelClass); 116 Function1<Object, ViewModel> creationCallback = extras.get(CREATION_CALLBACK_KEY); 117 Object assistedFactory = 118 EntryPoints.get(component, ViewModelFactoriesEntryPoint.class) 119 .getHiltViewModelAssistedMap() 120 .get(modelClass); 121 122 if (assistedFactory == null) { 123 if (creationCallback == null) { 124 if (provider == null) { 125 throw new IllegalStateException( 126 "Expected the @HiltViewModel-annotated class " 127 + modelClass.getName() 128 + " to be available in the multi-binding of " 129 + "@HiltViewModelMap" 130 + " but none was found."); 131 } else { 132 return (T) provider.get(); 133 } 134 } else { 135 // Provider could be null or non-null. 136 throw new IllegalStateException( 137 "Found creation callback but class " 138 + modelClass.getName() 139 + " does not have an assisted factory specified in @HiltViewModel."); 140 } 141 } else { 142 if (provider == null) { 143 if (creationCallback == null) { 144 throw new IllegalStateException( 145 "Found @HiltViewModel-annotated class " 146 + modelClass.getName() 147 + " using @AssistedInject but no creation callback" 148 + " was provided in CreationExtras."); 149 } else { 150 return (T) creationCallback.invoke(assistedFactory); 151 } 152 } else { 153 // Creation callback could be null or non-null. 154 throw new AssertionError( 155 "Found the @HiltViewModel-annotated class " 156 + modelClass.getName() 157 + " in both the multi-bindings of " 158 + "@HiltViewModelMap and @HiltViewModelAssistedMap."); 159 } 160 } 161 } 162 }; 163 } 164 165 @NonNull 166 @Override create( @onNull Class<T> modelClass, @NonNull CreationExtras extras)167 public <T extends ViewModel> T create( 168 @NonNull Class<T> modelClass, @NonNull CreationExtras extras) { 169 if (hiltViewModelKeys.containsKey(modelClass)) { 170 return hiltViewModelFactory.create(modelClass, extras); 171 } else { 172 return delegateFactory.create(modelClass, extras); 173 } 174 } 175 176 @NonNull 177 @Override create(@onNull Class<T> modelClass)178 public <T extends ViewModel> T create(@NonNull Class<T> modelClass) { 179 if (hiltViewModelKeys.containsKey(modelClass)) { 180 return hiltViewModelFactory.create(modelClass); 181 } else { 182 return delegateFactory.create(modelClass); 183 } 184 } 185 186 @EntryPoint 187 @InstallIn(ActivityComponent.class) 188 interface ActivityCreatorEntryPoint { 189 @HiltViewModelMap.KeySet getViewModelKeys()190 Map<Class<?>, Boolean> getViewModelKeys(); 191 getViewModelComponentBuilder()192 ViewModelComponentBuilder getViewModelComponentBuilder(); 193 } 194 createInternal( @onNull Activity activity, @NonNull SavedStateRegistryOwner owner, @Nullable Bundle defaultArgs, @NonNull ViewModelProvider.Factory delegateFactory)195 public static ViewModelProvider.Factory createInternal( 196 @NonNull Activity activity, 197 @NonNull SavedStateRegistryOwner owner, 198 @Nullable Bundle defaultArgs, 199 @NonNull ViewModelProvider.Factory delegateFactory) { 200 return createInternal(activity, delegateFactory); 201 } 202 createInternal( @onNull Activity activity, @NonNull ViewModelProvider.Factory delegateFactory)203 public static ViewModelProvider.Factory createInternal( 204 @NonNull Activity activity, @NonNull ViewModelProvider.Factory delegateFactory) { 205 ActivityCreatorEntryPoint entryPoint = 206 EntryPoints.get(activity, ActivityCreatorEntryPoint.class); 207 return new HiltViewModelFactory( 208 entryPoint.getViewModelKeys(), 209 delegateFactory, 210 entryPoint.getViewModelComponentBuilder() 211 ); 212 } 213 } 214