1 /*
<lambda>null2  * Copyright 2023 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 package androidx.credentials.webauthn
18 
19 import androidx.annotation.RestrictTo
20 import java.lang.IllegalArgumentException
21 
22 @RestrictTo(RestrictTo.Scope.LIBRARY)
23 class Cbor {
24     data class Item(val item: Any, val len: Int)
25 
26     data class Arg(val arg: Long, val len: Int)
27 
28     val TYPE_UNSIGNED_INT = 0x00
29     val TYPE_NEGATIVE_INT = 0x01
30     val TYPE_BYTE_STRING = 0x02
31     val TYPE_TEXT_STRING = 0x03
32     val TYPE_ARRAY = 0x04
33     val TYPE_MAP = 0x05
34     val TYPE_TAG = 0x06
35     val TYPE_FLOAT = 0x07
36 
37     fun decode(data: ByteArray): Any {
38         val ret = parseItem(data, 0)
39         return ret.item
40     }
41 
42     fun encode(data: Any): ByteArray {
43         if (data is Number) {
44             if (data is Double) {
45                 throw IllegalArgumentException("Don't support doubles yet")
46             } else {
47                 val value = data.toLong()
48                 if (value >= 0) {
49                     return createArg(TYPE_UNSIGNED_INT, value)
50                 } else {
51                     return createArg(TYPE_NEGATIVE_INT, -1 - value)
52                 }
53             }
54         }
55         if (data is ByteArray) {
56             return createArg(TYPE_BYTE_STRING, data.size.toLong()) + data
57         }
58         if (data is String) {
59             return createArg(TYPE_TEXT_STRING, data.length.toLong()) + data.encodeToByteArray()
60         }
61         if (data is List<*>) {
62             var ret = createArg(TYPE_ARRAY, data.size.toLong())
63             for (i in data) {
64                 ret += encode(i!!)
65             }
66             return ret
67         }
68         if (data is Map<*, *>) {
69             // See:
70             // https://fidoalliance.org/specs/fido-v2.1-ps-20210615/fido-client-to-authenticator-protocol-v2.1-ps-20210615.html#ctap2-canonical-cbor-encoding-form
71             var ret = createArg(TYPE_MAP, data.size.toLong())
72             var byteMap: MutableMap<ByteArray, ByteArray> = mutableMapOf()
73             for (i in data) {
74                 // Convert to byte arrays so we can sort them.
75                 byteMap.put(encode(i.key!!), encode(i.value!!))
76             }
77 
78             var keysList = ArrayList<ByteArray>(byteMap.keys)
79             keysList.sortedWith(
80                 Comparator<ByteArray> { a, b ->
81                     // If two keys have different lengths, the shorter one sorts earlier;
82                     // If two keys have the same length, the one with the lower value in (byte-wise)
83                     // lexical order sorts earlier.
84                     var aBytes = byteMap.get(a)!!
85                     var bBytes = byteMap.get(b)!!
86                     when {
87                         a.size > b.size -> 1
88                         a.size < b.size -> -1
89                         aBytes.size > bBytes.size -> 1
90                         aBytes.size < bBytes.size -> -1
91                         else -> 0
92                     }
93                 }
94             )
95 
96             for (key in keysList) {
97                 ret += key
98                 ret += byteMap.get(key)!!
99             }
100             return ret
101         }
102         throw IllegalArgumentException("Bad type")
103     }
104 
105     private fun getType(data: ByteArray, offset: Int): Int {
106         val d = data[offset].toInt()
107         return (d and 0xFF) shr 5
108     }
109 
110     private fun getArg(data: ByteArray, offset: Int): Arg {
111         val arg = data[offset].toLong() and 0x1F
112         if (arg < 24) {
113             return Arg(arg, 1)
114         }
115         if (arg == 24L) {
116             return Arg(data[offset + 1].toLong() and 0xFF, 2)
117         }
118         if (arg == 25L) {
119             var ret = (data[offset + 1].toLong() and 0xFF) shl 8
120             ret = ret or (data[offset + 2].toLong() and 0xFF)
121             return Arg(ret, 3)
122         }
123         if (arg == 26L) {
124             var ret = (data[offset + 1].toLong() and 0xFF) shl 24
125             ret = ret or ((data[offset + 2].toLong() and 0xFF) shl 16)
126             ret = ret or ((data[offset + 3].toLong() and 0xFF) shl 8)
127             ret = ret or (data[offset + 4].toLong() and 0xFF)
128             return Arg(ret, 5)
129         }
130         throw IllegalArgumentException("Bad arg")
131     }
132 
133     private fun parseItem(data: ByteArray, offset: Int): Item {
134         val itemType = getType(data, offset)
135         val arg = getArg(data, offset)
136         println("Type $itemType ${arg.arg} ${arg.len}")
137 
138         when (itemType) {
139             TYPE_UNSIGNED_INT -> {
140                 return Item(arg.arg, arg.len)
141             }
142             TYPE_NEGATIVE_INT -> {
143                 return Item(-1 - arg.arg, arg.len)
144             }
145             TYPE_BYTE_STRING -> {
146                 val ret =
147                     data.sliceArray(
148                         offset + arg.len.toInt() until offset + arg.len.toInt() + arg.arg.toInt()
149                     )
150                 return Item(ret, arg.len + arg.arg.toInt())
151             }
152             TYPE_TEXT_STRING -> {
153                 val ret =
154                     data.sliceArray(
155                         offset + arg.len.toInt() until offset + arg.len.toInt() + arg.arg.toInt()
156                     )
157                 return Item(ret.toString(Charsets.UTF_8), arg.len + arg.arg.toInt())
158             }
159             TYPE_ARRAY -> {
160                 val ret = mutableListOf<Any>()
161                 var consumed = arg.len
162                 for (i in 0 until arg.arg.toInt()) {
163                     val item = parseItem(data, offset + consumed)
164                     ret.add(item.item)
165                     consumed += item.len
166                 }
167                 return Item(ret.toList(), consumed)
168             }
169             TYPE_MAP -> {
170                 val ret = mutableMapOf<Any, Any>()
171                 var consumed = arg.len
172                 for (i in 0 until arg.arg.toInt()) {
173                     val key = parseItem(data, offset + consumed)
174                     consumed += key.len
175                     val value = parseItem(data, offset + consumed)
176                     consumed += value.len
177                     ret[key.item] = value.item
178                 }
179                 return Item(ret.toMap(), consumed)
180             }
181             else -> {
182                 throw IllegalArgumentException("Bad type")
183             }
184         }
185     }
186 
187     private fun createArg(type: Int, arg: Long): ByteArray {
188         val t = type shl 5
189         val a = arg.toInt()
190         if (arg < 24) {
191             return byteArrayOf(((t or a) and 0xFF).toByte())
192         }
193         if (arg <= 0xFF) {
194             return byteArrayOf(((t or 24) and 0xFF).toByte(), (a and 0xFF).toByte())
195         }
196         if (arg <= 0xFFFF) {
197             return byteArrayOf(
198                 ((t or 25) and 0xFF).toByte(),
199                 ((a shr 8) and 0xFF).toByte(),
200                 (a and 0xFF).toByte()
201             )
202         }
203         if (arg <= 0xFFFFFFFF) {
204             return byteArrayOf(
205                 ((t or 26) and 0xFF).toByte(),
206                 ((a shr 24) and 0xFF).toByte(),
207                 ((a shr 16) and 0xFF).toByte(),
208                 ((a shr 8) and 0xFF).toByte(),
209                 (a and 0xFF).toByte()
210             )
211         }
212         throw IllegalArgumentException("bad Arg")
213     }
214 }
215