• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2# Copyright 2013 The Chromium Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6"""Tests exercising the various classes in xmppserver.py."""
7
8import unittest
9
10import base64
11import xmppserver
12
13class XmlUtilsTest(unittest.TestCase):
14
15  def testParseXml(self):
16    xml_text = """<foo xmlns=""><bar xmlns=""><baz/></bar></foo>"""
17    xml = xmppserver.ParseXml(xml_text)
18    self.assertEqual(xml.toxml(), xml_text)
19
20  def testCloneXml(self):
21    xml = xmppserver.ParseXml('<foo/>')
22    xml_clone = xmppserver.CloneXml(xml)
23    xml_clone.setAttribute('bar', 'baz')
24    self.assertEqual(xml, xml)
25    self.assertEqual(xml_clone, xml_clone)
26    self.assertNotEqual(xml, xml_clone)
27
28  def testCloneXmlUnlink(self):
29    xml_text = '<foo/>'
30    xml = xmppserver.ParseXml(xml_text)
31    xml_clone = xmppserver.CloneXml(xml)
32    xml.unlink()
33    self.assertEqual(xml.parentNode, None)
34    self.assertNotEqual(xml_clone.parentNode, None)
35    self.assertEqual(xml_clone.toxml(), xml_text)
36
37class StanzaParserTest(unittest.TestCase):
38
39  def setUp(self):
40    self.stanzas = []
41
42  def FeedStanza(self, stanza):
43    # We can't append stanza directly because it is unlinked after
44    # this callback.
45    self.stanzas.append(stanza.toxml())
46
47  def testBasic(self):
48    parser = xmppserver.StanzaParser(self)
49    parser.FeedString('<foo')
50    self.assertEqual(len(self.stanzas), 0)
51    parser.FeedString('/><bar></bar>')
52    self.assertEqual(self.stanzas[0], '<foo/>')
53    self.assertEqual(self.stanzas[1], '<bar/>')
54
55  def testStream(self):
56    parser = xmppserver.StanzaParser(self)
57    parser.FeedString('<stream')
58    self.assertEqual(len(self.stanzas), 0)
59    parser.FeedString(':stream foo="bar" xmlns:stream="baz">')
60    self.assertEqual(self.stanzas[0],
61                     '<stream:stream foo="bar" xmlns:stream="baz"/>')
62
63  def testNested(self):
64    parser = xmppserver.StanzaParser(self)
65    parser.FeedString('<foo')
66    self.assertEqual(len(self.stanzas), 0)
67    parser.FeedString(' bar="baz"')
68    parser.FeedString('><baz/><blah>meh</blah></foo>')
69    self.assertEqual(self.stanzas[0],
70                     '<foo bar="baz"><baz/><blah>meh</blah></foo>')
71
72
73class JidTest(unittest.TestCase):
74
75  def testBasic(self):
76    jid = xmppserver.Jid('foo', 'bar.com')
77    self.assertEqual(str(jid), 'foo@bar.com')
78
79  def testResource(self):
80    jid = xmppserver.Jid('foo', 'bar.com', 'resource')
81    self.assertEqual(str(jid), 'foo@bar.com/resource')
82
83  def testGetBareJid(self):
84    jid = xmppserver.Jid('foo', 'bar.com', 'resource')
85    self.assertEqual(str(jid.GetBareJid()), 'foo@bar.com')
86
87
88class IdGeneratorTest(unittest.TestCase):
89
90  def testBasic(self):
91    id_generator = xmppserver.IdGenerator('foo')
92    for i in xrange(0, 100):
93      self.assertEqual('foo.%d' % i, id_generator.GetNextId())
94
95
96class HandshakeTaskTest(unittest.TestCase):
97
98  def setUp(self):
99    self.Reset()
100
101  def Reset(self):
102    self.data_received = 0
103    self.handshake_done = False
104    self.jid = None
105
106  def SendData(self, _):
107    self.data_received += 1
108
109  def SendStanza(self, _, unused=True):
110    self.data_received += 1
111
112  def HandshakeDone(self, jid):
113    self.handshake_done = True
114    self.jid = jid
115
116  def DoHandshake(self, resource_prefix, resource, username,
117                  initial_stream_domain, auth_domain, auth_stream_domain):
118    self.Reset()
119    handshake_task = (
120      xmppserver.HandshakeTask(self, resource_prefix, True))
121    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
122    stream_xml.setAttribute('to', initial_stream_domain)
123    self.assertEqual(self.data_received, 0)
124    handshake_task.FeedStanza(stream_xml)
125    self.assertEqual(self.data_received, 2)
126
127    if auth_domain:
128      username_domain = '%s@%s' % (username, auth_domain)
129    else:
130      username_domain = username
131    auth_string = base64.b64encode('\0%s\0bar' % username_domain)
132    auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
133    handshake_task.FeedStanza(auth_xml)
134    self.assertEqual(self.data_received, 3)
135
136    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
137    stream_xml.setAttribute('to', auth_stream_domain)
138    handshake_task.FeedStanza(stream_xml)
139    self.assertEqual(self.data_received, 5)
140
141    bind_xml = xmppserver.ParseXml(
142      '<iq type="set"><bind><resource>%s</resource></bind></iq>' % resource)
143    handshake_task.FeedStanza(bind_xml)
144    self.assertEqual(self.data_received, 6)
145
146    self.assertFalse(self.handshake_done)
147
148    session_xml = xmppserver.ParseXml(
149      '<iq type="set"><session></session></iq>')
150    handshake_task.FeedStanza(session_xml)
151    self.assertEqual(self.data_received, 7)
152
153    self.assertTrue(self.handshake_done)
154
155    self.assertEqual(self.jid.username, username)
156    self.assertEqual(self.jid.domain,
157                     auth_stream_domain or auth_domain or
158                     initial_stream_domain)
159    self.assertEqual(self.jid.resource,
160                     '%s.%s' % (resource_prefix, resource))
161
162    handshake_task.FeedStanza('<ignored/>')
163    self.assertEqual(self.data_received, 7)
164
165  def DoHandshakeUnauthenticated(self, resource_prefix, resource, username,
166                                 initial_stream_domain):
167    self.Reset()
168    handshake_task = (
169      xmppserver.HandshakeTask(self, resource_prefix, False))
170    stream_xml = xmppserver.ParseXml('<stream:stream xmlns:stream="foo"/>')
171    stream_xml.setAttribute('to', initial_stream_domain)
172    self.assertEqual(self.data_received, 0)
173    handshake_task.FeedStanza(stream_xml)
174    self.assertEqual(self.data_received, 2)
175
176    self.assertFalse(self.handshake_done)
177
178    auth_string = base64.b64encode('\0%s\0bar' % username)
179    auth_xml = xmppserver.ParseXml('<auth>%s</auth>'% auth_string)
180    handshake_task.FeedStanza(auth_xml)
181    self.assertEqual(self.data_received, 3)
182
183    self.assertTrue(self.handshake_done)
184
185    self.assertEqual(self.jid, None)
186
187    handshake_task.FeedStanza('<ignored/>')
188    self.assertEqual(self.data_received, 3)
189
190  def testBasic(self):
191    self.DoHandshake('resource_prefix', 'resource',
192                     'foo', 'bar.com', 'baz.com', 'quux.com')
193
194  def testDomainBehavior(self):
195    self.DoHandshake('resource_prefix', 'resource',
196                     'foo', 'bar.com', 'baz.com', 'quux.com')
197    self.DoHandshake('resource_prefix', 'resource',
198                     'foo', 'bar.com', 'baz.com', '')
199    self.DoHandshake('resource_prefix', 'resource',
200                     'foo', 'bar.com', '', '')
201    self.DoHandshake('resource_prefix', 'resource',
202                     'foo', '', '', '')
203
204  def testBasicUnauthenticated(self):
205    self.DoHandshakeUnauthenticated('resource_prefix', 'resource',
206                                    'foo', 'bar.com')
207
208
209class FakeSocket(object):
210  """A fake socket object used for testing.
211  """
212
213  def __init__(self):
214    self._sent_data = []
215
216  def GetSentData(self):
217    return self._sent_data
218
219  # socket-like methods.
220  def fileno(self):
221    return 0
222
223  def setblocking(self, int):
224    pass
225
226  def getpeername(self):
227    return ('', 0)
228
229  def send(self, data):
230    self._sent_data.append(data)
231    pass
232
233  def close(self):
234    pass
235
236
237class XmppConnectionTest(unittest.TestCase):
238
239  def setUp(self):
240    self.connections = set()
241    self.fake_socket = FakeSocket()
242
243  # XmppConnection delegate methods.
244  def OnXmppHandshakeDone(self, xmpp_connection):
245    self.connections.add(xmpp_connection)
246
247  def OnXmppConnectionClosed(self, xmpp_connection):
248    self.connections.discard(xmpp_connection)
249
250  def ForwardNotification(self, unused_xmpp_connection, notification_stanza):
251    for connection in self.connections:
252      connection.ForwardNotification(notification_stanza)
253
254  def testBasic(self):
255    socket_map = {}
256    xmpp_connection = xmppserver.XmppConnection(
257      self.fake_socket, socket_map, self, ('', 0), True)
258    self.assertEqual(len(socket_map), 1)
259    self.assertEqual(len(self.connections), 0)
260    xmpp_connection.HandshakeDone(xmppserver.Jid('foo', 'bar'))
261    self.assertEqual(len(socket_map), 1)
262    self.assertEqual(len(self.connections), 1)
263
264    sent_data = self.fake_socket.GetSentData()
265
266    # Test subscription request.
267    self.assertEqual(len(sent_data), 0)
268    xmpp_connection.collect_incoming_data(
269      '<iq><subscribe xmlns="google:push"></subscribe></iq>')
270    self.assertEqual(len(sent_data), 1)
271
272    # Test acks.
273    xmpp_connection.collect_incoming_data('<iq type="result"/>')
274    self.assertEqual(len(sent_data), 1)
275
276    # Test notification.
277    xmpp_connection.collect_incoming_data(
278      '<message><push xmlns="google:push"/></message>')
279    self.assertEqual(len(sent_data), 2)
280
281    # Test unexpected stanza.
282    def SendUnexpectedStanza():
283      xmpp_connection.collect_incoming_data('<foo/>')
284    self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
285
286    # Test unexpected notifier command.
287    def SendUnexpectedNotifierCommand():
288      xmpp_connection.collect_incoming_data(
289        '<iq><foo xmlns="google:notifier"/></iq>')
290    self.assertRaises(xmppserver.UnexpectedXml,
291                      SendUnexpectedNotifierCommand)
292
293    # Test close.
294    xmpp_connection.close()
295    self.assertEqual(len(socket_map), 0)
296    self.assertEqual(len(self.connections), 0)
297
298  def testBasicUnauthenticated(self):
299    socket_map = {}
300    xmpp_connection = xmppserver.XmppConnection(
301      self.fake_socket, socket_map, self, ('', 0), False)
302    self.assertEqual(len(socket_map), 1)
303    self.assertEqual(len(self.connections), 0)
304    xmpp_connection.HandshakeDone(None)
305    self.assertEqual(len(socket_map), 0)
306    self.assertEqual(len(self.connections), 0)
307
308    # Test unexpected stanza.
309    def SendUnexpectedStanza():
310      xmpp_connection.collect_incoming_data('<foo/>')
311    self.assertRaises(xmppserver.UnexpectedXml, SendUnexpectedStanza)
312
313    # Test redundant close.
314    xmpp_connection.close()
315    self.assertEqual(len(socket_map), 0)
316    self.assertEqual(len(self.connections), 0)
317
318
319class FakeXmppServer(xmppserver.XmppServer):
320  """A fake XMPP server object used for testing.
321  """
322
323  def __init__(self):
324    self._socket_map = {}
325    self._fake_sockets = set()
326    self._next_jid_suffix = 1
327    xmppserver.XmppServer.__init__(self, self._socket_map, ('', 0))
328
329  def GetSocketMap(self):
330    return self._socket_map
331
332  def GetFakeSockets(self):
333    return self._fake_sockets
334
335  def AddHandshakeCompletedConnection(self):
336    """Creates a new XMPP connection and completes its handshake.
337    """
338    xmpp_connection = self.handle_accept()
339    jid = xmppserver.Jid('user%s' % self._next_jid_suffix, 'domain.com')
340    self._next_jid_suffix += 1
341    xmpp_connection.HandshakeDone(jid)
342
343  # XmppServer overrides.
344  def accept(self):
345    fake_socket = FakeSocket()
346    self._fake_sockets.add(fake_socket)
347    return (fake_socket, ('', 0))
348
349  def close(self):
350    self._fake_sockets.clear()
351    xmppserver.XmppServer.close(self)
352
353
354class XmppServerTest(unittest.TestCase):
355
356  def setUp(self):
357    self.xmpp_server = FakeXmppServer()
358
359  def AssertSentDataLength(self, expected_length):
360    for fake_socket in self.xmpp_server.GetFakeSockets():
361      self.assertEqual(len(fake_socket.GetSentData()), expected_length)
362
363  def testBasic(self):
364    socket_map = self.xmpp_server.GetSocketMap()
365    self.assertEqual(len(socket_map), 1)
366    self.xmpp_server.AddHandshakeCompletedConnection()
367    self.assertEqual(len(socket_map), 2)
368    self.xmpp_server.close()
369    self.assertEqual(len(socket_map), 0)
370
371  def testMakeNotification(self):
372    notification = self.xmpp_server.MakeNotification('channel', 'data')
373    expected_xml = (
374      '<message>'
375      '  <push channel="channel" xmlns="google:push">'
376      '    <data>%s</data>'
377      '  </push>'
378      '</message>' % base64.b64encode('data'))
379    self.assertEqual(notification.toxml(), expected_xml)
380
381  def testSendNotification(self):
382    # Add a few connections.
383    for _ in xrange(0, 7):
384      self.xmpp_server.AddHandshakeCompletedConnection()
385
386    self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 7)
387
388    self.AssertSentDataLength(0)
389    self.xmpp_server.SendNotification('channel', 'data')
390    self.AssertSentDataLength(1)
391
392  def testEnableDisableNotifications(self):
393    # Add a few connections.
394    for _ in xrange(0, 5):
395      self.xmpp_server.AddHandshakeCompletedConnection()
396
397    self.assertEqual(len(self.xmpp_server.GetFakeSockets()), 5)
398
399    self.AssertSentDataLength(0)
400    self.xmpp_server.SendNotification('channel', 'data')
401    self.AssertSentDataLength(1)
402
403    self.xmpp_server.EnableNotifications()
404    self.xmpp_server.SendNotification('channel', 'data')
405    self.AssertSentDataLength(2)
406
407    self.xmpp_server.DisableNotifications()
408    self.xmpp_server.SendNotification('channel', 'data')
409    self.AssertSentDataLength(2)
410
411    self.xmpp_server.DisableNotifications()
412    self.xmpp_server.SendNotification('channel', 'data')
413    self.AssertSentDataLength(2)
414
415    self.xmpp_server.EnableNotifications()
416    self.xmpp_server.SendNotification('channel', 'data')
417    self.AssertSentDataLength(3)
418
419
420if __name__ == '__main__':
421  unittest.main()
422