1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.savedstate.serialization
18 
19 import androidx.savedstate.SavedState
20 import androidx.savedstate.read
21 import androidx.savedstate.savedState
22 import androidx.savedstate.write
23 import kotlin.jvm.JvmOverloads
24 import kotlinx.serialization.ExperimentalSerializationApi
25 import kotlinx.serialization.SerializationException
26 import kotlinx.serialization.SerializationStrategy
27 import kotlinx.serialization.descriptors.SerialDescriptor
28 import kotlinx.serialization.descriptors.StructureKind
29 import kotlinx.serialization.encoding.AbstractEncoder
30 import kotlinx.serialization.encoding.CompositeEncoder
31 import kotlinx.serialization.serializer
32 
33 /**
34  * Serializes the [value] of type [T] into an equivalent [SavedState] using [KSerializer] retrieved
35  * from the reified type parameter.
36  *
37  * @sample androidx.savedstate.encode
38  * @param value The serializable object to encode.
39  * @param configuration The [SavedStateConfiguration] to use. Defaults to
40  *   [SavedStateConfiguration.DEFAULT].
41  * @return The encoded [SavedState].
42  * @throws SerializationException in case of any encoding-specific error.
43  */
encodeToSavedStatenull44 public inline fun <reified T : Any> encodeToSavedState(
45     value: T,
46     configuration: SavedStateConfiguration = SavedStateConfiguration.DEFAULT,
47 ): SavedState =
48     encodeToSavedState(configuration.serializersModule.serializer(), value, configuration)
49 
50 /**
51  * Serializes and encodes the given [value] to [SavedState] using the given [serializer].
52  *
53  * @sample androidx.savedstate.encodeWithExplicitSerializerAndConfig
54  * @param serializer The serializer to use.
55  * @param value The serializable object to encode.
56  * @param configuration The [SavedStateConfiguration] to use. Defaults to
57  *   [SavedStateConfiguration.DEFAULT].
58  * @return The encoded [SavedState].
59  * @throws SerializationException in case of any encoding-specific error.
60  */
61 @JvmOverloads
62 public fun <T : Any> encodeToSavedState(
63     serializer: SerializationStrategy<T>,
64     value: T,
65     configuration: SavedStateConfiguration = SavedStateConfiguration.DEFAULT,
66 ): SavedState {
67     val result = savedState()
68     SavedStateEncoder(result, configuration).encodeSerializableValue(serializer, value)
69     return result
70 }
71 
72 /**
73  * A [kotlinx.serialization.encoding.Encoder] that can encode a serializable object to a
74  * [SavedState]. The instance should not be reused after encoding.
75  *
76  * @property savedState The [SavedState] to encode to. Has to be empty before encoding.
77  */
78 @OptIn(ExperimentalSerializationApi::class)
79 internal class SavedStateEncoder(
80     internal val savedState: SavedState,
81     private val configuration: SavedStateConfiguration
82 ) : AbstractEncoder() {
83 
84     internal var key: String = ""
85         private set
86 
87     override val serializersModule = configuration.serializersModule
88 
shouldEncodeElementDefaultnull89     override fun shouldEncodeElementDefault(descriptor: SerialDescriptor, index: Int): Boolean {
90         return configuration.encodeDefaults
91     }
92 
encodeElementnull93     override fun encodeElement(descriptor: SerialDescriptor, index: Int): Boolean {
94         // The key will be property names for classes by default and can be modified with
95         // `@SerialName`. The key for collections will be decimal integer Strings ("0",
96         // "1", "2", ...).
97         key = descriptor.getElementName(index)
98         checkDiscriminatorCollisions(savedState, key)
99 
100         return true
101     }
102 
checkDiscriminatorCollisionsnull103     private fun checkDiscriminatorCollisions(
104         savedState: SavedState,
105         elementName: String,
106     ) {
107         if (configuration.classDiscriminatorMode == ClassDiscriminatorMode.ALL_OBJECTS) {
108             val hasClassDiscriminator = savedState.read { contains(CLASS_DISCRIMINATOR_KEY) }
109             val hasConflictingElementName = elementName == CLASS_DISCRIMINATOR_KEY
110             if (hasClassDiscriminator && hasConflictingElementName) {
111                 val classDiscriminator = savedState.read { getString(CLASS_DISCRIMINATOR_KEY) }
112                 throw IllegalArgumentException(
113                     "SavedStateEncoder for $classDiscriminator has property '$elementName' that " +
114                         "conflicts with the class discriminator. You can rename a property with " +
115                         "@SerialName annotation."
116                 )
117             }
118         }
119     }
120 
encodeBooleannull121     override fun encodeBoolean(value: Boolean) {
122         savedState.write { putBoolean(key, value) }
123     }
124 
encodeBytenull125     override fun encodeByte(value: Byte) {
126         savedState.write { putInt(key, value.toInt()) }
127     }
128 
encodeShortnull129     override fun encodeShort(value: Short) {
130         savedState.write { putInt(key, value.toInt()) }
131     }
132 
encodeIntnull133     override fun encodeInt(value: Int) {
134         savedState.write { putInt(key, value) }
135     }
136 
encodeLongnull137     override fun encodeLong(value: Long) {
138         savedState.write { putLong(key, value) }
139     }
140 
encodeFloatnull141     override fun encodeFloat(value: Float) {
142         savedState.write { putFloat(key, value) }
143     }
144 
encodeDoublenull145     override fun encodeDouble(value: Double) {
146         savedState.write { putDouble(key, value) }
147     }
148 
encodeCharnull149     override fun encodeChar(value: Char) {
150         savedState.write { putChar(key, value) }
151     }
152 
encodeStringnull153     override fun encodeString(value: String) {
154         savedState.write { putString(key, value) }
155     }
156 
encodeEnumnull157     override fun encodeEnum(enumDescriptor: SerialDescriptor, index: Int) {
158         savedState.write { putInt(key, index) }
159     }
160 
encodeNullnull161     override fun encodeNull() {
162         savedState.write { putNull(key) }
163     }
164 
encodeIntListnull165     private fun encodeIntList(value: List<Int>) {
166         savedState.write { putIntList(key, value) }
167     }
168 
encodeStringListnull169     private fun encodeStringList(value: List<String>) {
170         savedState.write { putStringList(key, value) }
171     }
172 
encodeBooleanArraynull173     private fun encodeBooleanArray(value: BooleanArray) {
174         savedState.write { putBooleanArray(key, value) }
175     }
176 
encodeCharArraynull177     private fun encodeCharArray(value: CharArray) {
178         savedState.write { putCharArray(key, value) }
179     }
180 
encodeDoubleArraynull181     private fun encodeDoubleArray(value: DoubleArray) {
182         savedState.write { putDoubleArray(key, value) }
183     }
184 
encodeFloatArraynull185     private fun encodeFloatArray(value: FloatArray) {
186         savedState.write { putFloatArray(key, value) }
187     }
188 
encodeIntArraynull189     private fun encodeIntArray(value: IntArray) {
190         savedState.write { putIntArray(key, value) }
191     }
192 
encodeLongArraynull193     private fun encodeLongArray(value: LongArray) {
194         savedState.write { putLongArray(key, value) }
195     }
196 
encodeStringArraynull197     private fun encodeStringArray(value: Array<String>) {
198         savedState.write { putStringArray(key, value) }
199     }
200 
beginStructurenull201     override fun beginStructure(descriptor: SerialDescriptor): CompositeEncoder {
202         // We flatten single structured object at root to prevent encoding to a
203         // SavedState containing only one SavedState inside. For example, a
204         // `Pair(3, 5)` would become `{"first" = 3, "second" = 5}` instead of
205         // `{{"first" = 3, "second" = 5}}`, which is more consistent but less
206         // efficient.
207         return if (key == "") {
208             putClassDiscriminatorIfRequired(configuration, descriptor, savedState)
209             this
210         } else {
211             val childState = savedState()
212             savedState.write { putSavedState(key, childState) } // Link child to parent.
213             putClassDiscriminatorIfRequired(configuration, descriptor, childState)
214             SavedStateEncoder(childState, configuration)
215         }
216     }
217 
218     @OptIn(ExperimentalSerializationApi::class)
putClassDiscriminatorIfRequirednull219     private fun putClassDiscriminatorIfRequired(
220         configuration: SavedStateConfiguration,
221         descriptor: SerialDescriptor,
222         savedState: SavedState,
223     ) {
224         // POLYMORPHIC is handled by kotlinx.serialization.PolymorphicSerializer.
225         if (configuration.classDiscriminatorMode != ClassDiscriminatorMode.ALL_OBJECTS) {
226             return
227         }
228 
229         if (savedState.read { contains(CLASS_DISCRIMINATOR_KEY) }) {
230             return
231         }
232 
233         if (descriptor.kind == StructureKind.CLASS || descriptor.kind == StructureKind.OBJECT) {
234             savedState.write { putString(CLASS_DISCRIMINATOR_KEY, descriptor.serialName) }
235         }
236     }
237 
encodeSerializableValuenull238     override fun <T> encodeSerializableValue(serializer: SerializationStrategy<T>, value: T) {
239         val encoded = encodeFormatSpecificTypes(serializer, value)
240         if (!encoded) {
241             super.encodeSerializableValue(serializer, value)
242         }
243     }
244 
245     /**
246      * @return `true` if [value] was encoded with SavedState's special representation, `false`
247      *   otherwise.
248      */
249     @Suppress("UNCHECKED_CAST")
encodeFormatSpecificTypesnull250     private fun <T> encodeFormatSpecificTypes(
251         serializer: SerializationStrategy<T>,
252         value: T
253     ): Boolean {
254         val encoded = encodeFormatSpecificTypesOnPlatform(serializer, value)
255         if (!encoded) {
256             when (serializer.descriptor) {
257                 intListDescriptor -> encodeIntList(value as List<Int>)
258                 stringListDescriptor -> encodeStringList(value as List<String>)
259                 booleanArrayDescriptor -> encodeBooleanArray(value as BooleanArray)
260                 charArrayDescriptor -> encodeCharArray(value as CharArray)
261                 doubleArrayDescriptor -> encodeDoubleArray(value as DoubleArray)
262                 floatArrayDescriptor -> encodeFloatArray(value as FloatArray)
263                 intArrayDescriptor -> encodeIntArray(value as IntArray)
264                 longArrayDescriptor -> encodeLongArray(value as LongArray)
265                 stringArrayDescriptor -> encodeStringArray(value as Array<String>)
266                 else -> return false
267             }
268         }
269         return true
270     }
271 }
272 
273 /**
274  * @return `true` if [value] was encoded with SavedState's special representation, `false`
275  *   otherwise.
276  */
encodeFormatSpecificTypesOnPlatformnull277 internal expect fun <T> SavedStateEncoder.encodeFormatSpecificTypesOnPlatform(
278     strategy: SerializationStrategy<T>,
279     value: T
280 ): Boolean
281