• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python3
2# -*- coding: utf-8 -*-
3#
4#  Project                     ___| | | |  _ \| |
5#                             / __| | | | |_) | |
6#                            | (__| |_| |  _ <| |___
7#                             \___|\___/|_| \_\_____|
8#
9# Copyright (C) Daniel Stenberg, <daniel@haxx.se>, et al.
10#
11# This software is licensed as described in the file COPYING, which
12# you should have received as part of this distribution. The terms
13# are also available at https://curl.se/docs/copyright.html.
14#
15# You may opt to use, copy, modify, merge, publish, distribute and/or sell
16# copies of the Software, and permit persons to whom the Software is
17# furnished to do so, under the terms of the COPYING file.
18#
19# This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY
20# KIND, either express or implied.
21#
22# SPDX-License-Identifier: curl
23#
24"""Server for testing SMB"""
25
26from __future__ import (absolute_import, division, print_function,
27                        unicode_literals)
28
29import argparse
30import logging
31import os
32import signal
33import sys
34import tempfile
35import threading
36
37# Import our curl test data helper
38from util import ClosingFileHandler, TestData
39
40if sys.version_info.major >= 3:
41    import configparser
42else:
43    import ConfigParser as configparser
44
45# impacket needs to be installed in the Python environment
46try:
47    import impacket
48except ImportError:
49    sys.stderr.write('Python package impacket needs to be installed!\n')
50    sys.stderr.write('Use pip or your package manager to install it.\n')
51    sys.exit(1)
52from impacket import smb as imp_smb
53from impacket import smbserver as imp_smbserver
54from impacket.nt_errors import (STATUS_ACCESS_DENIED, STATUS_NO_SUCH_FILE,
55                                STATUS_SUCCESS)
56
57log = logging.getLogger(__name__)
58SERVER_MAGIC = "SERVER_MAGIC"
59TESTS_MAGIC = "TESTS_MAGIC"
60VERIFIED_REQ = "verifiedserver"
61VERIFIED_RSP = "WE ROOLZ: {pid}\n"
62
63
64class ShutdownHandler(threading.Thread):
65    """Cleanly shut down the SMB server
66
67    This can only be done from another thread while the server is in
68    serve_forever(), so a thread is spawned here that waits for a shutdown
69    signal before doing its thing. Use in a with statement around the
70    serve_forever() call.
71    """
72
73    def __init__(self, server):
74        super(ShutdownHandler, self).__init__()
75        self.server = server
76        self.shutdown_event = threading.Event()
77
78    def __enter__(self):
79        self.start()
80        signal.signal(signal.SIGINT, self._sighandler)
81        signal.signal(signal.SIGTERM, self._sighandler)
82
83    def __exit__(self, *_):
84        # Call for shutdown just in case it wasn't done already
85        self.shutdown_event.set()
86        # Wait for thread, and therefore also the server, to finish
87        self.join()
88        # Uninstall our signal handlers
89        signal.signal(signal.SIGINT, signal.SIG_DFL)
90        signal.signal(signal.SIGTERM, signal.SIG_DFL)
91        # Delete any temporary files created by the server during its run
92        log.info("Deleting %d temporary files", len(self.server.tmpfiles))
93        for f in self.server.tmpfiles:
94            os.unlink(f)
95
96    def _sighandler(self, _signum, _frame):
97        # Wake up the cleanup task
98        self.shutdown_event.set()
99
100    def run(self):
101        # Wait for shutdown signal
102        self.shutdown_event.wait()
103        # Notify the server to shut down
104        self.server.shutdown()
105
106
107def smbserver(options):
108    """Start up a TCP SMB server that serves forever
109
110    """
111    if options.pidfile:
112        pid = os.getpid()
113        # see tests/server/util.c function write_pidfile
114        if os.name == "nt":
115            pid += 65536
116        with open(options.pidfile, "w") as f:
117            f.write(str(pid))
118
119    # Here we write a mini config for the server
120    smb_config = configparser.ConfigParser()
121    smb_config.add_section("global")
122    smb_config.set("global", "server_name", "SERVICE")
123    smb_config.set("global", "server_os", "UNIX")
124    smb_config.set("global", "server_domain", "WORKGROUP")
125    smb_config.set("global", "log_file", "")
126    smb_config.set("global", "credentials_file", "")
127
128    # We need a share which allows us to test that the server is running
129    smb_config.add_section("SERVER")
130    smb_config.set("SERVER", "comment", "server function")
131    smb_config.set("SERVER", "read only", "yes")
132    smb_config.set("SERVER", "share type", "0")
133    smb_config.set("SERVER", "path", SERVER_MAGIC)
134
135    # Have a share for tests.  These files will be autogenerated from the
136    # test input.
137    smb_config.add_section("TESTS")
138    smb_config.set("TESTS", "comment", "tests")
139    smb_config.set("TESTS", "read only", "yes")
140    smb_config.set("TESTS", "share type", "0")
141    smb_config.set("TESTS", "path", TESTS_MAGIC)
142
143    if not options.srcdir or not os.path.isdir(options.srcdir):
144        raise ScriptException("--srcdir is mandatory")
145
146    test_data_dir = os.path.join(options.srcdir, "data")
147
148    smb_server = TestSmbServer((options.host, options.port),
149                               config_parser=smb_config,
150                               test_data_directory=test_data_dir)
151    log.info("[SMB] setting up SMB server on port %s", options.port)
152    smb_server.processConfigFile()
153
154    # Start a thread that cleanly shuts down the server on a signal
155    with ShutdownHandler(smb_server):
156        # This will block until smb_server.shutdown() is called
157        smb_server.serve_forever()
158
159    return 0
160
161
162class TestSmbServer(imp_smbserver.SMBSERVER):
163    """
164    Test server for SMB which subclasses the impacket SMBSERVER and provides
165    test functionality.
166    """
167
168    def __init__(self,
169                 address,
170                 config_parser=None,
171                 test_data_directory=None):
172        imp_smbserver.SMBSERVER.__init__(self,
173                                         address,
174                                         config_parser=config_parser)
175        self.tmpfiles = []
176
177        # Set up a test data object so we can get test data later.
178        self.ctd = TestData(test_data_directory)
179
180        # Override smbComNtCreateAndX so we can pretend to have files which
181        # don't exist.
182        self.hookSmbCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX,
183                            self.create_and_x)
184
185    def create_and_x(self, conn_id, smb_server, smb_command, recv_packet):
186        """
187        Our version of smbComNtCreateAndX looks for special test files and
188        fools the rest of the framework into opening them as if they were
189        normal files.
190        """
191        conn_data = smb_server.getConnectionData(conn_id)
192
193        # Wrap processing in a try block which allows us to throw SmbException
194        # to control the flow.
195        try:
196            ncax_parms = imp_smb.SMBNtCreateAndX_Parameters(
197                smb_command["Parameters"])
198
199            path = self.get_share_path(conn_data,
200                                       ncax_parms["RootFid"],
201                                       recv_packet["Tid"])
202            log.info("[SMB] Requested share path: %s", path)
203
204            disposition = ncax_parms["Disposition"]
205            log.debug("[SMB] Requested disposition: %s", disposition)
206
207            # Currently we only support reading files.
208            if disposition != imp_smb.FILE_OPEN:
209                raise SmbException(STATUS_ACCESS_DENIED,
210                                   "Only support reading files")
211
212            # Check to see if the path we were given is actually a
213            # magic path which needs generating on the fly.
214            if path not in [SERVER_MAGIC, TESTS_MAGIC]:
215                # Pass the command onto the original handler.
216                return imp_smbserver.SMBCommands.smbComNtCreateAndX(conn_id,
217                                                                    smb_server,
218                                                                    smb_command,
219                                                                    recv_packet)
220
221            flags2 = recv_packet["Flags2"]
222            ncax_data = imp_smb.SMBNtCreateAndX_Data(flags=flags2,
223                                                     data=smb_command[
224                                                         "Data"])
225            requested_file = imp_smbserver.decodeSMBString(
226                flags2,
227                ncax_data["FileName"])
228            log.debug("[SMB] User requested file '%s'", requested_file)
229
230            if path == SERVER_MAGIC:
231                fid, full_path = self.get_server_path(requested_file)
232            else:
233                assert (path == TESTS_MAGIC)
234                fid, full_path = self.get_test_path(requested_file)
235
236            self.tmpfiles.append(full_path)
237
238            resp_parms = imp_smb.SMBNtCreateAndXResponse_Parameters()
239            resp_data = ""
240
241            # Simple way to generate a fid
242            if len(conn_data["OpenedFiles"]) == 0:
243                fakefid = 1
244            else:
245                fakefid = conn_data["OpenedFiles"].keys()[-1] + 1
246            resp_parms["Fid"] = fakefid
247            resp_parms["CreateAction"] = disposition
248
249            if os.path.isdir(path):
250                resp_parms[
251                    "FileAttributes"] = imp_smb.SMB_FILE_ATTRIBUTE_DIRECTORY
252                resp_parms["IsDirectory"] = 1
253            else:
254                resp_parms["IsDirectory"] = 0
255                resp_parms["FileAttributes"] = ncax_parms["FileAttributes"]
256
257            # Get this file's information
258            resp_info, error_code = imp_smbserver.queryPathInformation(
259                os.path.dirname(full_path), os.path.basename(full_path),
260                level=imp_smb.SMB_QUERY_FILE_ALL_INFO)
261
262            if error_code != STATUS_SUCCESS:
263                raise SmbException(error_code, "Failed to query path info")
264
265            resp_parms["CreateTime"] = resp_info["CreationTime"]
266            resp_parms["LastAccessTime"] = resp_info[
267                "LastAccessTime"]
268            resp_parms["LastWriteTime"] = resp_info["LastWriteTime"]
269            resp_parms["LastChangeTime"] = resp_info[
270                "LastChangeTime"]
271            resp_parms["FileAttributes"] = resp_info[
272                "ExtFileAttributes"]
273            resp_parms["AllocationSize"] = resp_info[
274                "AllocationSize"]
275            resp_parms["EndOfFile"] = resp_info["EndOfFile"]
276
277            # Let's store the fid for the connection
278            # smbServer.log("Create file %s, mode:0x%x" % (pathName, mode))
279            conn_data["OpenedFiles"][fakefid] = {}
280            conn_data["OpenedFiles"][fakefid]["FileHandle"] = fid
281            conn_data["OpenedFiles"][fakefid]["FileName"] = path
282            conn_data["OpenedFiles"][fakefid]["DeleteOnClose"] = False
283
284        except SmbException as s:
285            log.debug("[SMB] SmbException hit: %s", s)
286            error_code = s.error_code
287            resp_parms = ""
288            resp_data = ""
289
290        resp_cmd = imp_smb.SMBCommand(imp_smb.SMB.SMB_COM_NT_CREATE_ANDX)
291        resp_cmd["Parameters"] = resp_parms
292        resp_cmd["Data"] = resp_data
293        smb_server.setConnectionData(conn_id, conn_data)
294
295        return [resp_cmd], None, error_code
296
297    def get_share_path(self, conn_data, root_fid, tid):
298        conn_shares = conn_data["ConnectedShares"]
299
300        if tid in conn_shares:
301            if root_fid > 0:
302                # If we have a rootFid, the path is relative to that fid
303                path = conn_data["OpenedFiles"][root_fid]["FileName"]
304                log.debug("RootFid present %s!" % path)
305            else:
306                if "path" in conn_shares[tid]:
307                    path = conn_shares[tid]["path"]
308                else:
309                    raise SmbException(STATUS_ACCESS_DENIED,
310                                       "Connection share had no path")
311        else:
312            raise SmbException(imp_smbserver.STATUS_SMB_BAD_TID,
313                               "TID was invalid")
314
315        return path
316
317    def get_server_path(self, requested_filename):
318        log.debug("[SMB] Get server path '%s'", requested_filename)
319
320        if requested_filename not in [VERIFIED_REQ]:
321            raise SmbException(STATUS_NO_SUCH_FILE, "Couldn't find the file")
322
323        fid, filename = tempfile.mkstemp()
324        log.debug("[SMB] Created %s (%d) for storing '%s'",
325                  filename, fid, requested_filename)
326
327        contents = ""
328
329        if requested_filename == VERIFIED_REQ:
330            log.debug("[SMB] Verifying server is alive")
331            pid = os.getpid()
332            # see tests/server/util.c function write_pidfile
333            if os.name == "nt":
334                pid += 65536
335            contents = VERIFIED_RSP.format(pid=pid).encode('utf-8')
336
337        self.write_to_fid(fid, contents)
338        return fid, filename
339
340    def write_to_fid(self, fid, contents):
341        # Write the contents to file descriptor
342        os.write(fid, contents)
343        os.fsync(fid)
344
345        # Rewind the file to the beginning so a read gets us the contents
346        os.lseek(fid, 0, os.SEEK_SET)
347
348    def get_test_path(self, requested_filename):
349        log.info("[SMB] Get reply data from 'test%s'", requested_filename)
350
351        fid, filename = tempfile.mkstemp()
352        log.debug("[SMB] Created %s (%d) for storing test '%s'",
353                  filename, fid, requested_filename)
354
355        try:
356            contents = self.ctd.get_test_data(requested_filename).encode('utf-8')
357            self.write_to_fid(fid, contents)
358            return fid, filename
359
360        except Exception:
361            log.exception("Failed to make test file")
362            raise SmbException(STATUS_NO_SUCH_FILE, "Failed to make test file")
363
364
365class SmbException(Exception):
366    def __init__(self, error_code, error_message):
367        super(SmbException, self).__init__(error_message)
368        self.error_code = error_code
369
370
371class ScriptRC(object):
372    """Enum for script return codes"""
373    SUCCESS = 0
374    FAILURE = 1
375    EXCEPTION = 2
376
377
378class ScriptException(Exception):
379    pass
380
381
382def get_options():
383    parser = argparse.ArgumentParser()
384
385    parser.add_argument("--port", action="store", default=9017,
386                      type=int, help="port to listen on")
387    parser.add_argument("--host", action="store", default="127.0.0.1",
388                      help="host to listen on")
389    parser.add_argument("--verbose", action="store", type=int, default=0,
390                        help="verbose output")
391    parser.add_argument("--pidfile", action="store",
392                        help="file name for the PID")
393    parser.add_argument("--logfile", action="store",
394                        help="file name for the log")
395    parser.add_argument("--srcdir", action="store", help="test directory")
396    parser.add_argument("--id", action="store", help="server ID")
397    parser.add_argument("--ipv4", action="store_true", default=0,
398                        help="IPv4 flag")
399
400    return parser.parse_args()
401
402
403def setup_logging(options):
404    """
405    Set up logging from the command line options
406    """
407    root_logger = logging.getLogger()
408    add_stdout = False
409
410    formatter = logging.Formatter("%(asctime)s %(levelname)-5.5s %(message)s")
411
412    # Write out to a logfile
413    if options.logfile:
414        handler = ClosingFileHandler(options.logfile)
415        handler.setFormatter(formatter)
416        handler.setLevel(logging.DEBUG)
417        root_logger.addHandler(handler)
418    else:
419        # The logfile wasn't specified. Add a stdout logger.
420        add_stdout = True
421
422    if options.verbose:
423        # Add a stdout logger as well in verbose mode
424        root_logger.setLevel(logging.DEBUG)
425        add_stdout = True
426    else:
427        root_logger.setLevel(logging.INFO)
428
429    if add_stdout:
430        stdout_handler = logging.StreamHandler(sys.stdout)
431        stdout_handler.setFormatter(formatter)
432        stdout_handler.setLevel(logging.DEBUG)
433        root_logger.addHandler(stdout_handler)
434
435
436if __name__ == '__main__':
437    # Get the options from the user.
438    options = get_options()
439
440    # Setup logging using the user options
441    setup_logging(options)
442
443    # Run main script.
444    try:
445        rc = smbserver(options)
446    except Exception as e:
447        log.exception(e)
448        rc = ScriptRC.EXCEPTION
449
450    if options.pidfile and os.path.isfile(options.pidfile):
451        os.unlink(options.pidfile)
452
453    log.info("[SMB] Returning %d", rc)
454    sys.exit(rc)
455