• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1
2import yaml
3import pprint
4
5import datetime
6import yaml.tokens
7
8# Import any packages here that need to be referenced in .code files.
9import signal
10
11def execute(code):
12    global value
13    exec(code)
14    return value
15
16def _make_objects():
17    global MyLoader, MyDumper, MyTestClass1, MyTestClass2, MyTestClass3, YAMLObject1, YAMLObject2,  \
18            AnObject, AnInstance, AState, ACustomState, InitArgs, InitArgsWithState,    \
19            NewArgs, NewArgsWithState, Reduce, ReduceWithState, Slots, MyInt, MyList, MyDict,  \
20            FixedOffset, today, execute, MyFullLoader
21
22    class MyLoader(yaml.Loader):
23        pass
24    class MyDumper(yaml.Dumper):
25        pass
26
27    class MyTestClass1:
28        def __init__(self, x, y=0, z=0):
29            self.x = x
30            self.y = y
31            self.z = z
32        def __eq__(self, other):
33            if isinstance(other, MyTestClass1):
34                return self.__class__, self.__dict__ == other.__class__, other.__dict__
35            else:
36                return False
37
38    def construct1(constructor, node):
39        mapping = constructor.construct_mapping(node)
40        return MyTestClass1(**mapping)
41    def represent1(representer, native):
42        return representer.represent_mapping("!tag1", native.__dict__)
43
44    def my_time_constructor(constructor, node):
45        seq = constructor.construct_sequence(node)
46        dt = seq[0]
47        tz = None
48        try:
49            tz = dt.tzinfo.tzname(dt)
50        except:
51            pass
52        return [dt, tz]
53
54    yaml.add_constructor("!tag1", construct1, Loader=MyLoader)
55    yaml.add_constructor("!MyTime", my_time_constructor, Loader=MyLoader)
56    yaml.add_representer(MyTestClass1, represent1, Dumper=MyDumper)
57
58    class MyTestClass2(MyTestClass1, yaml.YAMLObject):
59        yaml_loader = MyLoader
60        yaml_dumper = MyDumper
61        yaml_tag = "!tag2"
62        def from_yaml(cls, constructor, node):
63            x = constructor.construct_yaml_int(node)
64            return cls(x=x)
65        from_yaml = classmethod(from_yaml)
66        def to_yaml(cls, representer, native):
67            return representer.represent_scalar(cls.yaml_tag, str(native.x))
68        to_yaml = classmethod(to_yaml)
69
70    class MyTestClass3(MyTestClass2):
71        yaml_tag = "!tag3"
72        def from_yaml(cls, constructor, node):
73            mapping = constructor.construct_mapping(node)
74            if '=' in mapping:
75                x = mapping['=']
76                del mapping['=']
77                mapping['x'] = x
78            return cls(**mapping)
79        from_yaml = classmethod(from_yaml)
80        def to_yaml(cls, representer, native):
81            return representer.represent_mapping(cls.yaml_tag, native.__dict__)
82        to_yaml = classmethod(to_yaml)
83
84    class YAMLObject1(yaml.YAMLObject):
85        yaml_loader = MyLoader
86        yaml_dumper = MyDumper
87        yaml_tag = '!foo'
88        def __init__(self, my_parameter=None, my_another_parameter=None):
89            self.my_parameter = my_parameter
90            self.my_another_parameter = my_another_parameter
91        def __eq__(self, other):
92            if isinstance(other, YAMLObject1):
93                return self.__class__, self.__dict__ == other.__class__, other.__dict__
94            else:
95                return False
96
97    class YAMLObject2(yaml.YAMLObject):
98        yaml_loader = MyLoader
99        yaml_dumper = MyDumper
100        yaml_tag = '!bar'
101        def __init__(self, foo=1, bar=2, baz=3):
102            self.foo = foo
103            self.bar = bar
104            self.baz = baz
105        def __getstate__(self):
106            return {1: self.foo, 2: self.bar, 3: self.baz}
107        def __setstate__(self, state):
108            self.foo = state[1]
109            self.bar = state[2]
110            self.baz = state[3]
111        def __eq__(self, other):
112            if isinstance(other, YAMLObject2):
113                return self.__class__, self.__dict__ == other.__class__, other.__dict__
114            else:
115                return False
116
117    class AnObject:
118        def __new__(cls, foo=None, bar=None, baz=None):
119            self = object.__new__(cls)
120            self.foo = foo
121            self.bar = bar
122            self.baz = baz
123            return self
124        def __cmp__(self, other):
125            return cmp((type(self), self.foo, self.bar, self.baz),
126                    (type(other), other.foo, other.bar, other.baz))
127        def __eq__(self, other):
128            return type(self) is type(other) and    \
129                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
130
131    class AnInstance:
132        def __init__(self, foo=None, bar=None, baz=None):
133            self.foo = foo
134            self.bar = bar
135            self.baz = baz
136        def __cmp__(self, other):
137            return cmp((type(self), self.foo, self.bar, self.baz),
138                    (type(other), other.foo, other.bar, other.baz))
139        def __eq__(self, other):
140            return type(self) is type(other) and    \
141                    (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
142
143    class AState(AnInstance):
144        def __getstate__(self):
145            return {
146                '_foo': self.foo,
147                '_bar': self.bar,
148                '_baz': self.baz,
149            }
150        def __setstate__(self, state):
151            self.foo = state['_foo']
152            self.bar = state['_bar']
153            self.baz = state['_baz']
154
155    class ACustomState(AnInstance):
156        def __getstate__(self):
157            return (self.foo, self.bar, self.baz)
158        def __setstate__(self, state):
159            self.foo, self.bar, self.baz = state
160
161    class NewArgs(AnObject):
162        def __getnewargs__(self):
163            return (self.foo, self.bar, self.baz)
164        def __getstate__(self):
165            return {}
166
167    class NewArgsWithState(AnObject):
168        def __getnewargs__(self):
169            return (self.foo, self.bar)
170        def __getstate__(self):
171            return self.baz
172        def __setstate__(self, state):
173            self.baz = state
174
175    InitArgs = NewArgs
176
177    InitArgsWithState = NewArgsWithState
178
179    class Reduce(AnObject):
180        def __reduce__(self):
181            return self.__class__, (self.foo, self.bar, self.baz)
182
183    class ReduceWithState(AnObject):
184        def __reduce__(self):
185            return self.__class__, (self.foo, self.bar), self.baz
186        def __setstate__(self, state):
187            self.baz = state
188
189    class Slots:
190        __slots__ = ("foo", "bar", "baz")
191        def __init__(self, foo=None, bar=None, baz=None):
192            self.foo = foo
193            self.bar = bar
194            self.baz = baz
195
196        def __eq__(self, other):
197            return type(self) is type(other) and \
198                (self.foo, self.bar, self.baz) == (other.foo, other.bar, other.baz)
199
200    class MyInt(int):
201        def __eq__(self, other):
202            return type(self) is type(other) and int(self) == int(other)
203
204    class MyList(list):
205        def __init__(self, n=1):
206            self.extend([None]*n)
207        def __eq__(self, other):
208            return type(self) is type(other) and list(self) == list(other)
209
210    class MyDict(dict):
211        def __init__(self, n=1):
212            for k in range(n):
213                self[k] = None
214        def __eq__(self, other):
215            return type(self) is type(other) and dict(self) == dict(other)
216
217    class FixedOffset(datetime.tzinfo):
218        def __init__(self, offset, name):
219            self.__offset = datetime.timedelta(minutes=offset)
220            self.__name = name
221        def utcoffset(self, dt):
222            return self.__offset
223        def tzname(self, dt):
224            return self.__name
225        def dst(self, dt):
226            return datetime.timedelta(0)
227
228    class MyFullLoader(yaml.FullLoader):
229        def get_state_keys_blacklist(self):
230            return super().get_state_keys_blacklist() + ['^mymethod$', '^wrong_.*$']
231
232    today = datetime.date.today()
233
234def _load_code(expression):
235    return eval(expression)
236
237def _serialize_value(data):
238    if isinstance(data, list):
239        return '[%s]' % ', '.join(map(_serialize_value, data))
240    elif isinstance(data, dict):
241        items = []
242        for key, value in data.items():
243            key = _serialize_value(key)
244            value = _serialize_value(value)
245            items.append("%s: %s" % (key, value))
246        items.sort()
247        return '{%s}' % ', '.join(items)
248    elif isinstance(data, datetime.datetime):
249        return repr(data.utctimetuple())
250    elif isinstance(data, float) and data != data:
251        return '?'
252    else:
253        return str(data)
254
255def test_constructor_types(data_filename, code_filename, verbose=False):
256    _make_objects()
257    native1 = None
258    native2 = None
259    try:
260        with open(data_filename, 'rb') as file:
261            native1 = list(yaml.load_all(file, Loader=MyLoader))
262        if len(native1) == 1:
263            native1 = native1[0]
264        with open(code_filename, 'rb') as file:
265            native2 = _load_code(file.read())
266        try:
267            if native1 == native2:
268                return
269        except TypeError:
270            pass
271        if verbose:
272            print("SERIALIZED NATIVE1:")
273            print(_serialize_value(native1))
274            print("SERIALIZED NATIVE2:")
275            print(_serialize_value(native2))
276        assert _serialize_value(native1) == _serialize_value(native2), (native1, native2)
277    finally:
278        if verbose:
279            print("NATIVE1:")
280            pprint.pprint(native1)
281            print("NATIVE2:")
282            pprint.pprint(native2)
283
284test_constructor_types.unittest = ['.data', '.code']
285
286def test_subclass_blacklist_types(data_filename, verbose=False):
287    _make_objects()
288    try:
289        with open(data_filename, 'rb') as file:
290            yaml.load(file.read(), MyFullLoader)
291    except yaml.YAMLError as exc:
292        if verbose:
293            print("%s:" % exc.__class__.__name__, exc)
294    else:
295        raise AssertionError("expected an exception")
296
297test_subclass_blacklist_types.unittest = ['.subclass_blacklist']
298
299if __name__ == '__main__':
300    import sys, test_constructor
301    sys.modules['test_constructor'] = sys.modules['__main__']
302    import test_appliance
303    test_appliance.run(globals())
304
305