• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (C) 2020 The Dagger Authors.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package dagger.hilt.android.internal.lifecycle;
18 
19 import static androidx.lifecycle.SavedStateHandleSupport.createSavedStateHandle;
20 
21 import android.app.Activity;
22 import android.os.Bundle;
23 import androidx.annotation.NonNull;
24 import androidx.annotation.Nullable;
25 import androidx.lifecycle.ViewModel;
26 import androidx.lifecycle.ViewModelProvider;
27 import androidx.lifecycle.viewmodel.CreationExtras;
28 import androidx.savedstate.SavedStateRegistryOwner;
29 import dagger.Module;
30 import dagger.hilt.EntryPoint;
31 import dagger.hilt.EntryPoints;
32 import dagger.hilt.InstallIn;
33 import dagger.hilt.android.components.ActivityComponent;
34 import dagger.hilt.android.components.ViewModelComponent;
35 import dagger.hilt.android.internal.builders.ViewModelComponentBuilder;
36 import dagger.multibindings.Multibinds;
37 import java.util.Map;
38 import javax.inject.Provider;
39 import kotlin.jvm.functions.Function1;
40 
41 /**
42  * View Model Provider Factory for the Hilt Extension.
43  *
44  * <p>A provider for this factory will be installed in the {@link
45  * dagger.hilt.android.components.ActivityComponent} and {@link
46  * dagger.hilt.android.components.FragmentComponent}. An instance of this factory will also be the
47  * default factory by activities and fragments annotated with {@link
48  * dagger.hilt.android.AndroidEntryPoint}.
49  */
50 public final class HiltViewModelFactory implements ViewModelProvider.Factory {
51 
52   /** Hilt entry point for getting the multi-binding map of ViewModels. */
53   @EntryPoint
54   @InstallIn(ViewModelComponent.class)
55   public interface ViewModelFactoriesEntryPoint {
56     @HiltViewModelMap
getHiltViewModelMap()57     Map<Class<?>, Provider<ViewModel>> getHiltViewModelMap();
58 
59     // From ViewModel class names to user defined @AssistedFactory-annotated implementations.
60     @HiltViewModelAssistedMap
getHiltViewModelAssistedMap()61     Map<Class<?>, Object> getHiltViewModelAssistedMap();
62   }
63 
64   /** Creation extra key for the callbacks that create @AssistedInject-annotated ViewModels. */
65   public static final CreationExtras.Key<Function1<Object, ViewModel>> CREATION_CALLBACK_KEY =
66       new CreationExtras.Key<Function1<Object, ViewModel>>() {};
67 
68   /** Hilt module for providing the empty multi-binding map of ViewModels. */
69   @Module
70   @InstallIn(ViewModelComponent.class)
71   interface ViewModelModule {
72     @Multibinds
73     @HiltViewModelMap
hiltViewModelMap()74     Map<Class<?>, ViewModel> hiltViewModelMap();
75 
76     @Multibinds
77     @HiltViewModelAssistedMap
hiltViewModelAssistedMap()78     Map<Class<?>, Object> hiltViewModelAssistedMap();
79   }
80 
81   private final Map<Class<?>, Boolean> hiltViewModelKeys;
82   private final ViewModelProvider.Factory delegateFactory;
83   private final ViewModelProvider.Factory hiltViewModelFactory;
84 
HiltViewModelFactory( @onNull Map<Class<?>, Boolean> hiltViewModelKeys, @NonNull ViewModelProvider.Factory delegateFactory, @NonNull ViewModelComponentBuilder viewModelComponentBuilder)85   public HiltViewModelFactory(
86       @NonNull Map<Class<?>, Boolean> hiltViewModelKeys,
87       @NonNull ViewModelProvider.Factory delegateFactory,
88       @NonNull ViewModelComponentBuilder viewModelComponentBuilder) {
89     this.hiltViewModelKeys = hiltViewModelKeys;
90     this.delegateFactory = delegateFactory;
91     this.hiltViewModelFactory =
92         new ViewModelProvider.Factory() {
93           @NonNull
94           @Override
95           public <T extends ViewModel> T create(
96               @NonNull Class<T> modelClass, @NonNull CreationExtras extras) {
97             RetainedLifecycleImpl lifecycle = new RetainedLifecycleImpl();
98             ViewModelComponent component =
99                 viewModelComponentBuilder
100                     .savedStateHandle(createSavedStateHandle(extras))
101                     .viewModelLifecycle(lifecycle)
102                     .build();
103             T viewModel = createViewModel(component, modelClass, extras);
104             viewModel.addCloseable(lifecycle::dispatchOnCleared);
105             return viewModel;
106           }
107 
108           private <T extends ViewModel> T createViewModel(
109               @NonNull ViewModelComponent component,
110               @NonNull Class<T> modelClass,
111               @NonNull CreationExtras extras) {
112             Provider<? extends ViewModel> provider =
113                 EntryPoints.get(component, ViewModelFactoriesEntryPoint.class)
114                     .getHiltViewModelMap()
115                     .get(modelClass);
116             Function1<Object, ViewModel> creationCallback = extras.get(CREATION_CALLBACK_KEY);
117             Object assistedFactory =
118                 EntryPoints.get(component, ViewModelFactoriesEntryPoint.class)
119                     .getHiltViewModelAssistedMap()
120                     .get(modelClass);
121 
122             if (assistedFactory == null) {
123               if (creationCallback == null) {
124                 if (provider == null) {
125                   throw new IllegalStateException(
126                       "Expected the @HiltViewModel-annotated class "
127                           + modelClass.getName()
128                           + " to be available in the multi-binding of "
129                           + "@HiltViewModelMap"
130                           + " but none was found.");
131                 } else {
132                   return (T) provider.get();
133                 }
134               } else {
135                 // Provider could be null or non-null.
136                 throw new IllegalStateException(
137                     "Found creation callback but class "
138                         + modelClass.getName()
139                         + " does not have an assisted factory specified in @HiltViewModel.");
140               }
141             } else {
142               if (provider == null) {
143                 if (creationCallback == null) {
144                   throw new IllegalStateException(
145                       "Found @HiltViewModel-annotated class "
146                           + modelClass.getName()
147                           + " using @AssistedInject but no creation callback"
148                           + " was provided in CreationExtras.");
149                 } else {
150                   return (T) creationCallback.invoke(assistedFactory);
151                 }
152               } else {
153                 // Creation callback could be null or non-null.
154                 throw new AssertionError(
155                     "Found the @HiltViewModel-annotated class "
156                         + modelClass.getName()
157                         + " in both the multi-bindings of "
158                         + "@HiltViewModelMap and @HiltViewModelAssistedMap.");
159               }
160             }
161           }
162         };
163   }
164 
165   @NonNull
166   @Override
create( @onNull Class<T> modelClass, @NonNull CreationExtras extras)167   public <T extends ViewModel> T create(
168       @NonNull Class<T> modelClass, @NonNull CreationExtras extras) {
169     if (hiltViewModelKeys.containsKey(modelClass)) {
170       return hiltViewModelFactory.create(modelClass, extras);
171     } else {
172       return delegateFactory.create(modelClass, extras);
173     }
174   }
175 
176   @NonNull
177   @Override
create(@onNull Class<T> modelClass)178   public <T extends ViewModel> T create(@NonNull Class<T> modelClass) {
179     if (hiltViewModelKeys.containsKey(modelClass)) {
180       return hiltViewModelFactory.create(modelClass);
181     } else {
182       return delegateFactory.create(modelClass);
183     }
184   }
185 
186   @EntryPoint
187   @InstallIn(ActivityComponent.class)
188   interface ActivityCreatorEntryPoint {
189     @HiltViewModelMap.KeySet
getViewModelKeys()190     Map<Class<?>, Boolean> getViewModelKeys();
191 
getViewModelComponentBuilder()192     ViewModelComponentBuilder getViewModelComponentBuilder();
193   }
194 
createInternal( @onNull Activity activity, @NonNull SavedStateRegistryOwner owner, @Nullable Bundle defaultArgs, @NonNull ViewModelProvider.Factory delegateFactory)195   public static ViewModelProvider.Factory createInternal(
196       @NonNull Activity activity,
197       @NonNull SavedStateRegistryOwner owner,
198       @Nullable Bundle defaultArgs,
199       @NonNull ViewModelProvider.Factory delegateFactory) {
200     return createInternal(activity, delegateFactory);
201   }
202 
createInternal( @onNull Activity activity, @NonNull ViewModelProvider.Factory delegateFactory)203   public static ViewModelProvider.Factory createInternal(
204       @NonNull Activity activity, @NonNull ViewModelProvider.Factory delegateFactory) {
205     ActivityCreatorEntryPoint entryPoint =
206         EntryPoints.get(activity, ActivityCreatorEntryPoint.class);
207     return new HiltViewModelFactory(
208         entryPoint.getViewModelKeys(),
209         delegateFactory,
210         entryPoint.getViewModelComponentBuilder()
211     );
212   }
213 }
214