• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2010 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16#
17
18"""JSON support for message types.
19
20Public classes:
21  MessageJSONEncoder: JSON encoder for message objects.
22
23Public functions:
24  encode_message: Encodes a message in to a JSON string.
25  decode_message: Merge from a JSON string in to a message.
26"""
27import base64
28import binascii
29import logging
30
31import six
32
33from apitools.base.protorpclite import message_types
34from apitools.base.protorpclite import messages
35from apitools.base.protorpclite import util
36
37__all__ = [
38    'ALTERNATIVE_CONTENT_TYPES',
39    'CONTENT_TYPE',
40    'MessageJSONEncoder',
41    'encode_message',
42    'decode_message',
43    'ProtoJson',
44]
45
46
47def _load_json_module():
48    """Try to load a valid json module.
49
50    There are more than one json modules that might be installed.  They are
51    mostly compatible with one another but some versions may be different.
52    This function attempts to load various json modules in a preferred order.
53    It does a basic check to guess if a loaded version of json is compatible.
54
55    Returns:
56      Compatible json module.
57
58    Raises:
59      ImportError if there are no json modules or the loaded json module is
60        not compatible with ProtoRPC.
61    """
62    first_import_error = None
63    for module_name in ['json',
64                        'simplejson']:
65        try:
66            module = __import__(module_name, {}, {}, 'json')
67            if not hasattr(module, 'JSONEncoder'):
68                message = (
69                    'json library "%s" is not compatible with ProtoRPC' %
70                    module_name)
71                logging.warning(message)
72                raise ImportError(message)
73            else:
74                return module
75        except ImportError as err:
76            if not first_import_error:
77                first_import_error = err
78
79    logging.error('Must use valid json library (json or simplejson)')
80    raise first_import_error  # pylint:disable=raising-bad-type
81
82
83json = _load_json_module()
84
85
86# TODO: Rename this to MessageJsonEncoder.
87class MessageJSONEncoder(json.JSONEncoder):
88    """Message JSON encoder class.
89
90    Extension of JSONEncoder that can build JSON from a message object.
91    """
92
93    def __init__(self, protojson_protocol=None, **kwargs):
94        """Constructor.
95
96        Args:
97          protojson_protocol: ProtoJson instance.
98        """
99        super(MessageJSONEncoder, self).__init__(**kwargs)
100        self.__protojson_protocol = (
101            protojson_protocol or ProtoJson.get_default())
102
103    def default(self, value):
104        """Return dictionary instance from a message object.
105
106        Args:
107        value: Value to get dictionary for.  If not encodable, will
108          call superclasses default method.
109        """
110        if isinstance(value, messages.Enum):
111            return str(value)
112
113        if six.PY3 and isinstance(value, bytes):
114            return value.decode('utf8')
115
116        if isinstance(value, messages.Message):
117            result = {}
118            for field in value.all_fields():
119                item = value.get_assigned_value(field.name)
120                if item not in (None, [], ()):
121                    result[field.name] = (
122                        self.__protojson_protocol.encode_field(field, item))
123            # Handle unrecognized fields, so they're included when a message is
124            # decoded then encoded.
125            for unknown_key in value.all_unrecognized_fields():
126                unrecognized_field, _ = value.get_unrecognized_field_info(
127                    unknown_key)
128                # Unknown fields are not encoded as they should have been
129                # processed before we get to here.
130                result[unknown_key] = unrecognized_field
131            return result
132
133        return super(MessageJSONEncoder, self).default(value)
134
135
136class ProtoJson(object):
137    """ProtoRPC JSON implementation class.
138
139    Implementation of JSON based protocol used for serializing and
140    deserializing message objects. Instances of remote.ProtocolConfig
141    constructor or used with remote.Protocols.add_protocol. See the
142    remote.py module for more details.
143
144    """
145
146    CONTENT_TYPE = 'application/json'
147    ALTERNATIVE_CONTENT_TYPES = [
148        'application/x-javascript',
149        'text/javascript',
150        'text/x-javascript',
151        'text/x-json',
152        'text/json',
153    ]
154
155    def encode_field(self, field, value):
156        """Encode a python field value to a JSON value.
157
158        Args:
159          field: A ProtoRPC field instance.
160          value: A python value supported by field.
161
162        Returns:
163          A JSON serializable value appropriate for field.
164        """
165        if isinstance(field, messages.BytesField):
166            if field.repeated:
167                value = [base64.b64encode(byte) for byte in value]
168            else:
169                value = base64.b64encode(value)
170        elif isinstance(field, message_types.DateTimeField):
171            # DateTimeField stores its data as a RFC 3339 compliant string.
172            if field.repeated:
173                value = [i.isoformat() for i in value]
174            else:
175                value = value.isoformat()
176        return value
177
178    def encode_message(self, message):
179        """Encode Message instance to JSON string.
180
181        Args:
182          Message instance to encode in to JSON string.
183
184        Returns:
185          String encoding of Message instance in protocol JSON format.
186
187        Raises:
188          messages.ValidationError if message is not initialized.
189        """
190        message.check_initialized()
191
192        return json.dumps(message, cls=MessageJSONEncoder,
193                          protojson_protocol=self)
194
195    def decode_message(self, message_type, encoded_message):
196        """Merge JSON structure to Message instance.
197
198        Args:
199          message_type: Message to decode data to.
200          encoded_message: JSON encoded version of message.
201
202        Returns:
203          Decoded instance of message_type.
204
205        Raises:
206          ValueError: If encoded_message is not valid JSON.
207          messages.ValidationError if merged message is not initialized.
208        """
209        encoded_message = six.ensure_str(encoded_message)
210        if not encoded_message.strip():
211            return message_type()
212
213        dictionary = json.loads(encoded_message)
214        message = self.__decode_dictionary(message_type, dictionary)
215        message.check_initialized()
216        return message
217
218    def __find_variant(self, value):
219        """Find the messages.Variant type that describes this value.
220
221        Args:
222          value: The value whose variant type is being determined.
223
224        Returns:
225          The messages.Variant value that best describes value's type,
226          or None if it's a type we don't know how to handle.
227
228        """
229        if isinstance(value, bool):
230            return messages.Variant.BOOL
231        elif isinstance(value, six.integer_types):
232            return messages.Variant.INT64
233        elif isinstance(value, float):
234            return messages.Variant.DOUBLE
235        elif isinstance(value, six.string_types):
236            return messages.Variant.STRING
237        elif isinstance(value, (list, tuple)):
238            # Find the most specific variant that covers all elements.
239            variant_priority = [None,
240                                messages.Variant.INT64,
241                                messages.Variant.DOUBLE,
242                                messages.Variant.STRING]
243            chosen_priority = 0
244            for v in value:
245                variant = self.__find_variant(v)
246                try:
247                    priority = variant_priority.index(variant)
248                except IndexError:
249                    priority = -1
250                if priority > chosen_priority:
251                    chosen_priority = priority
252            return variant_priority[chosen_priority]
253        # Unrecognized type.
254        return None
255
256    def __decode_dictionary(self, message_type, dictionary):
257        """Merge dictionary in to message.
258
259        Args:
260          message: Message to merge dictionary in to.
261          dictionary: Dictionary to extract information from.  Dictionary
262            is as parsed from JSON.  Nested objects will also be dictionaries.
263        """
264        message = message_type()
265        for key, value in six.iteritems(dictionary):
266            if value is None:
267                try:
268                    message.reset(key)
269                except AttributeError:
270                    pass  # This is an unrecognized field, skip it.
271                continue
272
273            try:
274                field = message.field_by_name(key)
275            except KeyError:
276                # Save unknown values.
277                variant = self.__find_variant(value)
278                if variant:
279                    message.set_unrecognized_field(key, value, variant)
280                continue
281
282            is_enum_field = isinstance(field, messages.EnumField)
283            is_unrecognized_field = False
284            if field.repeated:
285                # This should be unnecessary? Or in fact become an error.
286                if not isinstance(value, list):
287                    value = [value]
288                valid_value = []
289                for item in value:
290                    try:
291                        v = self.decode_field(field, item)
292                        if is_enum_field and v is None:
293                            continue
294                    except messages.DecodeError:
295                        if not is_enum_field:
296                            raise
297
298                        is_unrecognized_field = True
299                        continue
300                    valid_value.append(v)
301
302                setattr(message, field.name, valid_value)
303                if is_unrecognized_field:
304                    variant = self.__find_variant(value)
305                    if variant:
306                        message.set_unrecognized_field(key, value, variant)
307                continue
308
309            # This is just for consistency with the old behavior.
310            if value == []:
311                continue
312            try:
313                setattr(message, field.name, self.decode_field(field, value))
314            except messages.DecodeError:
315                # Save unknown enum values.
316                if not is_enum_field:
317                    raise
318                variant = self.__find_variant(value)
319                if variant:
320                    message.set_unrecognized_field(key, value, variant)
321
322        return message
323
324    def decode_field(self, field, value):
325        """Decode a JSON value to a python value.
326
327        Args:
328          field: A ProtoRPC field instance.
329          value: A serialized JSON value.
330
331        Return:
332          A Python value compatible with field.
333        """
334        if isinstance(field, messages.EnumField):
335            try:
336                return field.type(value)
337            except TypeError:
338                raise messages.DecodeError(
339                    'Invalid enum value "%s"' % (value or ''))
340
341        elif isinstance(field, messages.BytesField):
342            try:
343                return base64.b64decode(value)
344            except (binascii.Error, TypeError) as err:
345                raise messages.DecodeError('Base64 decoding error: %s' % err)
346
347        elif isinstance(field, message_types.DateTimeField):
348            try:
349                return util.decode_datetime(value, truncate_time=True)
350            except ValueError as err:
351                raise messages.DecodeError(err)
352
353        elif (isinstance(field, messages.MessageField) and
354              issubclass(field.type, messages.Message)):
355            return self.__decode_dictionary(field.type, value)
356
357        elif (isinstance(field, messages.FloatField) and
358              isinstance(value, (six.integer_types, six.string_types))):
359            try:
360                return float(value)
361            except:  # pylint:disable=bare-except
362                pass
363
364        elif (isinstance(field, messages.IntegerField) and
365              isinstance(value, six.string_types)):
366            try:
367                return int(value)
368            except:  # pylint:disable=bare-except
369                pass
370
371        return value
372
373    @staticmethod
374    def get_default():
375        """Get default instanceof ProtoJson."""
376        try:
377            return ProtoJson.__default
378        except AttributeError:
379            ProtoJson.__default = ProtoJson()
380            return ProtoJson.__default
381
382    @staticmethod
383    def set_default(protocol):
384        """Set the default instance of ProtoJson.
385
386        Args:
387          protocol: A ProtoJson instance.
388        """
389        if not isinstance(protocol, ProtoJson):
390            raise TypeError('Expected protocol of type ProtoJson')
391        ProtoJson.__default = protocol
392
393
394CONTENT_TYPE = ProtoJson.CONTENT_TYPE
395
396ALTERNATIVE_CONTENT_TYPES = ProtoJson.ALTERNATIVE_CONTENT_TYPES
397
398encode_message = ProtoJson.get_default().encode_message
399
400decode_message = ProtoJson.get_default().decode_message
401