1 // Copyright 2013 The Chromium Authors. All rights reserved.
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 #include "media/cdm/json_web_key.h"
6
7 #include "base/base64.h"
8 #include "base/json/json_reader.h"
9 #include "base/json/json_string_value_serializer.h"
10 #include "base/json/string_escape.h"
11 #include "base/logging.h"
12 #include "base/memory/scoped_ptr.h"
13 #include "base/strings/string_util.h"
14 #include "base/values.h"
15
16 namespace media {
17
18 const char kKeysTag[] = "keys";
19 const char kKeyTypeTag[] = "kty";
20 const char kSymmetricKeyValue[] = "oct";
21 const char kKeyTag[] = "k";
22 const char kKeyIdTag[] = "kid";
23 const char kKeyIdsTag[] = "kids";
24 const char kBase64Padding = '=';
25 const char kTypeTag[] = "type";
26 const char kPersistentType[] = "persistent";
27 const char kTemporaryType[] = "temporary";
28
29 // Encodes |input| into a base64 string without padding.
EncodeBase64(const uint8 * input,int input_length)30 static std::string EncodeBase64(const uint8* input, int input_length) {
31 std::string encoded_text;
32 base::Base64Encode(
33 std::string(reinterpret_cast<const char*>(input), input_length),
34 &encoded_text);
35
36 // Remove any padding characters added by Base64Encode().
37 size_t found = encoded_text.find_last_not_of(kBase64Padding);
38 if (found != std::string::npos)
39 encoded_text.erase(found + 1);
40
41 return encoded_text;
42 }
43
44 // Decodes an unpadded base64 string. Returns empty string on error.
DecodeBase64(const std::string & encoded_text)45 static std::string DecodeBase64(const std::string& encoded_text) {
46 // EME spec doesn't allow padding characters.
47 if (encoded_text.find_first_of(kBase64Padding) != std::string::npos)
48 return std::string();
49
50 // Since base::Base64Decode() requires padding characters, add them so length
51 // of |encoded_text| is exactly a multiple of 4.
52 size_t num_last_grouping_chars = encoded_text.length() % 4;
53 std::string modified_text = encoded_text;
54 if (num_last_grouping_chars > 0)
55 modified_text.append(4 - num_last_grouping_chars, kBase64Padding);
56
57 std::string decoded_text;
58 if (!base::Base64Decode(modified_text, &decoded_text))
59 return std::string();
60
61 return decoded_text;
62 }
63
GenerateJWKSet(const uint8 * key,int key_length,const uint8 * key_id,int key_id_length)64 std::string GenerateJWKSet(const uint8* key, int key_length,
65 const uint8* key_id, int key_id_length) {
66 // Both |key| and |key_id| need to be base64 encoded strings in the JWK.
67 std::string key_base64 = EncodeBase64(key, key_length);
68 std::string key_id_base64 = EncodeBase64(key_id, key_id_length);
69
70 // Create the JWK, and wrap it into a JWK Set.
71 scoped_ptr<base::DictionaryValue> jwk(new base::DictionaryValue());
72 jwk->SetString(kKeyTypeTag, kSymmetricKeyValue);
73 jwk->SetString(kKeyTag, key_base64);
74 jwk->SetString(kKeyIdTag, key_id_base64);
75 scoped_ptr<base::ListValue> list(new base::ListValue());
76 list->Append(jwk.release());
77 base::DictionaryValue jwk_set;
78 jwk_set.Set(kKeysTag, list.release());
79
80 // Finally serialize |jwk_set| into a string and return it.
81 std::string serialized_jwk;
82 JSONStringValueSerializer serializer(&serialized_jwk);
83 serializer.Serialize(jwk_set);
84 return serialized_jwk;
85 }
86
87 // Processes a JSON Web Key to extract the key id and key value. Sets |jwk_key|
88 // to the id/value pair and returns true on success.
ConvertJwkToKeyPair(const base::DictionaryValue & jwk,KeyIdAndKeyPair * jwk_key)89 static bool ConvertJwkToKeyPair(const base::DictionaryValue& jwk,
90 KeyIdAndKeyPair* jwk_key) {
91 // Have found a JWK, start by checking that it is a symmetric key.
92 std::string type;
93 if (!jwk.GetString(kKeyTypeTag, &type) || type != kSymmetricKeyValue) {
94 DVLOG(1) << "JWK is not a symmetric key";
95 return false;
96 }
97
98 // Get the key id and actual key parameters.
99 std::string encoded_key_id;
100 std::string encoded_key;
101 if (!jwk.GetString(kKeyIdTag, &encoded_key_id)) {
102 DVLOG(1) << "Missing '" << kKeyIdTag << "' parameter";
103 return false;
104 }
105 if (!jwk.GetString(kKeyTag, &encoded_key)) {
106 DVLOG(1) << "Missing '" << kKeyTag << "' parameter";
107 return false;
108 }
109
110 // Key ID and key are base64-encoded strings, so decode them.
111 std::string raw_key_id = DecodeBase64(encoded_key_id);
112 if (raw_key_id.empty()) {
113 DVLOG(1) << "Invalid '" << kKeyIdTag << "' value: " << encoded_key_id;
114 return false;
115 }
116
117 std::string raw_key = DecodeBase64(encoded_key);
118 if (raw_key.empty()) {
119 DVLOG(1) << "Invalid '" << kKeyTag << "' value: " << encoded_key;
120 return false;
121 }
122
123 // Add the decoded key ID and the decoded key to the list.
124 *jwk_key = std::make_pair(raw_key_id, raw_key);
125 return true;
126 }
127
ExtractKeysFromJWKSet(const std::string & jwk_set,KeyIdAndKeyPairs * keys,MediaKeys::SessionType * session_type)128 bool ExtractKeysFromJWKSet(const std::string& jwk_set,
129 KeyIdAndKeyPairs* keys,
130 MediaKeys::SessionType* session_type) {
131 if (!base::IsStringASCII(jwk_set))
132 return false;
133
134 scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(jwk_set));
135 if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY)
136 return false;
137
138 // Locate the set from the dictionary.
139 base::DictionaryValue* dictionary =
140 static_cast<base::DictionaryValue*>(root.get());
141 base::ListValue* list_val = NULL;
142 if (!dictionary->GetList(kKeysTag, &list_val)) {
143 DVLOG(1) << "Missing '" << kKeysTag
144 << "' parameter or not a list in JWK Set";
145 return false;
146 }
147
148 // Create a local list of keys, so that |jwk_keys| only gets updated on
149 // success.
150 KeyIdAndKeyPairs local_keys;
151 for (size_t i = 0; i < list_val->GetSize(); ++i) {
152 base::DictionaryValue* jwk = NULL;
153 if (!list_val->GetDictionary(i, &jwk)) {
154 DVLOG(1) << "Unable to access '" << kKeysTag << "'[" << i
155 << "] in JWK Set";
156 return false;
157 }
158 KeyIdAndKeyPair key_pair;
159 if (!ConvertJwkToKeyPair(*jwk, &key_pair)) {
160 DVLOG(1) << "Error from '" << kKeysTag << "'[" << i << "]";
161 return false;
162 }
163 local_keys.push_back(key_pair);
164 }
165
166 // Successfully processed all JWKs in the set. Now check if "type" is
167 // specified.
168 base::Value* value = NULL;
169 std::string type_id;
170 if (!dictionary->Get(kTypeTag, &value)) {
171 // Not specified, so use the default type.
172 *session_type = MediaKeys::TEMPORARY_SESSION;
173 } else if (!value->GetAsString(&type_id)) {
174 DVLOG(1) << "Invalid '" << kTypeTag << "' value";
175 return false;
176 } else if (type_id == kPersistentType) {
177 *session_type = MediaKeys::PERSISTENT_SESSION;
178 } else if (type_id == kTemporaryType) {
179 *session_type = MediaKeys::TEMPORARY_SESSION;
180 } else {
181 DVLOG(1) << "Invalid '" << kTypeTag << "' value: " << type_id;
182 return false;
183 }
184
185 // All done.
186 keys->swap(local_keys);
187 return true;
188 }
189
CreateLicenseRequest(const uint8 * key_id,int key_id_length,MediaKeys::SessionType session_type,std::vector<uint8> * license)190 void CreateLicenseRequest(const uint8* key_id,
191 int key_id_length,
192 MediaKeys::SessionType session_type,
193 std::vector<uint8>* license) {
194 // Create the license request.
195 scoped_ptr<base::DictionaryValue> request(new base::DictionaryValue());
196 scoped_ptr<base::ListValue> list(new base::ListValue());
197 list->AppendString(EncodeBase64(key_id, key_id_length));
198 request->Set(kKeyIdsTag, list.release());
199
200 switch (session_type) {
201 case MediaKeys::TEMPORARY_SESSION:
202 request->SetString(kTypeTag, kTemporaryType);
203 break;
204 case MediaKeys::PERSISTENT_SESSION:
205 request->SetString(kTypeTag, kPersistentType);
206 break;
207 }
208
209 // Serialize the license request as a string.
210 std::string json;
211 JSONStringValueSerializer serializer(&json);
212 serializer.Serialize(*request);
213
214 // Convert the serialized license request into std::vector and return it.
215 std::vector<uint8> result(json.begin(), json.end());
216 license->swap(result);
217 }
218
ExtractFirstKeyIdFromLicenseRequest(const std::vector<uint8> & license,std::vector<uint8> * first_key)219 bool ExtractFirstKeyIdFromLicenseRequest(const std::vector<uint8>& license,
220 std::vector<uint8>* first_key) {
221 const std::string license_as_str(
222 reinterpret_cast<const char*>(!license.empty() ? &license[0] : NULL),
223 license.size());
224 if (!base::IsStringASCII(license_as_str))
225 return false;
226
227 scoped_ptr<base::Value> root(base::JSONReader().ReadToValue(license_as_str));
228 if (!root.get() || root->GetType() != base::Value::TYPE_DICTIONARY)
229 return false;
230
231 // Locate the set from the dictionary.
232 base::DictionaryValue* dictionary =
233 static_cast<base::DictionaryValue*>(root.get());
234 base::ListValue* list_val = NULL;
235 if (!dictionary->GetList(kKeyIdsTag, &list_val)) {
236 DVLOG(1) << "Missing '" << kKeyIdsTag << "' parameter or not a list";
237 return false;
238 }
239
240 // Get the first key.
241 if (list_val->GetSize() < 1) {
242 DVLOG(1) << "Empty '" << kKeyIdsTag << "' list";
243 return false;
244 }
245
246 std::string encoded_key;
247 if (!list_val->GetString(0, &encoded_key)) {
248 DVLOG(1) << "First entry in '" << kKeyIdsTag << "' not a string";
249 return false;
250 }
251
252 std::string decoded_string = DecodeBase64(encoded_key);
253 if (decoded_string.empty()) {
254 DVLOG(1) << "Invalid '" << kKeyIdsTag << "' value: " << encoded_key;
255 return false;
256 }
257
258 std::vector<uint8> result(decoded_string.begin(), decoded_string.end());
259 first_key->swap(result);
260 return true;
261 }
262
263 } // namespace media
264