1 /*
2  * Copyright 2024 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 @file:OptIn(ExperimentalSerializationApi::class)
18 
19 package androidx.navigation.serialization
20 
21 import androidx.navigation.CollectionNavType
22 import androidx.navigation.NavType
23 import androidx.navigation.NavUriUtils
24 import androidx.savedstate.SavedState
25 import androidx.savedstate.read
26 import androidx.savedstate.write
27 import kotlin.reflect.KType
28 import kotlinx.serialization.ExperimentalSerializationApi
29 import kotlinx.serialization.descriptors.SerialDescriptor
30 import kotlinx.serialization.descriptors.SerialKind
31 import kotlinx.serialization.serializerOrNull
32 
33 /** Marker for Native Kotlin types with either full or partial built-in NavType support */
34 private enum class InternalType {
35     INT,
36     INT_NULLABLE,
37     BOOL,
38     BOOL_NULLABLE,
39     DOUBLE,
40     DOUBLE_NULLABLE,
41     FLOAT,
42     FLOAT_NULLABLE,
43     LONG,
44     LONG_NULLABLE,
45     STRING,
46     STRING_NULLABLE,
47     INT_ARRAY,
48     BOOL_ARRAY,
49     DOUBLE_ARRAY,
50     FLOAT_ARRAY,
51     LONG_ARRAY,
52     ARRAY,
53     LIST,
54     ENUM,
55     ENUM_NULLABLE,
56     UNKNOWN
57 }
58 
59 /**
60  * Converts an argument type to a built-in NavType.
61  *
62  * Built-in NavTypes include NavType objects declared within [NavType.Companion], such as
63  * [NavType.IntType], [NavType.BoolArrayType] etc.
64  *
65  * Returns [UNKNOWN] type if the argument does not have built-in NavType support.
66  */
getNavTypenull67 internal fun SerialDescriptor.getNavType(): NavType<*> {
68     val type =
69         when (this.toInternalType()) {
70             InternalType.INT -> NavType.IntType
71             InternalType.INT_NULLABLE -> InternalNavType.IntNullableType
72             InternalType.BOOL -> NavType.BoolType
73             InternalType.BOOL_NULLABLE -> InternalNavType.BoolNullableType
74             InternalType.DOUBLE -> InternalNavType.DoubleType
75             InternalType.DOUBLE_NULLABLE -> InternalNavType.DoubleNullableType
76             InternalType.FLOAT -> NavType.FloatType
77             InternalType.FLOAT_NULLABLE -> InternalNavType.FloatNullableType
78             InternalType.LONG -> NavType.LongType
79             InternalType.LONG_NULLABLE -> InternalNavType.LongNullableType
80             InternalType.STRING -> InternalNavType.StringNonNullableType
81             InternalType.STRING_NULLABLE -> NavType.StringType
82             InternalType.INT_ARRAY -> NavType.IntArrayType
83             InternalType.BOOL_ARRAY -> NavType.BoolArrayType
84             InternalType.DOUBLE_ARRAY -> InternalNavType.DoubleArrayType
85             InternalType.FLOAT_ARRAY -> NavType.FloatArrayType
86             InternalType.LONG_ARRAY -> NavType.LongArrayType
87             InternalType.ARRAY -> {
88                 val typeParameter = getElementDescriptor(0).toInternalType()
89                 when (typeParameter) {
90                     InternalType.STRING -> NavType.StringArrayType
91                     InternalType.STRING_NULLABLE -> InternalNavType.StringNullableArrayType
92                     else -> UNKNOWN
93                 }
94             }
95             InternalType.LIST -> {
96                 val typeParameter = getElementDescriptor(0).toInternalType()
97                 when (typeParameter) {
98                     InternalType.INT -> NavType.IntListType
99                     InternalType.BOOL -> NavType.BoolListType
100                     InternalType.DOUBLE -> InternalNavType.DoubleListType
101                     InternalType.FLOAT -> NavType.FloatListType
102                     InternalType.LONG -> NavType.LongListType
103                     InternalType.STRING -> NavType.StringListType
104                     InternalType.STRING_NULLABLE -> InternalNavType.StringNullableListType
105                     InternalType.ENUM -> parseEnumList()
106                     else -> UNKNOWN
107                 }
108             }
109             InternalType.ENUM -> parseEnum()
110             InternalType.ENUM_NULLABLE -> parseNullableEnum()
111             else -> UNKNOWN
112         }
113     return type
114 }
115 
parseEnumnull116 internal expect fun SerialDescriptor.parseEnum(): NavType<*>
117 
118 internal expect fun SerialDescriptor.parseNullableEnum(): NavType<*>
119 
120 internal expect fun SerialDescriptor.parseEnumList(): NavType<*>
121 
122 /**
123  * Convert SerialDescriptor to an InternalCommonType.
124  *
125  * The descriptor's associated argument could be any of the native Kotlin types supported in
126  * [InternalType], or it could be an unsupported type (custom class, object or enum).
127  */
128 private fun SerialDescriptor.toInternalType(): InternalType {
129     val serialName = serialName.replace("?", "")
130     return when {
131         kind == SerialKind.ENUM -> if (isNullable) InternalType.ENUM_NULLABLE else InternalType.ENUM
132         serialName == "kotlin.Int" ->
133             if (isNullable) InternalType.INT_NULLABLE else InternalType.INT
134         serialName == "kotlin.Boolean" ->
135             if (isNullable) InternalType.BOOL_NULLABLE else InternalType.BOOL
136         serialName == "kotlin.Double" ->
137             if (isNullable) InternalType.DOUBLE_NULLABLE else InternalType.DOUBLE
138         serialName == "kotlin.Float" ->
139             if (isNullable) InternalType.FLOAT_NULLABLE else InternalType.FLOAT
140         serialName == "kotlin.Long" ->
141             if (isNullable) InternalType.LONG_NULLABLE else InternalType.LONG
142         serialName == "kotlin.String" ->
143             if (isNullable) InternalType.STRING_NULLABLE else InternalType.STRING
144         serialName == "kotlin.IntArray" -> InternalType.INT_ARRAY
145         serialName == "kotlin.DoubleArray" -> InternalType.DOUBLE_ARRAY
146         serialName == "kotlin.BooleanArray" -> InternalType.BOOL_ARRAY
147         serialName == "kotlin.FloatArray" -> InternalType.FLOAT_ARRAY
148         serialName == "kotlin.LongArray" -> InternalType.LONG_ARRAY
149         serialName == "kotlin.Array" -> InternalType.ARRAY
150         // serial name for both List and ArrayList
151         serialName.startsWith("kotlin.collections.ArrayList") -> InternalType.LIST
152         // custom classes or other types without built-in NavTypes
153         else -> InternalType.UNKNOWN
154     }
155 }
156 
157 /**
158  * Match the [SerialDescriptor] of a type to a KType
159  *
160  * Returns true if match, false otherwise.
161  */
matchKTypenull162 internal fun SerialDescriptor.matchKType(kType: KType): Boolean {
163     if (this.isNullable != kType.isMarkedNullable) return false
164     val kTypeSerializer = serializerOrNull(kType)
165     checkNotNull(kTypeSerializer) {
166         "Cannot find KSerializer for [${this.serialName}]. If applicable, custom KSerializers " +
167             "for custom and third-party KType is currently not supported when declared " +
168             "directly on a class field via @Serializable(with = ...). " +
169             "Please use @Serializable or @Serializable(with = ...) on the " +
170             "class or object declaration."
171     }
172     return this == kTypeSerializer.descriptor
173 }
174 
175 internal object UNKNOWN : NavType<String>(false) {
176     override val name: String
177         get() = "unknown"
178 
putnull179     override fun put(bundle: SavedState, key: String, value: String) {}
180 
getnull181     override fun get(bundle: SavedState, key: String): String? = null
182 
183     override fun parseValue(value: String): String = "null"
184 }
185 
186 internal object InternalNavType {
187     val IntNullableType =
188         object : NavType<Int?>(true) {
189             override val name: String
190                 get() = "integer_nullable"
191 
192             override fun put(bundle: SavedState, key: String, value: Int?) {
193                 // store null as serializable inside bundle, so that decoder will use the null
194                 // instead of default value
195                 if (value == null) bundle.write { putNull(key) }
196                 else IntType.put(bundle, key, value)
197             }
198 
199             override fun get(bundle: SavedState, key: String): Int? =
200                 bundle.read { if (contains(key) && !isNull(key)) getInt(key) else null }
201 
202             override fun parseValue(value: String): Int? {
203                 return if (value == "null") null else IntType.parseValue(value)
204             }
205         }
206 
207     val BoolNullableType =
208         object : NavType<Boolean?>(true) {
209             override val name: String
210                 get() = "boolean_nullable"
211 
212             override fun put(bundle: SavedState, key: String, value: Boolean?) {
213                 if (value == null) bundle.write { putNull(key) }
214                 else BoolType.put(bundle, key, value)
215             }
216 
217             override fun get(bundle: SavedState, key: String): Boolean? =
218                 bundle.read { if (contains(key) && !isNull(key)) getBoolean(key) else null }
219 
220             override fun parseValue(value: String): Boolean? {
221                 return if (value == "null") null else BoolType.parseValue(value)
222             }
223         }
224 
225     val DoubleType: NavType<Double> =
226         object : NavType<Double>(false) {
227             override val name: String
228                 get() = "double"
229 
230             override fun put(bundle: SavedState, key: String, value: Double) {
231                 bundle.write { putDouble(key, value) }
232             }
233 
234             override fun get(bundle: SavedState, key: String): Double =
235                 bundle.read { getDouble(key) }
236 
237             override fun parseValue(value: String): Double = value.toDouble()
238         }
239 
240     val DoubleNullableType: NavType<Double?> =
241         object : NavType<Double?>(true) {
242             override val name: String
243                 get() = "double_nullable"
244 
245             override fun put(bundle: SavedState, key: String, value: Double?) {
246                 if (value == null) bundle.write { putNull(key) }
247                 else DoubleType.put(bundle, key, value)
248             }
249 
250             override fun get(bundle: SavedState, key: String): Double? =
251                 bundle.read { if (contains(key) && !isNull(key)) getDouble(key) else null }
252 
253             override fun parseValue(value: String): Double? {
254                 return if (value == "null") null else DoubleType.parseValue(value)
255             }
256         }
257 
258     val FloatNullableType =
259         object : NavType<Float?>(true) {
260             override val name: String
261                 get() = "float_nullable"
262 
263             override fun put(bundle: SavedState, key: String, value: Float?) {
264                 if (value == null) bundle.write { putNull(key) }
265                 else FloatType.put(bundle, key, value)
266             }
267 
268             override fun get(bundle: SavedState, key: String): Float? =
269                 bundle.read { if (contains(key) && !isNull(key)) getFloat(key) else null }
270 
271             override fun parseValue(value: String): Float? {
272                 return if (value == "null") null else FloatType.parseValue(value)
273             }
274         }
275 
276     val LongNullableType =
277         object : NavType<Long?>(true) {
278             override val name: String
279                 get() = "long_nullable"
280 
281             override fun put(bundle: SavedState, key: String, value: Long?) {
282                 if (value == null) bundle.write { putNull(key) }
283                 else LongType.put(bundle, key, value)
284             }
285 
286             override fun get(bundle: SavedState, key: String): Long? =
287                 bundle.read { if (contains(key) && !isNull(key)) getLong(key) else null }
288 
289             override fun parseValue(value: String): Long? {
290                 return if (value == "null") null else LongType.parseValue(value)
291             }
292         }
293 
294     val StringNonNullableType =
295         object : NavType<String>(false) {
296             override val name: String
297                 get() = "string_non_nullable"
298 
299             override fun put(bundle: SavedState, key: String, value: String) {
300                 bundle.write { putString(key, value) }
301             }
302 
303             override fun get(bundle: SavedState, key: String): String =
304                 bundle.read { if (contains(key) && !isNull(key)) getString(key) else "null" }
305 
306             // "null" is still parsed as "null"
307             override fun parseValue(value: String): String = value
308 
309             // "null" is still serialized as "null"
310             override fun serializeAsValue(value: String): String = NavUriUtils.encode(value)
311         }
312 
313     val StringNullableArrayType: NavType<Array<String?>?> =
314         object : CollectionNavType<Array<String?>?>(true) {
315             override val name: String
316                 get() = "string_nullable[]"
317 
318             override fun put(bundle: SavedState, key: String, value: Array<String?>?) {
319                 bundle.write {
320                     if (value == null) putNull(key)
321                     else putStringArray(key, value.map { it ?: "null" }.toTypedArray())
322                 }
323             }
324 
325             @Suppress("UNCHECKED_CAST")
326             override fun get(bundle: SavedState, key: String): Array<String?>? =
327                 bundle.read {
328                     if (contains(key) && !isNull(key)) {
329                         getStringArray(key).map { StringType.parseValue(it) }.toTypedArray()
330                     } else null
331                 }
332 
333             // match String? behavior where null -> null, and "null" -> null
334             override fun parseValue(value: String): Array<String?> =
335                 arrayOf(StringType.parseValue(value))
336 
337             override fun parseValue(
338                 value: String,
339                 previousValue: Array<String?>?
340             ): Array<String?>? = previousValue?.plus(parseValue(value)) ?: parseValue(value)
341 
342             override fun valueEquals(value: Array<String?>?, other: Array<String?>?): Boolean =
343                 value.contentDeepEquals(other)
344 
345             override fun serializeAsValues(value: Array<String?>?): List<String> =
346                 value?.map { it?.let { NavUriUtils.encode(it) } ?: "null" } ?: emptyList()
347 
348             override fun emptyCollection(): Array<String?>? = arrayOf()
349         }
350 
351     val StringNullableListType: NavType<List<String?>?> =
352         object : CollectionNavType<List<String?>?>(true) {
353             override val name: String
354                 get() = "List<String?>"
355 
356             override fun put(bundle: SavedState, key: String, value: List<String?>?) {
357                 bundle.write {
358                     if (value == null) putNull(key)
359                     else putStringArray(key, value.map { it ?: "null" }.toTypedArray())
360                 }
361             }
362 
363             override fun get(bundle: SavedState, key: String): List<String?>? =
364                 bundle.read {
365                     if (contains(key) && !isNull(key)) {
366                         getStringArray(key).toList().map { StringType.parseValue(it) }
367                     } else null
368                 }
369 
370             override fun parseValue(value: String): List<String?> {
371                 return listOf(StringType.parseValue(value))
372             }
373 
374             override fun parseValue(value: String, previousValue: List<String?>?): List<String?>? {
375                 return previousValue?.plus(parseValue(value)) ?: parseValue(value)
376             }
377 
378             override fun valueEquals(value: List<String?>?, other: List<String?>?): Boolean {
379                 val valueArray = value?.toTypedArray()
380                 val otherArray = other?.toTypedArray()
381                 return valueArray.contentDeepEquals(otherArray)
382             }
383 
384             override fun serializeAsValues(value: List<String?>?): List<String> =
385                 value?.map { it?.let { NavUriUtils.encode(it) } ?: "null" } ?: emptyList()
386 
387             override fun emptyCollection(): List<String?> = emptyList()
388         }
389 
390     val DoubleArrayType: NavType<DoubleArray?> =
391         object : CollectionNavType<DoubleArray?>(true) {
392             override val name: String
393                 get() = "double[]"
394 
395             override fun put(bundle: SavedState, key: String, value: DoubleArray?) {
396                 bundle.write { if (value == null) putNull(key) else putDoubleArray(key, value) }
397             }
398 
399             override fun get(bundle: SavedState, key: String): DoubleArray? =
400                 bundle.read { if (contains(key) && !isNull(key)) getDoubleArray(key) else null }
401 
402             override fun parseValue(value: String): DoubleArray =
403                 doubleArrayOf(DoubleType.parseValue(value))
404 
405             override fun parseValue(value: String, previousValue: DoubleArray?): DoubleArray =
406                 previousValue?.plus(parseValue(value)) ?: parseValue(value)
407 
408             override fun valueEquals(value: DoubleArray?, other: DoubleArray?): Boolean {
409                 val valueArray = value?.toTypedArray()
410                 val otherArray = other?.toTypedArray()
411                 return valueArray.contentDeepEquals(otherArray)
412             }
413 
414             override fun serializeAsValues(value: DoubleArray?): List<String> =
415                 value?.toList()?.map { it.toString() } ?: emptyList()
416 
417             override fun emptyCollection(): DoubleArray = doubleArrayOf()
418         }
419 
420     public val DoubleListType: NavType<List<Double>?> =
421         object : CollectionNavType<List<Double>?>(true) {
422             override val name: String
423                 get() = "List<Double>"
424 
425             override fun put(bundle: SavedState, key: String, value: List<Double>?) {
426                 bundle.write {
427                     if (value == null) putNull(key) else putDoubleArray(key, value.toDoubleArray())
428                 }
429             }
430 
431             override fun get(bundle: SavedState, key: String): List<Double>? =
432                 bundle.read {
433                     if (contains(key) && !isNull(key)) getDoubleArray(key).toList() else null
434                 }
435 
436             override fun parseValue(value: String): List<Double> =
437                 listOf(DoubleType.parseValue(value))
438 
439             override fun parseValue(value: String, previousValue: List<Double>?): List<Double>? =
440                 previousValue?.plus(parseValue(value)) ?: parseValue(value)
441 
442             override fun valueEquals(value: List<Double>?, other: List<Double>?): Boolean {
443                 val valueArray = value?.toTypedArray()
444                 val otherArray = other?.toTypedArray()
445                 return valueArray.contentDeepEquals(otherArray)
446             }
447 
448             override fun serializeAsValues(value: List<Double>?): List<String> =
449                 value?.map { it.toString() } ?: emptyList()
450 
451             override fun emptyCollection(): List<Double> = emptyList()
452         }
453 }
454