• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""JSON support for message types.
19
20Public classes:
21  MessageJSONEncoder: JSON encoder for message objects.
22
23Public functions:
24  encode_message: Encodes a message in to a JSON string.
25  decode_message: Merge from a JSON string in to a message.
26"""
27import base64
28import binascii
29import logging
30
31import six
32
33from apitools.base.protorpclite import message_types
34from apitools.base.protorpclite import messages
35from apitools.base.protorpclite import util
36
37__all__ = [
38    'ALTERNATIVE_CONTENT_TYPES',
39    'CONTENT_TYPE',
40    'MessageJSONEncoder',
41    'encode_message',
42    'decode_message',
43    'ProtoJson',
44]
45
46
47def _load_json_module():
48    """Try to load a valid json module.
49
50    There are more than one json modules that might be installed.  They are
51    mostly compatible with one another but some versions may be different.
52    This function attempts to load various json modules in a preferred order.
53    It does a basic check to guess if a loaded version of json is compatible.
54
55    Returns:
56      Compatible json module.
57
58    Raises:
59      ImportError if there are no json modules or the loaded json module is
60        not compatible with ProtoRPC.
61    """
62    first_import_error = None
63    for module_name in ['json',
64                        'simplejson']:
65        try:
66            module = __import__(module_name, {}, {}, 'json')
67            if not hasattr(module, 'JSONEncoder'):
68                message = (
69                    'json library "%s" is not compatible with ProtoRPC' %
70                    module_name)
71                logging.warning(message)
72                raise ImportError(message)
73            else:
74                return module
75        except ImportError as err:
76            if not first_import_error:
77                first_import_error = err
78
79    logging.error('Must use valid json library (json or simplejson)')
80    raise first_import_error  # pylint:disable=raising-bad-type
81
82
83json = _load_json_module()
84
85
86# TODO: Rename this to MessageJsonEncoder.
87class MessageJSONEncoder(json.JSONEncoder):
88    """Message JSON encoder class.
89
90    Extension of JSONEncoder that can build JSON from a message object.
91    """
92
93    def __init__(self, protojson_protocol=None, **kwargs):
94        """Constructor.
95
96        Args:
97          protojson_protocol: ProtoJson instance.
98        """
99        super(MessageJSONEncoder, self).__init__(**kwargs)
100        self.__protojson_protocol = (
101            protojson_protocol or ProtoJson.get_default())
102
103    def default(self, value):
104        """Return dictionary instance from a message object.
105
106        Args:
107        value: Value to get dictionary for.  If not encodable, will
108          call superclasses default method.
109        """
110        if isinstance(value, messages.Enum):
111            return str(value)
112
113        if six.PY3 and isinstance(value, bytes):
114            return value.decode('utf8')
115
116        if isinstance(value, messages.Message):
117            result = {}
118            for field in value.all_fields():
119                item = value.get_assigned_value(field.name)
120                if item not in (None, [], ()):
121                    result[field.name] = (
122                        self.__protojson_protocol.encode_field(field, item))
123            # Handle unrecognized fields, so they're included when a message is
124            # decoded then encoded.
125            for unknown_key in value.all_unrecognized_fields():
126                unrecognized_field, _ = value.get_unrecognized_field_info(
127                    unknown_key)
128                # Unknown fields are not encoded as they should have been
129                # processed before we get to here.
130                result[unknown_key] = unrecognized_field
131            return result
132
133        return super(MessageJSONEncoder, self).default(value)
134
135
136class ProtoJson(object):
137    """ProtoRPC JSON implementation class.
138
139    Implementation of JSON based protocol used for serializing and
140    deserializing message objects. Instances of remote.ProtocolConfig
141    constructor or used with remote.Protocols.add_protocol. See the
142    remote.py module for more details.
143
144    """
145
146    CONTENT_TYPE = 'application/json'
147    ALTERNATIVE_CONTENT_TYPES = [
148        'application/x-javascript',
149        'text/javascript',
150        'text/x-javascript',
151        'text/x-json',
152        'text/json',
153    ]
154
155    def encode_field(self, field, value):
156        """Encode a python field value to a JSON value.
157
158        Args:
159          field: A ProtoRPC field instance.
160          value: A python value supported by field.
161
162        Returns:
163          A JSON serializable value appropriate for field.
164        """
165        if isinstance(field, messages.BytesField):
166            if field.repeated:
167                value = [base64.b64encode(byte) for byte in value]
168            else:
169                value = base64.b64encode(value)
170        elif isinstance(field, message_types.DateTimeField):
171            # DateTimeField stores its data as a RFC 3339 compliant string.
172            if field.repeated:
173                value = [i.isoformat() for i in value]
174            else:
175                value = value.isoformat()
176        return value
177
178    def encode_message(self, message):
179        """Encode Message instance to JSON string.
180
181        Args:
182          Message instance to encode in to JSON string.
183
184        Returns:
185          String encoding of Message instance in protocol JSON format.
186
187        Raises:
188          messages.ValidationError if message is not initialized.
189        """
190        message.check_initialized()
191
192        return json.dumps(message, cls=MessageJSONEncoder,
193                          protojson_protocol=self)
194
195    def decode_message(self, message_type, encoded_message):
196        """Merge JSON structure to Message instance.
197
198        Args:
199          message_type: Message to decode data to.
200          encoded_message: JSON encoded version of message.
201
202        Returns:
203          Decoded instance of message_type.
204
205        Raises:
206          ValueError: If encoded_message is not valid JSON.
207          messages.ValidationError if merged message is not initialized.
208        """
209        encoded_message = six.ensure_str(encoded_message)
210        if not encoded_message.strip():
211            return message_type()
212
213        dictionary = json.loads(encoded_message)
214        message = self.__decode_dictionary(message_type, dictionary)
215        message.check_initialized()
216        return message
217
218    def __find_variant(self, value):
219        """Find the messages.Variant type that describes this value.
220
221        Args:
222          value: The value whose variant type is being determined.
223
224        Returns:
225          The messages.Variant value that best describes value's type,
226          or None if it's a type we don't know how to handle.
227
228        """
229        if isinstance(value, bool):
230            return messages.Variant.BOOL
231        elif isinstance(value, six.integer_types):
232            return messages.Variant.INT64
233        elif isinstance(value, float):
234            return messages.Variant.DOUBLE
235        elif isinstance(value, six.string_types):
236            return messages.Variant.STRING
237        elif isinstance(value, (list, tuple)):
238            # Find the most specific variant that covers all elements.
239            variant_priority = [None,
240                                messages.Variant.INT64,
241                                messages.Variant.DOUBLE,
242                                messages.Variant.STRING]
243            chosen_priority = 0
244            for v in value:
245                variant = self.__find_variant(v)
246                try:
247                    priority = variant_priority.index(variant)
248                except IndexError:
249                    priority = -1
250                if priority > chosen_priority:
251                    chosen_priority = priority
252            return variant_priority[chosen_priority]
253        # Unrecognized type.
254        return None
255
256    def __decode_dictionary(self, message_type, dictionary):
257        """Merge dictionary in to message.
258
259        Args:
260          message: Message to merge dictionary in to.
261          dictionary: Dictionary to extract information from.  Dictionary
262            is as parsed from JSON.  Nested objects will also be dictionaries.
263        """
264        message = message_type()
265        for key, value in six.iteritems(dictionary):
266            if value is None:
267                try:
268                    message.reset(key)
269                except AttributeError:
270                    pass  # This is an unrecognized field, skip it.
271                continue
272
273            try:
274                field = message.field_by_name(key)
275            except KeyError:
276                # Save unknown values.
277                variant = self.__find_variant(value)
278                if variant:
279                    message.set_unrecognized_field(key, value, variant)
280                continue
281
282            if field.repeated:
283                # This should be unnecessary? Or in fact become an error.
284                if not isinstance(value, list):
285                    value = [value]
286                valid_value = [self.decode_field(field, item)
287                               for item in value]
288                setattr(message, field.name, valid_value)
289                continue
290            # This is just for consistency with the old behavior.
291            if value == []:
292                continue
293            try:
294                setattr(message, field.name, self.decode_field(field, value))
295            except messages.DecodeError:
296                # Save unknown enum values.
297                if not isinstance(field, messages.EnumField):
298                    raise
299                variant = self.__find_variant(value)
300                if variant:
301                    message.set_unrecognized_field(key, value, variant)
302
303        return message
304
305    def decode_field(self, field, value):
306        """Decode a JSON value to a python value.
307
308        Args:
309          field: A ProtoRPC field instance.
310          value: A serialized JSON value.
311
312        Return:
313          A Python value compatible with field.
314        """
315        if isinstance(field, messages.EnumField):
316            try:
317                return field.type(value)
318            except TypeError:
319                raise messages.DecodeError(
320                    'Invalid enum value "%s"' % (value or ''))
321
322        elif isinstance(field, messages.BytesField):
323            try:
324                return base64.b64decode(value)
325            except (binascii.Error, TypeError) as err:
326                raise messages.DecodeError('Base64 decoding error: %s' % err)
327
328        elif isinstance(field, message_types.DateTimeField):
329            try:
330                return util.decode_datetime(value)
331            except ValueError as err:
332                raise messages.DecodeError(err)
333
334        elif (isinstance(field, messages.MessageField) and
335              issubclass(field.type, messages.Message)):
336            return self.__decode_dictionary(field.type, value)
337
338        elif (isinstance(field, messages.FloatField) and
339              isinstance(value, (six.integer_types, six.string_types))):
340            try:
341                return float(value)
342            except:  # pylint:disable=bare-except
343                pass
344
345        elif (isinstance(field, messages.IntegerField) and
346              isinstance(value, six.string_types)):
347            try:
348                return int(value)
349            except:  # pylint:disable=bare-except
350                pass
351
352        return value
353
354    @staticmethod
355    def get_default():
356        """Get default instanceof ProtoJson."""
357        try:
358            return ProtoJson.__default
359        except AttributeError:
360            ProtoJson.__default = ProtoJson()
361            return ProtoJson.__default
362
363    @staticmethod
364    def set_default(protocol):
365        """Set the default instance of ProtoJson.
366
367        Args:
368          protocol: A ProtoJson instance.
369        """
370        if not isinstance(protocol, ProtoJson):
371            raise TypeError('Expected protocol of type ProtoJson')
372        ProtoJson.__default = protocol
373
374
375CONTENT_TYPE = ProtoJson.CONTENT_TYPE
376
377ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES
378
379encode_message = ProtoJson.get_default().encode_message
380
381decode_message = ProtoJson.get_default().decode_message
382