#!/usr/bin/python """ Client for file transfer services offered by RSS (Remote Shell Server). @author: Michael Goldish (mgoldish@redhat.com) @copyright: 2008-2010 Red Hat Inc. """ import socket, struct, time, sys, os, glob # Globals CHUNKSIZE = 65536 # Protocol message constants RSS_MAGIC = 0x525353 RSS_OK = 1 RSS_ERROR = 2 RSS_UPLOAD = 3 RSS_DOWNLOAD = 4 RSS_SET_PATH = 5 RSS_CREATE_FILE = 6 RSS_CREATE_DIR = 7 RSS_LEAVE_DIR = 8 RSS_DONE = 9 # See rss.cpp for protocol details. class FileTransferError(Exception): def __init__(self, msg, e=None, filename=None): Exception.__init__(self, msg, e, filename) self.msg = msg self.e = e self.filename = filename def __str__(self): s = self.msg if self.e and self.filename: s += " (error: %s, filename: %s)" % (self.e, self.filename) elif self.e: s += " (%s)" % self.e elif self.filename: s += " (filename: %s)" % self.filename return s class FileTransferConnectError(FileTransferError): pass class FileTransferTimeoutError(FileTransferError): pass class FileTransferProtocolError(FileTransferError): pass class FileTransferSocketError(FileTransferError): pass class FileTransferServerError(FileTransferError): def __init__(self, errmsg): FileTransferError.__init__(self, None, errmsg) def __str__(self): s = "Server said: %r" % self.e if self.filename: s += " (filename: %s)" % self.filename return s class FileTransferNotFoundError(FileTransferError): pass class FileTransferClient(object): """ Connect to a RSS (remote shell server) and transfer files. """ def __init__(self, address, port, log_func=None, timeout=20): """ Connect to a server. @param address: The server's address @param port: The server's port @param log_func: If provided, transfer stats will be passed to this function during the transfer @param timeout: Time duration to wait for connection to succeed @raise FileTransferConnectError: Raised if the connection fails """ self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self._socket.settimeout(timeout) try: self._socket.connect((address, port)) except socket.error, e: raise FileTransferConnectError("Cannot connect to server at " "%s:%s" % (address, port), e) try: if self._receive_msg(timeout) != RSS_MAGIC: raise FileTransferConnectError("Received wrong magic number") except FileTransferTimeoutError: raise FileTransferConnectError("Timeout expired while waiting to " "receive magic number") self._send(struct.pack("=i", CHUNKSIZE)) self._log_func = log_func self._last_time = time.time() self._last_transferred = 0 self.transferred = 0 def __del__(self): self.close() def close(self): """ Close the connection. """ self._socket.close() def _send(self, str, timeout=60): try: if timeout <= 0: raise socket.timeout self._socket.settimeout(timeout) self._socket.sendall(str) except socket.timeout: raise FileTransferTimeoutError("Timeout expired while sending " "data to server") except socket.error, e: raise FileTransferSocketError("Could not send data to server", e) def _receive(self, size, timeout=60): strs = [] end_time = time.time() + timeout try: while size > 0: timeout = end_time - time.time() if timeout <= 0: raise socket.timeout self._socket.settimeout(timeout) data = self._socket.recv(size) if not data: raise FileTransferProtocolError("Connection closed " "unexpectedly while " "receiving data from " "server") strs.append(data) size -= len(data) except socket.timeout: raise FileTransferTimeoutError("Timeout expired while receiving " "data from server") except socket.error, e: raise FileTransferSocketError("Error receiving data from server", e) return "".join(strs) def _report_stats(self, str): if self._log_func: dt = time.time() - self._last_time if dt >= 1: transferred = self.transferred / 1048576. speed = (self.transferred - self._last_transferred) / dt speed /= 1048576. self._log_func("%s %.3f MB (%.3f MB/sec)" % (str, transferred, speed)) self._last_time = time.time() self._last_transferred = self.transferred def _send_packet(self, str, timeout=60): self._send(struct.pack("=I", len(str))) self._send(str, timeout) self.transferred += len(str) + 4 self._report_stats("Sent") def _receive_packet(self, timeout=60): size = struct.unpack("=I", self._receive(4))[0] str = self._receive(size, timeout) self.transferred += len(str) + 4 self._report_stats("Received") return str def _send_file_chunks(self, filename, timeout=60): if self._log_func: self._log_func("Sending file %s" % filename) f = open(filename, "rb") try: try: end_time = time.time() + timeout while True: data = f.read(CHUNKSIZE) self._send_packet(data, end_time - time.time()) if len(data) < CHUNKSIZE: break except FileTransferError, e: e.filename = filename raise finally: f.close() def _receive_file_chunks(self, filename, timeout=60): if self._log_func: self._log_func("Receiving file %s" % filename) f = open(filename, "wb") try: try: end_time = time.time() + timeout while True: data = self._receive_packet(end_time - time.time()) f.write(data) if len(data) < CHUNKSIZE: break except FileTransferError, e: e.filename = filename raise finally: f.close() def _send_msg(self, msg, timeout=60): self._send(struct.pack("=I", msg)) def _receive_msg(self, timeout=60): s = self._receive(4, timeout) return struct.unpack("=I", s)[0] def _handle_transfer_error(self): # Save original exception e = sys.exc_info() try: # See if we can get an error message msg = self._receive_msg() except FileTransferError: # No error message -- re-raise original exception raise e[0], e[1], e[2] if msg == RSS_ERROR: errmsg = self._receive_packet() raise FileTransferServerError(errmsg) raise e[0], e[1], e[2] class FileUploadClient(FileTransferClient): """ Connect to a RSS (remote shell server) and upload files or directory trees. """ def __init__(self, address, port, log_func=None, timeout=20): """ Connect to a server. @param address: The server's address @param port: The server's port @param log_func: If provided, transfer stats will be passed to this function during the transfer @param timeout: Time duration to wait for connection to succeed @raise FileTransferConnectError: Raised if the connection fails @raise FileTransferProtocolError: Raised if an incorrect magic number is received @raise FileTransferSocketError: Raised if the RSS_UPLOAD message cannot be sent to the server """ super(FileUploadClient, self).__init__(address, port, log_func, timeout) self._send_msg(RSS_UPLOAD) def _upload_file(self, path, end_time): if os.path.isfile(path): self._send_msg(RSS_CREATE_FILE) self._send_packet(os.path.basename(path)) self._send_file_chunks(path, end_time - time.time()) elif os.path.isdir(path): self._send_msg(RSS_CREATE_DIR) self._send_packet(os.path.basename(path)) for filename in os.listdir(path): self._upload_file(os.path.join(path, filename), end_time) self._send_msg(RSS_LEAVE_DIR) def upload(self, src_pattern, dst_path, timeout=600): """ Send files or directory trees to the server. The semantics of src_pattern and dst_path are similar to those of scp. For example, the following are OK: src_pattern='/tmp/foo.txt', dst_path='C:\\' (uploads a single file) src_pattern='/usr/', dst_path='C:\\Windows\\' (uploads a directory tree recursively) src_pattern='/usr/*', dst_path='C:\\Windows\\' (uploads all files and directory trees under /usr/) The following is not OK: src_pattern='/tmp/foo.txt', dst_path='C:\\Windows\\*' (wildcards are only allowed in src_pattern) @param src_pattern: A path or wildcard pattern specifying the files or directories to send to the server @param dst_path: A path in the server's filesystem where the files will be saved @param timeout: Time duration in seconds to wait for the transfer to complete @raise FileTransferTimeoutError: Raised if timeout expires @raise FileTransferServerError: Raised if something goes wrong and the server sends an informative error message to the client @note: Other exceptions can be raised. """ end_time = time.time() + timeout try: try: self._send_msg(RSS_SET_PATH) self._send_packet(dst_path) matches = glob.glob(src_pattern) for filename in matches: self._upload_file(os.path.abspath(filename), end_time) self._send_msg(RSS_DONE) except FileTransferTimeoutError: raise except FileTransferError: self._handle_transfer_error() else: # If nothing was transferred, raise an exception if not matches: raise FileTransferNotFoundError("Pattern %s does not " "match any files or " "directories" % src_pattern) # Look for RSS_OK or RSS_ERROR msg = self._receive_msg(end_time - time.time()) if msg == RSS_OK: return elif msg == RSS_ERROR: errmsg = self._receive_packet() raise FileTransferServerError(errmsg) else: # Neither RSS_OK nor RSS_ERROR found raise FileTransferProtocolError("Received unexpected msg") except: # In any case, if the transfer failed, close the connection self.close() raise class FileDownloadClient(FileTransferClient): """ Connect to a RSS (remote shell server) and download files or directory trees. """ def __init__(self, address, port, log_func=None, timeout=20): """ Connect to a server. @param address: The server's address @param port: The server's port @param log_func: If provided, transfer stats will be passed to this function during the transfer @param timeout: Time duration to wait for connection to succeed @raise FileTransferConnectError: Raised if the connection fails @raise FileTransferProtocolError: Raised if an incorrect magic number is received @raise FileTransferSendError: Raised if the RSS_UPLOAD message cannot be sent to the server """ super(FileDownloadClient, self).__init__(address, port, log_func, timeout) self._send_msg(RSS_DOWNLOAD) def download(self, src_pattern, dst_path, timeout=600): """ Receive files or directory trees from the server. The semantics of src_pattern and dst_path are similar to those of scp. For example, the following are OK: src_pattern='C:\\foo.txt', dst_path='/tmp' (downloads a single file) src_pattern='C:\\Windows', dst_path='/tmp' (downloads a directory tree recursively) src_pattern='C:\\Windows\\*', dst_path='/tmp' (downloads all files and directory trees under C:\\Windows) The following is not OK: src_pattern='C:\\Windows', dst_path='/tmp/*' (wildcards are only allowed in src_pattern) @param src_pattern: A path or wildcard pattern specifying the files or directories, in the server's filesystem, that will be sent to the client @param dst_path: A path in the local filesystem where the files will be saved @param timeout: Time duration in seconds to wait for the transfer to complete @raise FileTransferTimeoutError: Raised if timeout expires @raise FileTransferServerError: Raised if something goes wrong and the server sends an informative error message to the client @note: Other exceptions can be raised. """ dst_path = os.path.abspath(dst_path) end_time = time.time() + timeout file_count = 0 dir_count = 0 try: try: self._send_msg(RSS_SET_PATH) self._send_packet(src_pattern) except FileTransferError: self._handle_transfer_error() while True: msg = self._receive_msg() if msg == RSS_CREATE_FILE: # Receive filename and file contents filename = self._receive_packet() if os.path.isdir(dst_path): dst_path = os.path.join(dst_path, filename) self._receive_file_chunks(dst_path, end_time - time.time()) dst_path = os.path.dirname(dst_path) file_count += 1 elif msg == RSS_CREATE_DIR: # Receive dirname and create the directory dirname = self._receive_packet() if os.path.isdir(dst_path): dst_path = os.path.join(dst_path, dirname) if not os.path.isdir(dst_path): os.mkdir(dst_path) dir_count += 1 elif msg == RSS_LEAVE_DIR: # Return to parent dir dst_path = os.path.dirname(dst_path) elif msg == RSS_DONE: # Transfer complete if not file_count and not dir_count: raise FileTransferNotFoundError("Pattern %s does not " "match any files or " "directories that " "could be downloaded" % src_pattern) break elif msg == RSS_ERROR: # Receive error message and abort errmsg = self._receive_packet() raise FileTransferServerError(errmsg) else: # Unexpected msg raise FileTransferProtocolError("Received unexpected msg") except: # In any case, if the transfer failed, close the connection self.close() raise def upload(address, port, src_pattern, dst_path, log_func=None, timeout=60, connect_timeout=20): """ Connect to server and upload files. @see: FileUploadClient """ client = FileUploadClient(address, port, log_func, connect_timeout) client.upload(src_pattern, dst_path, timeout) client.close() def download(address, port, src_pattern, dst_path, log_func=None, timeout=60, connect_timeout=20): """ Connect to server and upload files. @see: FileDownloadClient """ client = FileDownloadClient(address, port, log_func, connect_timeout) client.download(src_pattern, dst_path, timeout) client.close() def main(): import optparse usage = "usage: %prog [options] address port src_pattern dst_path" parser = optparse.OptionParser(usage=usage) parser.add_option("-d", "--download", action="store_true", dest="download", help="download files from server") parser.add_option("-u", "--upload", action="store_true", dest="upload", help="upload files to server") parser.add_option("-v", "--verbose", action="store_true", dest="verbose", help="be verbose") parser.add_option("-t", "--timeout", type="int", dest="timeout", default=3600, help="transfer timeout") options, args = parser.parse_args() if options.download == options.upload: parser.error("you must specify either -d or -u") if len(args) != 4: parser.error("incorrect number of arguments") address, port, src_pattern, dst_path = args port = int(port) logger = None if options.verbose: def p(s): print s logger = p if options.download: download(address, port, src_pattern, dst_path, logger, options.timeout) elif options.upload: upload(address, port, src_pattern, dst_path, logger, options.timeout) if __name__ == "__main__": main()