• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * libjingle
3  * Copyright 2003-2008, Google Inc.
4  *
5  * Redistribution and use in source and binary forms, with or without
6  * modification, are permitted provided that the following conditions are met:
7  *
8  *  1. Redistributions of source code must retain the above copyright notice,
9  *     this list of conditions and the following disclaimer.
10  *  2. Redistributions in binary form must reproduce the above copyright notice,
11  *     this list of conditions and the following disclaimer in the documentation
12  *     and/or other materials provided with the distribution.
13  *  3. The name of the author may not be used to endorse or promote products
14  *     derived from this software without specific prior written permission.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED
17  * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
18  * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
19  * EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
20  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
22  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
23  * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
24  * OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
25  * ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 // Registry configuration wrapers class implementation
29 //
30 // Change made by S. Ganesh - ganesh@google.com:
31 //   Use SHQueryValueEx instead of RegQueryValueEx throughout.
32 //   A call to the SHLWAPI function is essentially a call to the standard
33 //   function but with post-processing:
34 //   * to fix REG_SZ or REG_EXPAND_SZ data that is not properly null-terminated;
35 //   * to expand REG_EXPAND_SZ data.
36 
37 #include "talk/base/win32regkey.h"
38 
39 #include <shlwapi.h>
40 
41 #include "talk/base/common.h"
42 #include "talk/base/logging.h"
43 #include "talk/base/scoped_ptr.h"
44 
45 namespace talk_base {
46 
RegKey()47 RegKey::RegKey() {
48   h_key_ = NULL;
49 }
50 
~RegKey()51 RegKey::~RegKey() {
52   Close();
53 }
54 
Create(HKEY parent_key,const wchar_t * key_name)55 HRESULT RegKey::Create(HKEY parent_key, const wchar_t* key_name) {
56   return Create(parent_key,
57                 key_name,
58                 REG_NONE,
59                 REG_OPTION_NON_VOLATILE,
60                 KEY_ALL_ACCESS,
61                 NULL,
62                 NULL);
63 }
64 
Open(HKEY parent_key,const wchar_t * key_name)65 HRESULT RegKey::Open(HKEY parent_key, const wchar_t* key_name) {
66   return Open(parent_key, key_name, KEY_ALL_ACCESS);
67 }
68 
HasValue(const TCHAR * value_name) const69 bool RegKey::HasValue(const TCHAR* value_name) const {
70   return (ERROR_SUCCESS == ::RegQueryValueEx(h_key_, value_name, NULL,
71                                              NULL, NULL, NULL));
72 }
73 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD value)74 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
75                          const wchar_t* value_name,
76                          DWORD value) {
77   ASSERT(full_key_name != NULL);
78 
79   return SetValueStaticHelper(full_key_name, value_name, REG_DWORD, &value);
80 }
81 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD64 value)82 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
83                          const wchar_t* value_name,
84                          DWORD64 value) {
85   ASSERT(full_key_name != NULL);
86 
87   return SetValueStaticHelper(full_key_name, value_name, REG_QWORD, &value);
88 }
89 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,float value)90 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
91                          const wchar_t* value_name,
92                          float value) {
93   ASSERT(full_key_name != NULL);
94 
95   return SetValueStaticHelper(full_key_name, value_name,
96                               REG_BINARY, &value, sizeof(value));
97 }
98 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,double value)99 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
100                          const wchar_t* value_name,
101                          double value) {
102   ASSERT(full_key_name != NULL);
103 
104   return SetValueStaticHelper(full_key_name, value_name,
105                               REG_BINARY, &value, sizeof(value));
106 }
107 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,const TCHAR * value)108 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
109                          const wchar_t* value_name,
110                          const TCHAR* value) {
111   ASSERT(full_key_name != NULL);
112   ASSERT(value != NULL);
113 
114   return SetValueStaticHelper(full_key_name, value_name,
115                               REG_SZ, const_cast<wchar_t*>(value));
116 }
117 
SetValue(const wchar_t * full_key_name,const wchar_t * value_name,const uint8 * value,DWORD byte_count)118 HRESULT RegKey::SetValue(const wchar_t* full_key_name,
119                          const wchar_t* value_name,
120                          const uint8* value,
121                          DWORD byte_count) {
122   ASSERT(full_key_name != NULL);
123 
124   return SetValueStaticHelper(full_key_name, value_name, REG_BINARY,
125                               const_cast<uint8*>(value), byte_count);
126 }
127 
SetValueMultiSZ(const wchar_t * full_key_name,const wchar_t * value_name,const uint8 * value,DWORD byte_count)128 HRESULT RegKey::SetValueMultiSZ(const wchar_t* full_key_name,
129                                 const wchar_t* value_name,
130                                 const uint8* value,
131                                 DWORD byte_count) {
132   ASSERT(full_key_name != NULL);
133 
134   return SetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ,
135                               const_cast<uint8*>(value), byte_count);
136 }
137 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD * value)138 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
139                          const wchar_t* value_name,
140                          DWORD* value) {
141   ASSERT(full_key_name != NULL);
142   ASSERT(value != NULL);
143 
144   return GetValueStaticHelper(full_key_name, value_name, REG_DWORD, value);
145 }
146 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,DWORD64 * value)147 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
148                          const wchar_t* value_name,
149                          DWORD64* value) {
150   ASSERT(full_key_name != NULL);
151   ASSERT(value != NULL);
152 
153   return GetValueStaticHelper(full_key_name, value_name, REG_QWORD, value);
154 }
155 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,float * value)156 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
157                          const wchar_t* value_name,
158                          float* value) {
159   ASSERT(value != NULL);
160   ASSERT(full_key_name != NULL);
161 
162   DWORD byte_count = 0;
163   scoped_ptr<byte[]> buffer;
164   HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
165                                     REG_BINARY, buffer.accept(), &byte_count);
166   if (SUCCEEDED(hr)) {
167     ASSERT(byte_count == sizeof(*value));
168     if (byte_count == sizeof(*value)) {
169       *value = *reinterpret_cast<float*>(buffer.get());
170     }
171   }
172   return hr;
173 }
174 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,double * value)175 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
176                          const wchar_t* value_name,
177                          double* value) {
178   ASSERT(value != NULL);
179   ASSERT(full_key_name != NULL);
180 
181   DWORD byte_count = 0;
182   scoped_ptr<byte[]> buffer;
183   HRESULT hr = GetValueStaticHelper(full_key_name, value_name,
184                                     REG_BINARY, buffer.accept(), &byte_count);
185   if (SUCCEEDED(hr)) {
186     ASSERT(byte_count == sizeof(*value));
187     if (byte_count == sizeof(*value)) {
188       *value = *reinterpret_cast<double*>(buffer.get());
189     }
190   }
191   return hr;
192 }
193 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,wchar_t ** value)194 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
195                          const wchar_t* value_name,
196                          wchar_t** value) {
197   ASSERT(full_key_name != NULL);
198   ASSERT(value != NULL);
199 
200   return GetValueStaticHelper(full_key_name, value_name, REG_SZ, value);
201 }
202 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,std::wstring * value)203 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
204                          const wchar_t* value_name,
205                          std::wstring* value) {
206   ASSERT(full_key_name != NULL);
207   ASSERT(value != NULL);
208 
209   scoped_ptr<wchar_t[]> buffer;
210   HRESULT hr = RegKey::GetValue(full_key_name, value_name, buffer.accept());
211   if (SUCCEEDED(hr)) {
212     value->assign(buffer.get());
213   }
214   return hr;
215 }
216 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,std::vector<std::wstring> * value)217 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
218                          const wchar_t* value_name,
219                          std::vector<std::wstring>* value) {
220   ASSERT(full_key_name != NULL);
221   ASSERT(value != NULL);
222 
223   return GetValueStaticHelper(full_key_name, value_name, REG_MULTI_SZ, value);
224 }
225 
GetValue(const wchar_t * full_key_name,const wchar_t * value_name,uint8 ** value,DWORD * byte_count)226 HRESULT RegKey::GetValue(const wchar_t* full_key_name,
227                          const wchar_t* value_name,
228                          uint8** value,
229                          DWORD* byte_count) {
230   ASSERT(full_key_name != NULL);
231   ASSERT(value != NULL);
232   ASSERT(byte_count != NULL);
233 
234   return GetValueStaticHelper(full_key_name, value_name,
235                               REG_BINARY, value, byte_count);
236 }
237 
DeleteSubKey(const wchar_t * key_name)238 HRESULT RegKey::DeleteSubKey(const wchar_t* key_name) {
239   ASSERT(key_name != NULL);
240   ASSERT(h_key_ != NULL);
241 
242   LONG res = ::RegDeleteKey(h_key_, key_name);
243   HRESULT hr = HRESULT_FROM_WIN32(res);
244   if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
245       hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
246     hr = S_FALSE;
247   }
248   return hr;
249 }
250 
DeleteValue(const wchar_t * value_name)251 HRESULT RegKey::DeleteValue(const wchar_t* value_name) {
252   ASSERT(h_key_ != NULL);
253 
254   LONG res = ::RegDeleteValue(h_key_, value_name);
255   HRESULT hr = HRESULT_FROM_WIN32(res);
256   if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
257       hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
258     hr = S_FALSE;
259   }
260   return hr;
261 }
262 
Close()263 HRESULT RegKey::Close() {
264   HRESULT hr = S_OK;
265   if (h_key_ != NULL) {
266     LONG res = ::RegCloseKey(h_key_);
267     hr = HRESULT_FROM_WIN32(res);
268     h_key_ = NULL;
269   }
270   return hr;
271 }
272 
Create(HKEY parent_key,const wchar_t * key_name,wchar_t * lpszClass,DWORD options,REGSAM sam_desired,LPSECURITY_ATTRIBUTES lpSecAttr,LPDWORD lpdwDisposition)273 HRESULT RegKey::Create(HKEY parent_key,
274                        const wchar_t* key_name,
275                        wchar_t* lpszClass,
276                        DWORD options,
277                        REGSAM sam_desired,
278                        LPSECURITY_ATTRIBUTES lpSecAttr,
279                        LPDWORD lpdwDisposition) {
280   ASSERT(key_name != NULL);
281   ASSERT(parent_key != NULL);
282 
283   DWORD dw = 0;
284   HKEY h_key = NULL;
285   LONG res = ::RegCreateKeyEx(parent_key, key_name, 0, lpszClass, options,
286                               sam_desired, lpSecAttr, &h_key, &dw);
287   HRESULT hr = HRESULT_FROM_WIN32(res);
288 
289   if (lpdwDisposition) {
290     *lpdwDisposition = dw;
291   }
292 
293   // we have to close the currently opened key
294   // before replacing it with the new one
295   if (hr == S_OK) {
296     hr = Close();
297     ASSERT(hr == S_OK);
298     h_key_ = h_key;
299   }
300   return hr;
301 }
302 
Open(HKEY parent_key,const wchar_t * key_name,REGSAM sam_desired)303 HRESULT RegKey::Open(HKEY parent_key,
304                      const wchar_t* key_name,
305                      REGSAM sam_desired) {
306   ASSERT(key_name != NULL);
307   ASSERT(parent_key != NULL);
308 
309   HKEY h_key = NULL;
310   LONG res = ::RegOpenKeyEx(parent_key, key_name, 0, sam_desired, &h_key);
311   HRESULT hr = HRESULT_FROM_WIN32(res);
312 
313   // we have to close the currently opened key
314   // before replacing it with the new one
315   if (hr == S_OK) {
316     // close the currently opened key if any
317     hr = Close();
318     ASSERT(hr == S_OK);
319     h_key_ = h_key;
320   }
321   return hr;
322 }
323 
324 // save the key and all of its subkeys and values to a file
Save(const wchar_t * full_key_name,const wchar_t * file_name)325 HRESULT RegKey::Save(const wchar_t* full_key_name, const wchar_t* file_name) {
326   ASSERT(full_key_name != NULL);
327   ASSERT(file_name != NULL);
328 
329   std::wstring key_name(full_key_name);
330   HKEY h_key = GetRootKeyInfo(&key_name);
331   if (!h_key) {
332     return E_FAIL;
333   }
334 
335   RegKey key;
336   HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
337   if (FAILED(hr)) {
338     return hr;
339   }
340 
341   AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, true);
342   LONG res = ::RegSaveKey(key.h_key_, file_name, NULL);
343   AdjustCurrentProcessPrivilege(SE_BACKUP_NAME, false);
344 
345   return HRESULT_FROM_WIN32(res);
346 }
347 
348 // restore the key and all of its subkeys and values which are saved into a file
Restore(const wchar_t * full_key_name,const wchar_t * file_name)349 HRESULT RegKey::Restore(const wchar_t* full_key_name,
350                         const wchar_t* file_name) {
351   ASSERT(full_key_name != NULL);
352   ASSERT(file_name != NULL);
353 
354   std::wstring key_name(full_key_name);
355   HKEY h_key = GetRootKeyInfo(&key_name);
356   if (!h_key) {
357     return E_FAIL;
358   }
359 
360   RegKey key;
361   HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_WRITE);
362   if (FAILED(hr)) {
363     return hr;
364   }
365 
366   AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, true);
367   LONG res = ::RegRestoreKey(key.h_key_, file_name, REG_FORCE_RESTORE);
368   AdjustCurrentProcessPrivilege(SE_RESTORE_NAME, false);
369 
370   return HRESULT_FROM_WIN32(res);
371 }
372 
373 // check if the current key has the specified subkey
HasSubkey(const wchar_t * key_name) const374 bool RegKey::HasSubkey(const wchar_t* key_name) const {
375   ASSERT(key_name != NULL);
376 
377   RegKey key;
378   HRESULT hr = key.Open(h_key_, key_name, KEY_READ);
379   key.Close();
380   return hr == S_OK;
381 }
382 
383 // static flush key
FlushKey(const wchar_t * full_key_name)384 HRESULT RegKey::FlushKey(const wchar_t* full_key_name) {
385   ASSERT(full_key_name != NULL);
386 
387   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
388   // get the root HKEY
389   std::wstring key_name(full_key_name);
390   HKEY h_key = GetRootKeyInfo(&key_name);
391 
392   if (h_key != NULL) {
393     LONG res = ::RegFlushKey(h_key);
394     hr = HRESULT_FROM_WIN32(res);
395   }
396   return hr;
397 }
398 
399 // static SET helper
SetValueStaticHelper(const wchar_t * full_key_name,const wchar_t * value_name,DWORD type,LPVOID value,DWORD byte_count)400 HRESULT RegKey::SetValueStaticHelper(const wchar_t* full_key_name,
401                                      const wchar_t* value_name,
402                                      DWORD type,
403                                      LPVOID value,
404                                      DWORD byte_count) {
405   ASSERT(full_key_name != NULL);
406 
407   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
408   // get the root HKEY
409   std::wstring key_name(full_key_name);
410   HKEY h_key = GetRootKeyInfo(&key_name);
411 
412   if (h_key != NULL) {
413     RegKey key;
414     hr = key.Create(h_key, key_name.c_str());
415     if (hr == S_OK) {
416       switch (type) {
417         case REG_DWORD:
418           hr = key.SetValue(value_name, *(static_cast<DWORD*>(value)));
419           break;
420         case REG_QWORD:
421           hr = key.SetValue(value_name, *(static_cast<DWORD64*>(value)));
422           break;
423         case REG_SZ:
424           hr = key.SetValue(value_name, static_cast<const wchar_t*>(value));
425           break;
426         case REG_BINARY:
427           hr = key.SetValue(value_name, static_cast<const uint8*>(value),
428                             byte_count);
429           break;
430         case REG_MULTI_SZ:
431           hr = key.SetValue(value_name, static_cast<const uint8*>(value),
432                             byte_count, type);
433           break;
434         default:
435           ASSERT(false);
436           hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
437           break;
438       }
439       // close the key after writing
440       HRESULT temp_hr = key.Close();
441       if (hr == S_OK) {
442         hr = temp_hr;
443       }
444     }
445   }
446   return hr;
447 }
448 
449 // static GET helper
GetValueStaticHelper(const wchar_t * full_key_name,const wchar_t * value_name,DWORD type,LPVOID value,DWORD * byte_count)450 HRESULT RegKey::GetValueStaticHelper(const wchar_t* full_key_name,
451                                      const wchar_t* value_name,
452                                      DWORD type,
453                                      LPVOID value,
454                                      DWORD* byte_count) {
455   ASSERT(full_key_name != NULL);
456 
457   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
458   // get the root HKEY
459   std::wstring key_name(full_key_name);
460   HKEY h_key = GetRootKeyInfo(&key_name);
461 
462   if (h_key != NULL) {
463     RegKey key;
464     hr = key.Open(h_key, key_name.c_str(), KEY_READ);
465     if (hr == S_OK) {
466       switch (type) {
467         case REG_DWORD:
468           hr = key.GetValue(value_name, reinterpret_cast<DWORD*>(value));
469           break;
470         case REG_QWORD:
471           hr = key.GetValue(value_name, reinterpret_cast<DWORD64*>(value));
472           break;
473         case REG_SZ:
474           hr = key.GetValue(value_name, reinterpret_cast<wchar_t**>(value));
475           break;
476         case REG_MULTI_SZ:
477           hr = key.GetValue(value_name, reinterpret_cast<
478                                             std::vector<std::wstring>*>(value));
479           break;
480         case REG_BINARY:
481           hr = key.GetValue(value_name, reinterpret_cast<uint8**>(value),
482                             byte_count);
483           break;
484         default:
485           ASSERT(false);
486           hr = HRESULT_FROM_WIN32(ERROR_DATATYPE_MISMATCH);
487           break;
488       }
489       // close the key after writing
490       HRESULT temp_hr = key.Close();
491       if (hr == S_OK) {
492         hr = temp_hr;
493       }
494     }
495   }
496   return hr;
497 }
498 
499 // GET helper
GetValueHelper(const wchar_t * value_name,DWORD * type,uint8 ** value,DWORD * byte_count) const500 HRESULT RegKey::GetValueHelper(const wchar_t* value_name,
501                                DWORD* type,
502                                uint8** value,
503                                DWORD* byte_count) const {
504   ASSERT(byte_count != NULL);
505   ASSERT(value != NULL);
506   ASSERT(type != NULL);
507 
508   // init return buffer
509   *value = NULL;
510 
511   // get the size of the return data buffer
512   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, type, NULL, byte_count);
513   HRESULT hr = HRESULT_FROM_WIN32(res);
514 
515   if (hr == S_OK) {
516     // if the value length is 0, nothing to do
517     if (*byte_count != 0) {
518       // allocate the buffer
519       *value = new byte[*byte_count];
520       ASSERT(*value != NULL);
521 
522       // make the call again to get the data
523       res = ::SHQueryValueEx(h_key_, value_name, NULL,
524                              type, *value, byte_count);
525       hr = HRESULT_FROM_WIN32(res);
526       ASSERT(hr == S_OK);
527     }
528   }
529   return hr;
530 }
531 
532 // Int32 Get
GetValue(const wchar_t * value_name,DWORD * value) const533 HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD* value) const {
534   ASSERT(value != NULL);
535 
536   DWORD type = 0;
537   DWORD byte_count = sizeof(DWORD);
538   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
539                               value, &byte_count);
540   HRESULT hr = HRESULT_FROM_WIN32(res);
541   ASSERT((hr != S_OK) || (type == REG_DWORD));
542   ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD)));
543   return hr;
544 }
545 
546 // Int64 Get
GetValue(const wchar_t * value_name,DWORD64 * value) const547 HRESULT RegKey::GetValue(const wchar_t* value_name, DWORD64* value) const {
548   ASSERT(value != NULL);
549 
550   DWORD type = 0;
551   DWORD byte_count = sizeof(DWORD64);
552   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
553                               value, &byte_count);
554   HRESULT hr = HRESULT_FROM_WIN32(res);
555   ASSERT((hr != S_OK) || (type == REG_QWORD));
556   ASSERT((hr != S_OK) || (byte_count == sizeof(DWORD64)));
557   return hr;
558 }
559 
560 // String Get
GetValue(const wchar_t * value_name,wchar_t ** value) const561 HRESULT RegKey::GetValue(const wchar_t* value_name, wchar_t** value) const {
562   ASSERT(value != NULL);
563 
564   DWORD byte_count = 0;
565   DWORD type = 0;
566 
567   // first get the size of the string buffer
568   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
569                               &type, NULL, &byte_count);
570   HRESULT hr = HRESULT_FROM_WIN32(res);
571 
572   if (hr == S_OK) {
573     // allocate room for the string and a terminating \0
574     *value = new wchar_t[(byte_count / sizeof(wchar_t)) + 1];
575 
576     if ((*value) != NULL) {
577       if (byte_count != 0) {
578         // make the call again
579         res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
580                                *value, &byte_count);
581         hr = HRESULT_FROM_WIN32(res);
582       } else {
583         (*value)[0] = L'\0';
584       }
585 
586       ASSERT((hr != S_OK) || (type == REG_SZ) ||
587              (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
588     } else {
589       hr = E_OUTOFMEMORY;
590     }
591   }
592 
593   return hr;
594 }
595 
596 // get a string value
GetValue(const wchar_t * value_name,std::wstring * value) const597 HRESULT RegKey::GetValue(const wchar_t* value_name, std::wstring* value) const {
598   ASSERT(value != NULL);
599 
600   DWORD byte_count = 0;
601   DWORD type = 0;
602 
603   // first get the size of the string buffer
604   LONG res = ::SHQueryValueEx(h_key_, value_name, NULL,
605                               &type, NULL, &byte_count);
606   HRESULT hr = HRESULT_FROM_WIN32(res);
607 
608   if (hr == S_OK) {
609     if (byte_count != 0) {
610       // Allocate some memory and make the call again
611       value->resize(byte_count / sizeof(wchar_t) + 1);
612       res = ::SHQueryValueEx(h_key_, value_name, NULL, &type,
613                              &value->at(0), &byte_count);
614       hr = HRESULT_FROM_WIN32(res);
615       value->resize(wcslen(value->data()));
616     } else {
617       value->clear();
618     }
619 
620     ASSERT((hr != S_OK) || (type == REG_SZ) ||
621            (type == REG_MULTI_SZ) || (type == REG_EXPAND_SZ));
622   }
623 
624   return hr;
625 }
626 
627 // convert REG_MULTI_SZ bytes to string array
MultiSZBytesToStringArray(const uint8 * buffer,DWORD byte_count,std::vector<std::wstring> * value)628 HRESULT RegKey::MultiSZBytesToStringArray(const uint8* buffer,
629                                           DWORD byte_count,
630                                           std::vector<std::wstring>* value) {
631   ASSERT(buffer != NULL);
632   ASSERT(value != NULL);
633 
634   const wchar_t* data = reinterpret_cast<const wchar_t*>(buffer);
635   DWORD data_len = byte_count / sizeof(wchar_t);
636   value->clear();
637   if (data_len > 1) {
638     // must be terminated by two null characters
639     if (data[data_len - 1] != 0 || data[data_len - 2] != 0) {
640       return E_INVALIDARG;
641     }
642 
643     // put null-terminated strings into arrays
644     while (*data) {
645       std::wstring str(data);
646       value->push_back(str);
647       data += str.length() + 1;
648     }
649   }
650   return S_OK;
651 }
652 
653 // get a std::vector<std::wstring> value from REG_MULTI_SZ type
GetValue(const wchar_t * value_name,std::vector<std::wstring> * value) const654 HRESULT RegKey::GetValue(const wchar_t* value_name,
655                          std::vector<std::wstring>* value) const {
656   ASSERT(value != NULL);
657 
658   DWORD byte_count = 0;
659   DWORD type = 0;
660   uint8* buffer = 0;
661 
662   // first get the size of the buffer
663   HRESULT hr = GetValueHelper(value_name, &type, &buffer, &byte_count);
664   ASSERT((hr != S_OK) || (type == REG_MULTI_SZ));
665 
666   if (SUCCEEDED(hr)) {
667     hr = MultiSZBytesToStringArray(buffer, byte_count, value);
668   }
669 
670   return hr;
671 }
672 
673 // Binary data Get
GetValue(const wchar_t * value_name,uint8 ** value,DWORD * byte_count) const674 HRESULT RegKey::GetValue(const wchar_t* value_name,
675                          uint8** value,
676                          DWORD* byte_count) const {
677   ASSERT(byte_count != NULL);
678   ASSERT(value != NULL);
679 
680   DWORD type = 0;
681   HRESULT hr = GetValueHelper(value_name, &type, value, byte_count);
682   ASSERT((hr != S_OK) || (type == REG_MULTI_SZ) || (type == REG_BINARY));
683   return hr;
684 }
685 
686 // Raw data get
GetValue(const wchar_t * value_name,uint8 ** value,DWORD * byte_count,DWORD * type) const687 HRESULT RegKey::GetValue(const wchar_t* value_name,
688                          uint8** value,
689                          DWORD* byte_count,
690                          DWORD*type) const {
691   ASSERT(type != NULL);
692   ASSERT(byte_count != NULL);
693   ASSERT(value != NULL);
694 
695   return GetValueHelper(value_name, type, value, byte_count);
696 }
697 
698 // Int32 set
SetValue(const wchar_t * value_name,DWORD value) const699 HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD value) const {
700   ASSERT(h_key_ != NULL);
701 
702   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_DWORD,
703                              reinterpret_cast<const uint8*>(&value),
704                              sizeof(DWORD));
705   return HRESULT_FROM_WIN32(res);
706 }
707 
708 // Int64 set
SetValue(const wchar_t * value_name,DWORD64 value) const709 HRESULT RegKey::SetValue(const wchar_t* value_name, DWORD64 value) const {
710   ASSERT(h_key_ != NULL);
711 
712   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_QWORD,
713                              reinterpret_cast<const uint8*>(&value),
714                              sizeof(DWORD64));
715   return HRESULT_FROM_WIN32(res);
716 }
717 
718 // String set
SetValue(const wchar_t * value_name,const wchar_t * value) const719 HRESULT RegKey::SetValue(const wchar_t* value_name,
720                          const wchar_t* value) const {
721   ASSERT(value != NULL);
722   ASSERT(h_key_ != NULL);
723 
724   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, REG_SZ,
725                              reinterpret_cast<const uint8*>(value),
726                              (lstrlen(value) + 1) * sizeof(wchar_t));
727   return HRESULT_FROM_WIN32(res);
728 }
729 
730 // Binary data set
SetValue(const wchar_t * value_name,const uint8 * value,DWORD byte_count) const731 HRESULT RegKey::SetValue(const wchar_t* value_name,
732                          const uint8* value,
733                          DWORD byte_count) const {
734   ASSERT(h_key_ != NULL);
735 
736   // special case - if 'value' is NULL make sure byte_count is zero
737   if (value == NULL) {
738     byte_count = 0;
739   }
740 
741   LONG res = ::RegSetValueEx(h_key_, value_name, NULL,
742                              REG_BINARY, value, byte_count);
743   return HRESULT_FROM_WIN32(res);
744 }
745 
746 // Raw data set
SetValue(const wchar_t * value_name,const uint8 * value,DWORD byte_count,DWORD type) const747 HRESULT RegKey::SetValue(const wchar_t* value_name,
748                          const uint8* value,
749                          DWORD byte_count,
750                          DWORD type) const {
751   ASSERT(value != NULL);
752   ASSERT(h_key_ != NULL);
753 
754   LONG res = ::RegSetValueEx(h_key_, value_name, NULL, type, value, byte_count);
755   return HRESULT_FROM_WIN32(res);
756 }
757 
HasKey(const wchar_t * full_key_name)758 bool RegKey::HasKey(const wchar_t* full_key_name) {
759   ASSERT(full_key_name != NULL);
760 
761   // get the root HKEY
762   std::wstring key_name(full_key_name);
763   HKEY h_key = GetRootKeyInfo(&key_name);
764 
765   if (h_key != NULL) {
766     RegKey key;
767     HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
768     key.Close();
769     return S_OK == hr;
770   }
771   return false;
772 }
773 
774 // static version of HasValue
HasValue(const wchar_t * full_key_name,const wchar_t * value_name)775 bool RegKey::HasValue(const wchar_t* full_key_name, const wchar_t* value_name) {
776   ASSERT(full_key_name != NULL);
777 
778   bool has_value = false;
779   // get the root HKEY
780   std::wstring key_name(full_key_name);
781   HKEY h_key = GetRootKeyInfo(&key_name);
782 
783   if (h_key != NULL) {
784     RegKey key;
785     if (key.Open(h_key, key_name.c_str(), KEY_READ) == S_OK) {
786       has_value = key.HasValue(value_name);
787       key.Close();
788     }
789   }
790   return has_value;
791 }
792 
GetValueType(const wchar_t * full_key_name,const wchar_t * value_name,DWORD * value_type)793 HRESULT RegKey::GetValueType(const wchar_t* full_key_name,
794                              const wchar_t* value_name,
795                              DWORD* value_type) {
796   ASSERT(full_key_name != NULL);
797   ASSERT(value_type != NULL);
798 
799   *value_type = REG_NONE;
800 
801   std::wstring key_name(full_key_name);
802   HKEY h_key = GetRootKeyInfo(&key_name);
803 
804   RegKey key;
805   HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
806   if (SUCCEEDED(hr)) {
807     LONG res = ::SHQueryValueEx(key.h_key_, value_name, NULL, value_type,
808                                 NULL, NULL);
809     if (res != ERROR_SUCCESS) {
810       hr = HRESULT_FROM_WIN32(res);
811     }
812   }
813 
814   return hr;
815 }
816 
DeleteKey(const wchar_t * full_key_name)817 HRESULT RegKey::DeleteKey(const wchar_t* full_key_name) {
818   ASSERT(full_key_name != NULL);
819 
820   return DeleteKey(full_key_name, true);
821 }
822 
DeleteKey(const wchar_t * full_key_name,bool recursively)823 HRESULT RegKey::DeleteKey(const wchar_t* full_key_name, bool recursively) {
824   ASSERT(full_key_name != NULL);
825 
826   // need to open the parent key first
827   // get the root HKEY
828   std::wstring key_name(full_key_name);
829   HKEY h_key = GetRootKeyInfo(&key_name);
830 
831   // get the parent key
832   std::wstring parent_key(GetParentKeyInfo(&key_name));
833 
834   RegKey key;
835   HRESULT hr = key.Open(h_key, parent_key.c_str());
836 
837   if (hr == S_OK) {
838     hr = recursively ? key.RecurseDeleteSubKey(key_name.c_str())
839                      : key.DeleteSubKey(key_name.c_str());
840   } else if (hr == HRESULT_FROM_WIN32(ERROR_FILE_NOT_FOUND) ||
841              hr == HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND)) {
842     hr = S_FALSE;
843   }
844 
845   key.Close();
846   return hr;
847 }
848 
DeleteValue(const wchar_t * full_key_name,const wchar_t * value_name)849 HRESULT RegKey::DeleteValue(const wchar_t* full_key_name,
850                             const wchar_t* value_name) {
851   ASSERT(full_key_name != NULL);
852 
853   HRESULT hr = HRESULT_FROM_WIN32(ERROR_PATH_NOT_FOUND);
854   // get the root HKEY
855   std::wstring key_name(full_key_name);
856   HKEY h_key = GetRootKeyInfo(&key_name);
857 
858   if (h_key != NULL) {
859     RegKey key;
860     hr = key.Open(h_key, key_name.c_str());
861     if (hr == S_OK) {
862       hr = key.DeleteValue(value_name);
863       key.Close();
864     }
865   }
866   return hr;
867 }
868 
RecurseDeleteSubKey(const wchar_t * key_name)869 HRESULT RegKey::RecurseDeleteSubKey(const wchar_t* key_name) {
870   ASSERT(key_name != NULL);
871 
872   RegKey key;
873   HRESULT hr = key.Open(h_key_, key_name);
874 
875   if (hr == S_OK) {
876     // enumerate all subkeys of this key and recursivelly delete them
877     FILETIME time = {0};
878     wchar_t key_name_buf[kMaxKeyNameChars] = {0};
879     DWORD key_name_buf_size = kMaxKeyNameChars;
880     while (hr == S_OK &&
881         ::RegEnumKeyEx(key.h_key_, 0, key_name_buf, &key_name_buf_size,
882                        NULL, NULL, NULL,  &time) == ERROR_SUCCESS) {
883       hr = key.RecurseDeleteSubKey(key_name_buf);
884 
885       // restore the buffer size
886       key_name_buf_size = kMaxKeyNameChars;
887     }
888     // close the top key
889     key.Close();
890   }
891 
892   if (hr == S_OK) {
893     // the key has no more children keys
894     // delete the key and all of its values
895     hr = DeleteSubKey(key_name);
896   }
897 
898   return hr;
899 }
900 
GetRootKeyInfo(std::wstring * full_key_name)901 HKEY RegKey::GetRootKeyInfo(std::wstring* full_key_name) {
902   ASSERT(full_key_name != NULL);
903 
904   HKEY h_key = NULL;
905   // get the root HKEY
906   size_t index = full_key_name->find(L'\\');
907   std::wstring root_key;
908 
909   if (index == -1) {
910     root_key = *full_key_name;
911     *full_key_name = L"";
912   } else {
913     root_key = full_key_name->substr(0, index);
914     *full_key_name = full_key_name->substr(index + 1,
915                                            full_key_name->length() - index - 1);
916   }
917 
918   for (std::wstring::iterator iter = root_key.begin();
919        iter != root_key.end(); ++iter) {
920     *iter = toupper(*iter);
921   }
922 
923   if (!root_key.compare(L"HKLM") ||
924       !root_key.compare(L"HKEY_LOCAL_MACHINE")) {
925     h_key = HKEY_LOCAL_MACHINE;
926   } else if (!root_key.compare(L"HKCU") ||
927              !root_key.compare(L"HKEY_CURRENT_USER")) {
928     h_key = HKEY_CURRENT_USER;
929   } else if (!root_key.compare(L"HKU") ||
930              !root_key.compare(L"HKEY_USERS")) {
931     h_key = HKEY_USERS;
932   } else if (!root_key.compare(L"HKCR") ||
933              !root_key.compare(L"HKEY_CLASSES_ROOT")) {
934     h_key = HKEY_CLASSES_ROOT;
935   }
936 
937   return h_key;
938 }
939 
940 
941 // Returns true if this key name is 'safe' for deletion
942 // (doesn't specify a key root)
SafeKeyNameForDeletion(const wchar_t * key_name)943 bool RegKey::SafeKeyNameForDeletion(const wchar_t* key_name) {
944   ASSERT(key_name != NULL);
945   std::wstring key(key_name);
946 
947   HKEY root_key = GetRootKeyInfo(&key);
948 
949   if (!root_key) {
950     key = key_name;
951   }
952   if (key.empty()) {
953     return false;
954   }
955   bool found_subkey = false, backslash_found = false;
956   for (size_t i = 0 ; i < key.length() ; ++i) {
957     if (key[i] == L'\\') {
958       backslash_found = true;
959     } else if (backslash_found) {
960       found_subkey = true;
961       break;
962     }
963   }
964   return (root_key == HKEY_USERS) ? found_subkey : true;
965 }
966 
GetParentKeyInfo(std::wstring * key_name)967 std::wstring RegKey::GetParentKeyInfo(std::wstring* key_name) {
968   ASSERT(key_name != NULL);
969 
970   // get the parent key
971   size_t index = key_name->rfind(L'\\');
972   std::wstring parent_key;
973   if (index == -1) {
974     parent_key = L"";
975   } else {
976     parent_key = key_name->substr(0, index);
977     *key_name = key_name->substr(index + 1, key_name->length() - index - 1);
978   }
979 
980   return parent_key;
981 }
982 
983 // get the number of values for this key
GetValueCount()984 uint32 RegKey::GetValueCount() {
985   DWORD num_values = 0;
986 
987   LONG res = ::RegQueryInfoKey(
988         h_key_,                  // key handle
989         NULL,                    // buffer for class name
990         NULL,                    // size of class string
991         NULL,                    // reserved
992         NULL,                    // number of subkeys
993         NULL,                    // longest subkey size
994         NULL,                    // longest class string
995         &num_values,             // number of values for this key
996         NULL,                    // longest value name
997         NULL,                    // longest value data
998         NULL,                    // security descriptor
999         NULL);                   // last write time
1000 
1001   ASSERT(res == ERROR_SUCCESS);
1002   return num_values;
1003 }
1004 
1005 // Enumerators for the value_names for this key
1006 
1007 // Called to get the value name for the given value name index
1008 // Use GetValueCount() to get the total value_name count for this key
1009 // Returns failure if no key at the specified index
GetValueNameAt(int index,std::wstring * value_name,DWORD * type)1010 HRESULT RegKey::GetValueNameAt(int index, std::wstring* value_name,
1011                                DWORD* type) {
1012   ASSERT(value_name != NULL);
1013 
1014   LONG res = ERROR_SUCCESS;
1015   wchar_t value_name_buf[kMaxValueNameChars] = {0};
1016   DWORD value_name_buf_size = kMaxValueNameChars;
1017   res = ::RegEnumValue(h_key_, index, value_name_buf, &value_name_buf_size,
1018                        NULL, type, NULL, NULL);
1019 
1020   if (res == ERROR_SUCCESS) {
1021     value_name->assign(value_name_buf);
1022   }
1023 
1024   return HRESULT_FROM_WIN32(res);
1025 }
1026 
GetSubkeyCount()1027 uint32 RegKey::GetSubkeyCount() {
1028   // number of values for key
1029   DWORD num_subkeys = 0;
1030 
1031   LONG res = ::RegQueryInfoKey(
1032     h_key_,                  // key handle
1033     NULL,                    // buffer for class name
1034     NULL,                    // size of class string
1035     NULL,                    // reserved
1036     &num_subkeys,            // number of subkeys
1037     NULL,                    // longest subkey size
1038     NULL,                    // longest class string
1039     NULL,                    // number of values for this key
1040     NULL,                    // longest value name
1041     NULL,                    // longest value data
1042     NULL,                    // security descriptor
1043     NULL);                   // last write time
1044 
1045   ASSERT(res == ERROR_SUCCESS);
1046   return num_subkeys;
1047 }
1048 
GetSubkeyNameAt(int index,std::wstring * key_name)1049 HRESULT RegKey::GetSubkeyNameAt(int index, std::wstring* key_name) {
1050   ASSERT(key_name != NULL);
1051 
1052   LONG res = ERROR_SUCCESS;
1053   wchar_t key_name_buf[kMaxKeyNameChars] = {0};
1054   DWORD key_name_buf_size = kMaxKeyNameChars;
1055 
1056   res = ::RegEnumKeyEx(h_key_, index, key_name_buf, &key_name_buf_size,
1057                        NULL, NULL, NULL, NULL);
1058 
1059   if (res == ERROR_SUCCESS) {
1060     key_name->assign(key_name_buf);
1061   }
1062 
1063   return HRESULT_FROM_WIN32(res);
1064 }
1065 
1066 // Is the key empty: having no sub-keys and values
IsKeyEmpty(const wchar_t * full_key_name)1067 bool RegKey::IsKeyEmpty(const wchar_t* full_key_name) {
1068   ASSERT(full_key_name != NULL);
1069 
1070   bool is_empty = true;
1071 
1072   // Get the root HKEY
1073   std::wstring key_name(full_key_name);
1074   HKEY h_key = GetRootKeyInfo(&key_name);
1075 
1076   // Open the key to check
1077   if (h_key != NULL) {
1078     RegKey key;
1079     HRESULT hr = key.Open(h_key, key_name.c_str(), KEY_READ);
1080     if (SUCCEEDED(hr)) {
1081       is_empty = key.GetSubkeyCount() == 0 && key.GetValueCount() == 0;
1082       key.Close();
1083     }
1084   }
1085 
1086   return is_empty;
1087 }
1088 
AdjustCurrentProcessPrivilege(const TCHAR * privilege,bool to_enable)1089 bool AdjustCurrentProcessPrivilege(const TCHAR* privilege, bool to_enable) {
1090   ASSERT(privilege != NULL);
1091 
1092   bool ret = false;
1093   HANDLE token;
1094   if (::OpenProcessToken(::GetCurrentProcess(),
1095                          TOKEN_ADJUST_PRIVILEGES | TOKEN_QUERY, &token)) {
1096     LUID luid;
1097     memset(&luid, 0, sizeof(luid));
1098     if (::LookupPrivilegeValue(NULL, privilege, &luid)) {
1099       TOKEN_PRIVILEGES privs;
1100       privs.PrivilegeCount = 1;
1101       privs.Privileges[0].Luid = luid;
1102       privs.Privileges[0].Attributes = to_enable ? SE_PRIVILEGE_ENABLED : 0;
1103       if (::AdjustTokenPrivileges(token, FALSE, &privs, 0, NULL, 0)) {
1104         ret = true;
1105       } else {
1106         LOG_GLE(LS_ERROR) << "AdjustTokenPrivileges failed";
1107       }
1108     } else {
1109       LOG_GLE(LS_ERROR) << "LookupPrivilegeValue failed";
1110     }
1111     CloseHandle(token);
1112   } else {
1113     LOG_GLE(LS_ERROR) << "OpenProcessToken(GetCurrentProcess) failed";
1114   }
1115 
1116   return ret;
1117 }
1118 
1119 }  // namespace talk_base
1120