• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2012 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5"""A bare-bones test server for testing cloud policy support.
6
7This implements a simple cloud policy test server that can be used to test
8chrome's device management service client. The policy information is read from
9the file named device_management in the server's data directory. It contains
10enforced and recommended policies for the device and user scope, and a list
11of managed users.
12
13The format of the file is JSON. The root dictionary contains a list under the
14key "managed_users". It contains auth tokens for which the server will claim
15that the user is managed. The token string "*" indicates that all users are
16claimed to be managed. Other keys in the root dictionary identify request
17scopes. The user-request scope is described by a dictionary that holds two
18sub-dictionaries: "mandatory" and "recommended". Both these hold the policy
19definitions as key/value stores, their format is identical to what the Linux
20implementation reads from /etc.
21The device-scope holds the policy-definition directly as key/value stores in the
22protobuf-format.
23
24Example:
25
26{
27  "google/chromeos/device" : {
28    "guest_mode_enabled" : false
29  },
30  "google/chromeos/user" : {
31    "mandatory" : {
32      "HomepageLocation" : "http://www.chromium.org",
33      "IncognitoEnabled" : false
34    },
35     "recommended" : {
36      "JavascriptEnabled": false
37    }
38  },
39  "google/chromeos/publicaccount/user@example.com" : {
40    "mandatory" : {
41      "HomepageLocation" : "http://www.chromium.org"
42    },
43     "recommended" : {
44    }
45  },
46  "managed_users" : [
47    "secret123456"
48  ],
49  "current_key_index": 0,
50  "robot_api_auth_code": "fake_auth_code",
51  "invalidation_source": 1025,
52  "invalidation_name": "UENUPOL"
53}
54
55"""
56
57import BaseHTTPServer
58import cgi
59import google.protobuf.text_format
60import hashlib
61import logging
62import os
63import random
64import re
65import sys
66import time
67import tlslite
68import tlslite.api
69import tlslite.utils
70import tlslite.utils.cryptomath
71import urlparse
72
73# The name and availability of the json module varies in python versions.
74try:
75  import simplejson as json
76except ImportError:
77  try:
78    import json
79  except ImportError:
80    json = None
81
82import asn1der
83import testserver_base
84
85import device_management_backend_pb2 as dm
86import cloud_policy_pb2 as cp
87import chrome_extension_policy_pb2 as ep
88
89# Device policy is only available on Chrome OS builds.
90try:
91  import chrome_device_policy_pb2 as dp
92except ImportError:
93  dp = None
94
95# ASN.1 object identifier for PKCS#1/RSA.
96PKCS1_RSA_OID = '\x2a\x86\x48\x86\xf7\x0d\x01\x01\x01'
97
98# SHA256 sum of "0".
99SHA256_0 = hashlib.sha256('0').digest()
100
101# List of bad machine identifiers that trigger the |valid_serial_number_missing|
102# flag to be set set in the policy fetch response.
103BAD_MACHINE_IDS = [ '123490EN400015' ]
104
105# List of machines that trigger the server to send kiosk enrollment response
106# for the register request.
107KIOSK_MACHINE_IDS = [ 'KIOSK' ]
108
109
110class PolicyRequestHandler(BaseHTTPServer.BaseHTTPRequestHandler):
111  """Decodes and handles device management requests from clients.
112
113  The handler implements all the request parsing and protobuf message decoding
114  and encoding. It calls back into the server to lookup, register, and
115  unregister clients.
116  """
117
118  def __init__(self, request, client_address, server):
119    """Initialize the handler.
120
121    Args:
122      request: The request data received from the client as a string.
123      client_address: The client address.
124      server: The TestServer object to use for (un)registering clients.
125    """
126    BaseHTTPServer.BaseHTTPRequestHandler.__init__(self, request,
127                                                   client_address, server)
128
129  def GetUniqueParam(self, name):
130    """Extracts a unique query parameter from the request.
131
132    Args:
133      name: Names the parameter to fetch.
134    Returns:
135      The parameter value or None if the parameter doesn't exist or is not
136      unique.
137    """
138    if not hasattr(self, '_params'):
139      self._params = cgi.parse_qs(self.path[self.path.find('?') + 1:])
140
141    param_list = self._params.get(name, [])
142    if len(param_list) == 1:
143      return param_list[0]
144    return None
145
146  def do_GET(self):
147    """Handles GET requests.
148
149    Currently this is only used to serve external policy data."""
150    sep = self.path.find('?')
151    path = self.path if sep == -1 else self.path[:sep]
152    if path == '/externalpolicydata':
153      http_response, raw_reply = self.HandleExternalPolicyDataRequest()
154    else:
155      http_response = 404
156      raw_reply = 'Invalid path'
157    self.send_response(http_response)
158    self.end_headers()
159    self.wfile.write(raw_reply)
160
161  def do_POST(self):
162    http_response, raw_reply = self.HandleRequest()
163    self.send_response(http_response)
164    if (http_response == 200):
165      self.send_header('Content-Type', 'application/x-protobuffer')
166    self.end_headers()
167    self.wfile.write(raw_reply)
168
169  def HandleExternalPolicyDataRequest(self):
170    """Handles a request to download policy data for a component."""
171    policy_key = self.GetUniqueParam('key')
172    if not policy_key:
173      return (400, 'Missing key parameter')
174    data = self.server.ReadPolicyDataFromDataDir(policy_key)
175    if data is None:
176      return (404, 'Policy not found for ' + policy_key)
177    return (200, data)
178
179  def HandleRequest(self):
180    """Handles a request.
181
182    Parses the data supplied at construction time and returns a pair indicating
183    http status code and response data to be sent back to the client.
184
185    Returns:
186      A tuple of HTTP status code and response data to send to the client.
187    """
188    rmsg = dm.DeviceManagementRequest()
189    length = int(self.headers.getheader('content-length'))
190    rmsg.ParseFromString(self.rfile.read(length))
191
192    logging.debug('gaia auth token -> ' +
193                  self.headers.getheader('Authorization', ''))
194    logging.debug('oauth token -> ' + str(self.GetUniqueParam('oauth_token')))
195    logging.debug('deviceid -> ' + str(self.GetUniqueParam('deviceid')))
196    self.DumpMessage('Request', rmsg)
197
198    request_type = self.GetUniqueParam('request')
199    # Check server side requirements, as defined in
200    # device_management_backend.proto.
201    if (self.GetUniqueParam('devicetype') != '2' or
202        self.GetUniqueParam('apptype') != 'Chrome' or
203        (request_type != 'ping' and
204         len(self.GetUniqueParam('deviceid')) >= 64) or
205        len(self.GetUniqueParam('agent')) >= 64):
206      return (400, 'Invalid request parameter')
207    if request_type == 'register':
208      return self.ProcessRegister(rmsg.register_request)
209    if request_type == 'api_authorization':
210      return self.ProcessApiAuthorization(rmsg.service_api_access_request)
211    elif request_type == 'unregister':
212      return self.ProcessUnregister(rmsg.unregister_request)
213    elif request_type == 'policy' or request_type == 'ping':
214      return self.ProcessPolicy(rmsg.policy_request, request_type)
215    elif request_type == 'enterprise_check':
216      return self.ProcessAutoEnrollment(rmsg.auto_enrollment_request)
217    else:
218      return (400, 'Invalid request parameter')
219
220  def CreatePolicyForExternalPolicyData(self, policy_key):
221    """Returns an ExternalPolicyData protobuf for policy_key.
222
223    If there is policy data for policy_key then the download url will be
224    set so that it points to that data, and the appropriate hash is also set.
225    Otherwise, the protobuf will be empty.
226
227    Args:
228      policy_key: the policy type and settings entity id, joined by '/'.
229
230    Returns:
231      A serialized ExternalPolicyData.
232    """
233    settings = ep.ExternalPolicyData()
234    data = self.server.ReadPolicyDataFromDataDir(policy_key)
235    if data:
236      settings.download_url = urlparse.urljoin(
237          self.server.GetBaseURL(), 'externalpolicydata?key=%s' % policy_key)
238      settings.secure_hash = hashlib.sha1(data).digest()
239    return settings.SerializeToString()
240
241  def CheckGoogleLogin(self):
242    """Extracts the auth token from the request and returns it. The token may
243    either be a GoogleLogin token from an Authorization header, or an OAuth V2
244    token from the oauth_token query parameter. Returns None if no token is
245    present.
246    """
247    oauth_token = self.GetUniqueParam('oauth_token')
248    if oauth_token:
249      return oauth_token
250
251    match = re.match('GoogleLogin auth=(\\w+)',
252                     self.headers.getheader('Authorization', ''))
253    if match:
254      return match.group(1)
255
256    return None
257
258  def ProcessRegister(self, msg):
259    """Handles a register request.
260
261    Checks the query for authorization and device identifier, registers the
262    device with the server and constructs a response.
263
264    Args:
265      msg: The DeviceRegisterRequest message received from the client.
266
267    Returns:
268      A tuple of HTTP status code and response data to send to the client.
269    """
270    # Check the auth token and device ID.
271    auth = self.CheckGoogleLogin()
272    if not auth:
273      return (403, 'No authorization')
274
275    policy = self.server.GetPolicies()
276    if ('*' not in policy['managed_users'] and
277        auth not in policy['managed_users']):
278      return (403, 'Unmanaged')
279
280    device_id = self.GetUniqueParam('deviceid')
281    if not device_id:
282      return (400, 'Missing device identifier')
283
284    token_info = self.server.RegisterDevice(device_id,
285                                             msg.machine_id,
286                                             msg.type)
287
288    # Send back the reply.
289    response = dm.DeviceManagementResponse()
290    response.register_response.device_management_token = (
291        token_info['device_token'])
292    response.register_response.machine_name = token_info['machine_name']
293    response.register_response.enrollment_type = token_info['enrollment_mode']
294
295    self.DumpMessage('Response', response)
296
297    return (200, response.SerializeToString())
298
299  def ProcessApiAuthorization(self, msg):
300    """Handles an API authorization request.
301
302    Args:
303      msg: The DeviceServiceApiAccessRequest message received from the client.
304
305    Returns:
306      A tuple of HTTP status code and response data to send to the client.
307    """
308    policy = self.server.GetPolicies()
309
310    # Return the auth code from the config file if it's defined,
311    # else return a descriptive default value.
312    response = dm.DeviceManagementResponse()
313    response.service_api_access_response.auth_code = policy.get(
314        'robot_api_auth_code', 'policy_testserver.py-auth_code')
315    self.DumpMessage('Response', response)
316
317    return (200, response.SerializeToString())
318
319  def ProcessUnregister(self, msg):
320    """Handles a register request.
321
322    Checks for authorization, unregisters the device and constructs the
323    response.
324
325    Args:
326      msg: The DeviceUnregisterRequest message received from the client.
327
328    Returns:
329      A tuple of HTTP status code and response data to send to the client.
330    """
331    # Check the management token.
332    token, response = self.CheckToken()
333    if not token:
334      return response
335
336    # Unregister the device.
337    self.server.UnregisterDevice(token['device_token'])
338
339    # Prepare and send the response.
340    response = dm.DeviceManagementResponse()
341    response.unregister_response.CopyFrom(dm.DeviceUnregisterResponse())
342
343    self.DumpMessage('Response', response)
344
345    return (200, response.SerializeToString())
346
347  def ProcessPolicy(self, msg, request_type):
348    """Handles a policy request.
349
350    Checks for authorization, encodes the policy into protobuf representation
351    and constructs the response.
352
353    Args:
354      msg: The DevicePolicyRequest message received from the client.
355
356    Returns:
357      A tuple of HTTP status code and response data to send to the client.
358    """
359    token_info, error = self.CheckToken()
360    if not token_info:
361      return error
362
363    response = dm.DeviceManagementResponse()
364    for request in msg.request:
365      fetch_response = response.policy_response.response.add()
366      if (request.policy_type in
367             ('google/chrome/user',
368              'google/chromeos/user',
369              'google/chromeos/device',
370              'google/chromeos/publicaccount',
371              'google/chrome/extension')):
372        if request_type != 'policy':
373          fetch_response.error_code = 400
374          fetch_response.error_message = 'Invalid request type'
375        else:
376          self.ProcessCloudPolicy(request, token_info, fetch_response)
377      else:
378        fetch_response.error_code = 400
379        fetch_response.error_message = 'Invalid policy_type'
380
381    return (200, response.SerializeToString())
382
383  def ProcessAutoEnrollment(self, msg):
384    """Handles an auto-enrollment check request.
385
386    The reply depends on the value of the modulus:
387      1: replies with no new modulus and the sha256 hash of "0"
388      2: replies with a new modulus, 4.
389      4: replies with a new modulus, 2.
390      8: fails with error 400.
391      16: replies with a new modulus, 16.
392      32: replies with a new modulus, 1.
393      anything else: replies with no new modulus and an empty list of hashes
394
395    These allow the client to pick the testing scenario its wants to simulate.
396
397    Args:
398      msg: The DeviceAutoEnrollmentRequest message received from the client.
399
400    Returns:
401      A tuple of HTTP status code and response data to send to the client.
402    """
403    auto_enrollment_response = dm.DeviceAutoEnrollmentResponse()
404
405    if msg.modulus == 1:
406      auto_enrollment_response.hash.append(SHA256_0)
407    elif msg.modulus == 2:
408      auto_enrollment_response.expected_modulus = 4
409    elif msg.modulus == 4:
410      auto_enrollment_response.expected_modulus = 2
411    elif msg.modulus == 8:
412      return (400, 'Server error')
413    elif msg.modulus == 16:
414      auto_enrollment_response.expected_modulus = 16
415    elif msg.modulus == 32:
416      auto_enrollment_response.expected_modulus = 1
417
418    response = dm.DeviceManagementResponse()
419    response.auto_enrollment_response.CopyFrom(auto_enrollment_response)
420    return (200, response.SerializeToString())
421
422  def SetProtobufMessageField(self, group_message, field, field_value):
423    '''Sets a field in a protobuf message.
424
425    Args:
426      group_message: The protobuf message.
427      field: The field of the message to set, it should be a member of
428          group_message.DESCRIPTOR.fields.
429      field_value: The value to set.
430    '''
431    if field.label == field.LABEL_REPEATED:
432      assert type(field_value) == list
433      entries = group_message.__getattribute__(field.name)
434      if field.message_type is None:
435        for list_item in field_value:
436          entries.append(list_item)
437      else:
438        # This field is itself a protobuf.
439        sub_type = field.message_type
440        for sub_value in field_value:
441          assert type(sub_value) == dict
442          # Add a new sub-protobuf per list entry.
443          sub_message = entries.add()
444          # Now iterate over its fields and recursively add them.
445          for sub_field in sub_message.DESCRIPTOR.fields:
446            if sub_field.name in sub_value:
447              value = sub_value[sub_field.name]
448              self.SetProtobufMessageField(sub_message, sub_field, value)
449      return
450    elif field.type == field.TYPE_BOOL:
451      assert type(field_value) == bool
452    elif field.type == field.TYPE_STRING:
453      assert type(field_value) == str or type(field_value) == unicode
454    elif field.type == field.TYPE_INT64:
455      assert type(field_value) == int
456    elif (field.type == field.TYPE_MESSAGE and
457          field.message_type.name == 'StringList'):
458      assert type(field_value) == list
459      entries = group_message.__getattribute__(field.name).entries
460      for list_item in field_value:
461        entries.append(list_item)
462      return
463    else:
464      raise Exception('Unknown field type %s' % field.type)
465    group_message.__setattr__(field.name, field_value)
466
467  def GatherDevicePolicySettings(self, settings, policies):
468    '''Copies all the policies from a dictionary into a protobuf of type
469    CloudDeviceSettingsProto.
470
471    Args:
472      settings: The destination ChromeDeviceSettingsProto protobuf.
473      policies: The source dictionary containing policies in JSON format.
474    '''
475    for group in settings.DESCRIPTOR.fields:
476      # Create protobuf message for group.
477      group_message = eval('dp.' + group.message_type.name + '()')
478      # Indicates if at least one field was set in |group_message|.
479      got_fields = False
480      # Iterate over fields of the message and feed them from the
481      # policy config file.
482      for field in group_message.DESCRIPTOR.fields:
483        field_value = None
484        if field.name in policies:
485          got_fields = True
486          field_value = policies[field.name]
487          self.SetProtobufMessageField(group_message, field, field_value)
488      if got_fields:
489        settings.__getattribute__(group.name).CopyFrom(group_message)
490
491  def GatherUserPolicySettings(self, settings, policies):
492    '''Copies all the policies from a dictionary into a protobuf of type
493    CloudPolicySettings.
494
495    Args:
496      settings: The destination: a CloudPolicySettings protobuf.
497      policies: The source: a dictionary containing policies under keys
498          'recommended' and 'mandatory'.
499    '''
500    for field in settings.DESCRIPTOR.fields:
501      # |field| is the entry for a specific policy in the top-level
502      # CloudPolicySettings proto.
503
504      # Look for this policy's value in the mandatory or recommended dicts.
505      if field.name in policies.get('mandatory', {}):
506        mode = cp.PolicyOptions.MANDATORY
507        value = policies['mandatory'][field.name]
508      elif field.name in policies.get('recommended', {}):
509        mode = cp.PolicyOptions.RECOMMENDED
510        value = policies['recommended'][field.name]
511      else:
512        continue
513
514      # Create protobuf message for this policy.
515      policy_message = eval('cp.' + field.message_type.name + '()')
516      policy_message.policy_options.mode = mode
517      field_descriptor = policy_message.DESCRIPTOR.fields_by_name['value']
518      self.SetProtobufMessageField(policy_message, field_descriptor, value)
519      settings.__getattribute__(field.name).CopyFrom(policy_message)
520
521  def ProcessCloudPolicy(self, msg, token_info, response):
522    """Handles a cloud policy request. (New protocol for policy requests.)
523
524    Encodes the policy into protobuf representation, signs it and constructs
525    the response.
526
527    Args:
528      msg: The CloudPolicyRequest message received from the client.
529      token_info: the token extracted from the request.
530      response: A PolicyFetchResponse message that should be filled with the
531                response data.
532    """
533
534    if msg.machine_id:
535      self.server.UpdateMachineId(token_info['device_token'], msg.machine_id)
536
537    # Response is only given if the scope is specified in the config file.
538    # Normally 'google/chromeos/device', 'google/chromeos/user' and
539    # 'google/chromeos/publicaccount' should be accepted.
540    policy = self.server.GetPolicies()
541    policy_value = ''
542    policy_key = msg.policy_type
543    if msg.settings_entity_id:
544      policy_key += '/' + msg.settings_entity_id
545    if msg.policy_type in token_info['allowed_policy_types']:
546      if (msg.policy_type == 'google/chromeos/user' or
547          msg.policy_type == 'google/chrome/user' or
548          msg.policy_type == 'google/chromeos/publicaccount'):
549        settings = cp.CloudPolicySettings()
550        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
551        if payload is None:
552          self.GatherUserPolicySettings(settings, policy.get(policy_key, {}))
553          payload = settings.SerializeToString()
554      elif dp is not None and msg.policy_type == 'google/chromeos/device':
555        settings = dp.ChromeDeviceSettingsProto()
556        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
557        if payload is None:
558          self.GatherDevicePolicySettings(settings, policy.get(policy_key, {}))
559          payload = settings.SerializeToString()
560      elif msg.policy_type == 'google/chrome/extension':
561        settings = ep.ExternalPolicyData()
562        payload = self.server.ReadPolicyFromDataDir(policy_key, settings)
563        if payload is None:
564          payload = self.CreatePolicyForExternalPolicyData(policy_key)
565      else:
566        response.error_code = 400
567        response.error_message = 'Invalid policy type'
568        return
569    else:
570      response.error_code = 400
571      response.error_message = 'Request not allowed for the token used'
572      return
573
574    # Sign with 'current_key_index', defaulting to key 0.
575    signing_key = None
576    req_key = None
577    current_key_index = policy.get('current_key_index', 0)
578    nkeys = len(self.server.keys)
579    if (msg.signature_type == dm.PolicyFetchRequest.SHA1_RSA and
580        current_key_index in range(nkeys)):
581      signing_key = self.server.keys[current_key_index]
582      if msg.public_key_version in range(1, nkeys + 1):
583        # requested key exists, use for signing and rotate.
584        req_key = self.server.keys[msg.public_key_version - 1]['private_key']
585
586    # Fill the policy data protobuf.
587    policy_data = dm.PolicyData()
588    policy_data.policy_type = msg.policy_type
589    policy_data.timestamp = int(time.time() * 1000)
590    policy_data.request_token = token_info['device_token']
591    policy_data.policy_value = payload
592    policy_data.machine_name = token_info['machine_name']
593    policy_data.valid_serial_number_missing = (
594        token_info['machine_id'] in BAD_MACHINE_IDS)
595    policy_data.settings_entity_id = msg.settings_entity_id
596    policy_data.service_account_identity = policy.get(
597        'service_account_identity',
598        'policy_testserver.py-service_account_identity')
599    invalidation_source = policy.get('invalidation_source')
600    if invalidation_source is not None:
601      policy_data.invalidation_source = invalidation_source
602    # Since invalidation_name is type bytes in the proto, the Unicode name
603    # provided needs to be encoded as ASCII to set the correct byte pattern.
604    invalidation_name = policy.get('invalidation_name')
605    if invalidation_name is not None:
606      policy_data.invalidation_name = invalidation_name.encode('ascii')
607
608    if signing_key:
609      policy_data.public_key_version = current_key_index + 1
610    if msg.policy_type == 'google/chromeos/publicaccount':
611      policy_data.username = msg.settings_entity_id
612    else:
613      # For regular user/device policy, there is no way for the testserver to
614      # know the user name belonging to the GAIA auth token we received (short
615      # of actually talking to GAIA). To address this, we read the username from
616      # the policy configuration dictionary, or use a default.
617      policy_data.username = policy.get('policy_user', 'user@example.com')
618    policy_data.device_id = token_info['device_id']
619    signed_data = policy_data.SerializeToString()
620
621    response.policy_data = signed_data
622    if signing_key:
623      response.policy_data_signature = (
624          signing_key['private_key'].hashAndSign(signed_data).tostring())
625      if msg.public_key_version != current_key_index + 1:
626        response.new_public_key = signing_key['public_key']
627        if req_key:
628          response.new_public_key_signature = (
629              req_key.hashAndSign(response.new_public_key).tostring())
630
631    self.DumpMessage('Response', response)
632
633    return (200, response.SerializeToString())
634
635  def CheckToken(self):
636    """Helper for checking whether the client supplied a valid DM token.
637
638    Extracts the token from the request and passed to the server in order to
639    look up the client.
640
641    Returns:
642      A pair of token information record and error response. If the first
643      element is None, then the second contains an error code to send back to
644      the client. Otherwise the first element is the same structure that is
645      returned by LookupToken().
646    """
647    error = 500
648    dmtoken = None
649    request_device_id = self.GetUniqueParam('deviceid')
650    match = re.match('GoogleDMToken token=(\\w+)',
651                     self.headers.getheader('Authorization', ''))
652    if match:
653      dmtoken = match.group(1)
654    if not dmtoken:
655      error = 401
656    else:
657      token_info = self.server.LookupToken(dmtoken)
658      if (not token_info or
659          not request_device_id or
660          token_info['device_id'] != request_device_id):
661        error = 410
662      else:
663        return (token_info, None)
664
665    logging.debug('Token check failed with error %d' % error)
666
667    return (None, (error, 'Server error %d' % error))
668
669  def DumpMessage(self, label, msg):
670    """Helper for logging an ASCII dump of a protobuf message."""
671    logging.debug('%s\n%s' % (label, str(msg)))
672
673
674class PolicyTestServer(testserver_base.BrokenPipeHandlerMixIn,
675                       testserver_base.StoppableHTTPServer):
676  """Handles requests and keeps global service state."""
677
678  def __init__(self, server_address, data_dir, policy_path, client_state_file,
679               private_key_paths, server_base_url):
680    """Initializes the server.
681
682    Args:
683      server_address: Server host and port.
684      policy_path: Names the file to read JSON-formatted policy from.
685      private_key_paths: List of paths to read private keys from.
686    """
687    testserver_base.StoppableHTTPServer.__init__(self, server_address,
688                                                 PolicyRequestHandler)
689    self._registered_tokens = {}
690    self.data_dir = data_dir
691    self.policy_path = policy_path
692    self.client_state_file = client_state_file
693    self.server_base_url = server_base_url
694
695    self.keys = []
696    if private_key_paths:
697      # Load specified keys from the filesystem.
698      for key_path in private_key_paths:
699        try:
700          key_str = open(key_path).read()
701        except IOError:
702          print 'Failed to load private key from %s' % key_path
703          continue
704
705        try:
706          key = tlslite.api.parsePEMKey(key_str, private=True)
707        except SyntaxError:
708          key = tlslite.utils.Python_RSAKey.Python_RSAKey._parsePKCS8(
709              tlslite.utils.cryptomath.stringToBytes(key_str))
710
711        assert key is not None
712        self.keys.append({ 'private_key' : key })
713    else:
714      # Generate 2 private keys if none were passed from the command line.
715      for i in range(2):
716        key = tlslite.api.generateRSAKey(512)
717        assert key is not None
718        self.keys.append({ 'private_key' : key })
719
720    # Derive the public keys from the private keys.
721    for entry in self.keys:
722      key = entry['private_key']
723
724      algorithm = asn1der.Sequence(
725          [ asn1der.Data(asn1der.OBJECT_IDENTIFIER, PKCS1_RSA_OID),
726            asn1der.Data(asn1der.NULL, '') ])
727      rsa_pubkey = asn1der.Sequence([ asn1der.Integer(key.n),
728                                      asn1der.Integer(key.e) ])
729      pubkey = asn1der.Sequence([ algorithm, asn1der.Bitstring(rsa_pubkey) ])
730      entry['public_key'] = pubkey
731
732    # Load client state.
733    if self.client_state_file is not None:
734      try:
735        file_contents = open(self.client_state_file).read()
736        self._registered_tokens = json.loads(file_contents, strict=False)
737      except IOError:
738        pass
739
740  def GetPolicies(self):
741    """Returns the policies to be used, reloaded form the backend file every
742       time this is called.
743    """
744    policy = {}
745    if json is None:
746      print 'No JSON module, cannot parse policy information'
747    else :
748      try:
749        policy = json.loads(open(self.policy_path).read(), strict=False)
750      except IOError:
751        print 'Failed to load policy from %s' % self.policy_path
752    return policy
753
754  def RegisterDevice(self, device_id, machine_id, type):
755    """Registers a device or user and generates a DM token for it.
756
757    Args:
758      device_id: The device identifier provided by the client.
759
760    Returns:
761      The newly generated device token for the device.
762    """
763    dmtoken_chars = []
764    while len(dmtoken_chars) < 32:
765      dmtoken_chars.append(random.choice('0123456789abcdef'))
766    dmtoken = ''.join(dmtoken_chars)
767    allowed_policy_types = {
768      dm.DeviceRegisterRequest.BROWSER: [
769          'google/chrome/user',
770          'google/chrome/extension'
771      ],
772      dm.DeviceRegisterRequest.USER: [
773          'google/chromeos/user',
774          'google/chrome/extension'
775      ],
776      dm.DeviceRegisterRequest.DEVICE: [
777          'google/chromeos/device',
778          'google/chromeos/publicaccount'
779      ],
780      dm.DeviceRegisterRequest.TT: ['google/chromeos/user',
781                                    'google/chrome/user'],
782    }
783    if machine_id in KIOSK_MACHINE_IDS:
784      enrollment_mode = dm.DeviceRegisterResponse.RETAIL
785    else:
786      enrollment_mode = dm.DeviceRegisterResponse.ENTERPRISE
787    self._registered_tokens[dmtoken] = {
788      'device_id': device_id,
789      'device_token': dmtoken,
790      'allowed_policy_types': allowed_policy_types[type],
791      'machine_name': 'chromeos-' + machine_id,
792      'machine_id': machine_id,
793      'enrollment_mode': enrollment_mode,
794    }
795    self.WriteClientState()
796    return self._registered_tokens[dmtoken]
797
798  def UpdateMachineId(self, dmtoken, machine_id):
799    """Updates the machine identifier for a registered device.
800
801    Args:
802      dmtoken: The device management token provided by the client.
803      machine_id: Updated hardware identifier value.
804    """
805    if dmtoken in self._registered_tokens:
806      self._registered_tokens[dmtoken]['machine_id'] = machine_id
807      self.WriteClientState()
808
809  def LookupToken(self, dmtoken):
810    """Looks up a device or a user by DM token.
811
812    Args:
813      dmtoken: The device management token provided by the client.
814
815    Returns:
816      A dictionary with information about a device or user that is registered by
817      dmtoken, or None if the token is not found.
818    """
819    return self._registered_tokens.get(dmtoken, None)
820
821  def UnregisterDevice(self, dmtoken):
822    """Unregisters a device identified by the given DM token.
823
824    Args:
825      dmtoken: The device management token provided by the client.
826    """
827    if dmtoken in self._registered_tokens.keys():
828      del self._registered_tokens[dmtoken]
829      self.WriteClientState()
830
831  def WriteClientState(self):
832    """Writes the client state back to the file."""
833    if self.client_state_file is not None:
834      json_data = json.dumps(self._registered_tokens)
835      open(self.client_state_file, 'w').write(json_data)
836
837  def GetBaseFilename(self, policy_selector):
838    """Returns the base filename for the given policy_selector.
839
840    Args:
841      policy_selector: the policy type and settings entity id, joined by '/'.
842
843    Returns:
844      The filename corresponding to the policy_selector, without a file
845      extension.
846    """
847    sanitized_policy_selector = re.sub('[^A-Za-z0-9.@-]', '_', policy_selector)
848    return os.path.join(self.data_dir or '',
849                        'policy_%s' % sanitized_policy_selector)
850
851  def ReadPolicyFromDataDir(self, policy_selector, proto_message):
852    """Tries to read policy payload from a file in the data directory.
853
854    First checks for a binary rendition of the policy protobuf in
855    <data_dir>/policy_<sanitized_policy_selector>.bin. If that exists, returns
856    it. If that file doesn't exist, tries
857    <data_dir>/policy_<sanitized_policy_selector>.txt and decodes that as a
858    protobuf using proto_message. If that fails as well, returns None.
859
860    Args:
861      policy_selector: Selects which policy to read.
862      proto_message: Optional protobuf message object used for decoding the
863          proto text format.
864
865    Returns:
866      The binary payload message, or None if not found.
867    """
868    base_filename = self.GetBaseFilename(policy_selector)
869
870    # Try the binary payload file first.
871    try:
872      return open(base_filename + '.bin').read()
873    except IOError:
874      pass
875
876    # If that fails, try the text version instead.
877    if proto_message is None:
878      return None
879
880    try:
881      text = open(base_filename + '.txt').read()
882      google.protobuf.text_format.Merge(text, proto_message)
883      return proto_message.SerializeToString()
884    except IOError:
885      return None
886    except google.protobuf.text_format.ParseError:
887      return None
888
889  def ReadPolicyDataFromDataDir(self, policy_selector):
890    """Returns the external policy data for |policy_selector| if found.
891
892    Args:
893      policy_selector: Selects which policy to read.
894
895    Returns:
896      The data for the corresponding policy type and entity id, if found.
897    """
898    base_filename = self.GetBaseFilename(policy_selector)
899    try:
900      return open(base_filename + '.data').read()
901    except IOError:
902      return None
903
904  def GetBaseURL(self):
905    """Returns the server base URL.
906
907    Respects the |server_base_url| configuration parameter, if present. Falls
908    back to construct the URL from the server hostname and port otherwise.
909
910    Returns:
911      The URL to use for constructing URLs that get returned to clients.
912    """
913    base_url = self.server_base_url
914    if base_url is None:
915      base_url = 'http://%s:%s' % self.server_address[:2]
916
917    return base_url
918
919
920class PolicyServerRunner(testserver_base.TestServerRunner):
921
922  def __init__(self):
923    super(PolicyServerRunner, self).__init__()
924
925  def create_server(self, server_data):
926    data_dir = self.options.data_dir or ''
927    config_file = (self.options.config_file or
928                   os.path.join(data_dir, 'device_management'))
929    server = PolicyTestServer((self.options.host, self.options.port),
930                              data_dir, config_file,
931                              self.options.client_state_file,
932                              self.options.policy_keys,
933                              self.options.server_base_url)
934    server_data['port'] = server.server_port
935    return server
936
937  def add_options(self):
938    testserver_base.TestServerRunner.add_options(self)
939    self.option_parser.add_option('--client-state', dest='client_state_file',
940                                  help='File that client state should be '
941                                  'persisted to. This allows the server to be '
942                                  'seeded by a list of pre-registered clients '
943                                  'and restarts without abandoning registered '
944                                  'clients.')
945    self.option_parser.add_option('--policy-key', action='append',
946                                  dest='policy_keys',
947                                  help='Specify a path to a PEM-encoded '
948                                  'private key to use for policy signing. May '
949                                  'be specified multiple times in order to '
950                                  'load multipe keys into the server. If the '
951                                  'server has multiple keys, it will rotate '
952                                  'through them in at each request in a '
953                                  'round-robin fashion. The server will '
954                                  'generate a random key if none is specified '
955                                  'on the command line.')
956    self.option_parser.add_option('--log-level', dest='log_level',
957                                  default='WARN',
958                                  help='Log level threshold to use.')
959    self.option_parser.add_option('--config-file', dest='config_file',
960                                  help='Specify a configuration file to use '
961                                  'instead of the default '
962                                  '<data_dir>/device_management')
963    self.option_parser.add_option('--server-base-url', dest='server_base_url',
964                                  help='The server base URL to use when '
965                                  'constructing URLs to return to the client.')
966
967  def run_server(self):
968    logger = logging.getLogger()
969    logger.setLevel(getattr(logging, str(self.options.log_level).upper()))
970    if (self.options.log_to_console):
971      logger.addHandler(logging.StreamHandler())
972    if (self.options.log_file):
973      logger.addHandler(logging.FileHandler(self.options.log_file))
974
975    testserver_base.TestServerRunner.run_server(self)
976
977
978if __name__ == '__main__':
979  sys.exit(PolicyServerRunner().main())
980