• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2#
3# Copyright 2017 The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17# pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18from errno import *  # pylint: disable=wildcard-import
19import os
20import itertools
21from scapy import all as scapy
22from socket import *  # pylint: disable=wildcard-import
23import subprocess
24import threading
25import unittest
26
27import multinetwork_base
28import net_test
29from tun_twister import TapTwister
30import util
31import xfrm
32import xfrm_base
33
34# List of encryption algorithms for use in ParamTests.
35CRYPT_ALGOS = [
36    xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 128)),
37    xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 192)),
38    xfrm.XfrmAlgo((xfrm.XFRM_EALG_CBC_AES, 256)),
39]
40
41# List of auth algorithms for use in ParamTests.
42AUTH_ALGOS = [
43    # RFC 4868 specifies that the only supported truncation length is half the
44    # hash size.
45    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 96)),
46    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 96)),
47    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 128)),
48    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 192)),
49    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 256)),
50    # Test larger truncation lengths for good measure.
51    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_MD5, 128, 128)),
52    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA1, 160, 160)),
53    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA256, 256, 256)),
54    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA384, 384, 384)),
55    xfrm.XfrmAlgoAuth((xfrm.XFRM_AALG_HMAC_SHA512, 512, 512)),
56]
57
58# List of aead algorithms for use in ParamTests.
59AEAD_ALGOS = [
60    # RFC 4106 specifies that key length must be 128, 192 or 256 bits,
61    #   with an additional 4 bytes (32 bits) of salt. The salt must be unique
62    #   for each new SA using the same key.
63    # RFC 4106 specifies that ICV length must be 8, 12, or 16 bytes
64    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32,  8*8)),
65    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 12*8)),
66    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 128+32, 16*8)),
67    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32,  8*8)),
68    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 12*8)),
69    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 192+32, 16*8)),
70    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32,  8*8)),
71    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 12*8)),
72    xfrm.XfrmAlgoAead((xfrm.XFRM_AEAD_GCM_AES, 256+32, 16*8)),
73]
74
75def InjectTests():
76  XfrmAlgorithmTest.InjectTests()
77
78
79class XfrmAlgorithmTest(xfrm_base.XfrmLazyTest):
80  @classmethod
81  def InjectTests(cls):
82    VERSIONS = (4, 6)
83    TYPES = (SOCK_DGRAM, SOCK_STREAM)
84
85    # Tests all combinations of auth & crypt. Mutually exclusive with aead.
86    param_list = itertools.product(VERSIONS, TYPES, AUTH_ALGOS, CRYPT_ALGOS,
87                                   [None])
88    util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
89
90    # Tests all combinations of aead. Mutually exclusive with auth/crypt.
91    param_list = itertools.product(VERSIONS, TYPES, [None], [None], AEAD_ALGOS)
92    util.InjectParameterizedTest(cls, param_list, cls.TestNameGenerator)
93
94  @staticmethod
95  def TestNameGenerator(version, proto, auth, crypt, aead):
96    # Produce a unique and readable name for each test. e.g.
97    #     testSocketPolicySimple_cbc-aes_256_hmac-sha512_512_256_IPv6_UDP
98    param_string = ""
99    if crypt is not None:
100      param_string += "%s_%d_" % (crypt.name, crypt.key_len)
101
102    if auth is not None:
103      param_string += "%s_%d_%d_" % (auth.name, auth.key_len,
104          auth.trunc_len)
105
106    if aead is not None:
107      param_string += "%s_%d_%d_" % (aead.name, aead.key_len,
108          aead.icv_len)
109
110    param_string += "%s_%s" % ("IPv4" if version == 4 else "IPv6",
111        "UDP" if proto == SOCK_DGRAM else "TCP")
112    return param_string
113
114  def ParamTestSocketPolicySimple(self, version, proto, auth, crypt, aead):
115    """Test two-way traffic using transport mode and socket policies."""
116
117    def AssertEncrypted(packet):
118      # This gives a free pass to ICMP and ICMPv6 packets, which show up
119      # nondeterministically in tests.
120      self.assertEquals(None,
121                        packet.getlayer(scapy.UDP),
122                        "UDP packet sent in the clear")
123      self.assertEquals(None,
124                        packet.getlayer(scapy.TCP),
125                        "TCP packet sent in the clear")
126
127    # We create a pair of sockets, "left" and "right", that will talk to each
128    # other using transport mode ESP. Because of TapTwister, both sockets
129    # perceive each other as owning "remote_addr".
130    netid = self.RandomNetid()
131    family = net_test.GetAddressFamily(version)
132    local_addr = self.MyAddress(version, netid)
133    remote_addr = self.GetRemoteSocketAddress(version)
134    auth_left = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
135                 os.urandom(auth.key_len / 8)) if auth else None
136    auth_right = (xfrm.XfrmAlgoAuth((auth.name, auth.key_len, auth.trunc_len)),
137                  os.urandom(auth.key_len / 8)) if auth else None
138    crypt_left = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
139                  os.urandom(crypt.key_len / 8)) if crypt else None
140    crypt_right = (xfrm.XfrmAlgo((crypt.name, crypt.key_len)),
141                   os.urandom(crypt.key_len / 8)) if crypt else None
142    aead_left = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
143                 os.urandom(aead.key_len / 8)) if aead else None
144    aead_right = (xfrm.XfrmAlgoAead((aead.name, aead.key_len, aead.icv_len)),
145                  os.urandom(aead.key_len / 8)) if aead else None
146    spi_left = 0xbeefface
147    spi_right = 0xcafed00d
148    req_ids = [100, 200, 300, 400]  # Used to match templates and SAs.
149
150    # Left outbound SA
151    self.xfrm.AddSaInfo(
152        src=local_addr,
153        dst=remote_addr,
154        spi=spi_right,
155        mode=xfrm.XFRM_MODE_TRANSPORT,
156        reqid=req_ids[0],
157        encryption=crypt_right,
158        auth_trunc=auth_right,
159        aead=aead_right,
160        encap=None,
161        mark=None,
162        output_mark=None)
163    # Right inbound SA
164    self.xfrm.AddSaInfo(
165        src=remote_addr,
166        dst=local_addr,
167        spi=spi_right,
168        mode=xfrm.XFRM_MODE_TRANSPORT,
169        reqid=req_ids[1],
170        encryption=crypt_right,
171        auth_trunc=auth_right,
172        aead=aead_right,
173        encap=None,
174        mark=None,
175        output_mark=None)
176    # Right outbound SA
177    self.xfrm.AddSaInfo(
178        src=local_addr,
179        dst=remote_addr,
180        spi=spi_left,
181        mode=xfrm.XFRM_MODE_TRANSPORT,
182        reqid=req_ids[2],
183        encryption=crypt_left,
184        auth_trunc=auth_left,
185        aead=aead_left,
186        encap=None,
187        mark=None,
188        output_mark=None)
189    # Left inbound SA
190    self.xfrm.AddSaInfo(
191        src=remote_addr,
192        dst=local_addr,
193        spi=spi_left,
194        mode=xfrm.XFRM_MODE_TRANSPORT,
195        reqid=req_ids[3],
196        encryption=crypt_left,
197        auth_trunc=auth_left,
198        aead=aead_left,
199        encap=None,
200        mark=None,
201        output_mark=None)
202
203    # Make two sockets.
204    sock_left = socket(family, proto, 0)
205    sock_left.settimeout(2.0)
206    sock_left.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
207    self.SelectInterface(sock_left, netid, "mark")
208    sock_right = socket(family, proto, 0)
209    sock_right.settimeout(2.0)
210    sock_right.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
211    self.SelectInterface(sock_right, netid, "mark")
212
213    # For UDP, set SO_LINGER to 0, to prevent TCP sockets from hanging around
214    # in a TIME_WAIT state.
215    if proto == SOCK_STREAM:
216      net_test.DisableFinWait(sock_left)
217      net_test.DisableFinWait(sock_right)
218
219    # Apply the left outbound socket policy.
220    xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_OUT,
221                                spi_right, req_ids[0], None)
222    # Apply right inbound socket policy.
223    xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_IN,
224                                spi_right, req_ids[1], None)
225    # Apply right outbound socket policy.
226    xfrm_base.ApplySocketPolicy(sock_right, family, xfrm.XFRM_POLICY_OUT,
227                                spi_left, req_ids[2], None)
228    # Apply left inbound socket policy.
229    xfrm_base.ApplySocketPolicy(sock_left, family, xfrm.XFRM_POLICY_IN,
230                                spi_left, req_ids[3], None)
231
232    server_ready = threading.Event()
233    server_error = None  # Save exceptions thrown by the server.
234
235    def TcpServer(sock, client_port):
236      try:
237        sock.listen(1)
238        server_ready.set()
239        accepted, peer = sock.accept()
240        self.assertEquals(remote_addr, peer[0])
241        self.assertEquals(client_port, peer[1])
242        data = accepted.recv(2048)
243        self.assertEquals("hello request", data)
244        accepted.send("hello response")
245      except Exception as e:
246        server_error = e
247      finally:
248        sock.close()
249
250    def UdpServer(sock, client_port):
251      try:
252        server_ready.set()
253        data, peer = sock.recvfrom(2048)
254        self.assertEquals(remote_addr, peer[0])
255        self.assertEquals(client_port, peer[1])
256        self.assertEquals("hello request", data)
257        sock.sendto("hello response", peer)
258      except Exception as e:
259        server_error = e
260      finally:
261        sock.close()
262
263    # Server and client need to know each other's port numbers in advance.
264    wildcard_addr = net_test.GetWildcardAddress(version)
265    sock_left.bind((wildcard_addr, 0))
266    sock_right.bind((wildcard_addr, 0))
267    left_port = sock_left.getsockname()[1]
268    right_port = sock_right.getsockname()[1]
269
270    # Start the appropriate server type on sock_right.
271    target = TcpServer if proto == SOCK_STREAM else UdpServer
272    server = threading.Thread(
273        target=target,
274        args=(sock_right, left_port),
275        name="SocketServer")
276    server.start()
277    # Wait for server to be ready before attempting to connect. TCP retries
278    # hide this problem, but UDP will fail outright if the server socket has
279    # not bound when we send.
280    self.assertTrue(server_ready.wait(2.0), "Timed out waiting for server thread")
281
282    with TapTwister(fd=self.tuns[netid].fileno(), validator=AssertEncrypted):
283      sock_left.connect((remote_addr, right_port))
284      sock_left.send("hello request")
285      data = sock_left.recv(2048)
286      self.assertEquals("hello response", data)
287      sock_left.close()
288      server.join()
289    if server_error:
290      raise server_error
291
292
293if __name__ == "__main__":
294  XfrmAlgorithmTest.InjectTests()
295  unittest.main()
296