1#!/usr/bin/python 2""" 3Client for file transfer services offered by RSS (Remote Shell Server). 4 5@author: Michael Goldish (mgoldish@redhat.com) 6@copyright: 2008-2010 Red Hat Inc. 7""" 8 9import socket, struct, time, sys, os, glob 10 11# Globals 12CHUNKSIZE = 65536 13 14# Protocol message constants 15RSS_MAGIC = 0x525353 16RSS_OK = 1 17RSS_ERROR = 2 18RSS_UPLOAD = 3 19RSS_DOWNLOAD = 4 20RSS_SET_PATH = 5 21RSS_CREATE_FILE = 6 22RSS_CREATE_DIR = 7 23RSS_LEAVE_DIR = 8 24RSS_DONE = 9 25 26# See rss.cpp for protocol details. 27 28 29class FileTransferError(Exception): 30 def __init__(self, msg, e=None, filename=None): 31 Exception.__init__(self, msg, e, filename) 32 self.msg = msg 33 self.e = e 34 self.filename = filename 35 36 def __str__(self): 37 s = self.msg 38 if self.e and self.filename: 39 s += " (error: %s, filename: %s)" % (self.e, self.filename) 40 elif self.e: 41 s += " (%s)" % self.e 42 elif self.filename: 43 s += " (filename: %s)" % self.filename 44 return s 45 46 47class FileTransferConnectError(FileTransferError): 48 pass 49 50 51class FileTransferTimeoutError(FileTransferError): 52 pass 53 54 55class FileTransferProtocolError(FileTransferError): 56 pass 57 58 59class FileTransferSocketError(FileTransferError): 60 pass 61 62 63class FileTransferServerError(FileTransferError): 64 def __init__(self, errmsg): 65 FileTransferError.__init__(self, None, errmsg) 66 67 def __str__(self): 68 s = "Server said: %r" % self.e 69 if self.filename: 70 s += " (filename: %s)" % self.filename 71 return s 72 73 74class FileTransferNotFoundError(FileTransferError): 75 pass 76 77 78class FileTransferClient(object): 79 """ 80 Connect to a RSS (remote shell server) and transfer files. 81 """ 82 83 def __init__(self, address, port, log_func=None, timeout=20): 84 """ 85 Connect to a server. 86 87 @param address: The server's address 88 @param port: The server's port 89 @param log_func: If provided, transfer stats will be passed to this 90 function during the transfer 91 @param timeout: Time duration to wait for connection to succeed 92 @raise FileTransferConnectError: Raised if the connection fails 93 """ 94 self._socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 95 self._socket.settimeout(timeout) 96 try: 97 self._socket.connect((address, port)) 98 except socket.error, e: 99 raise FileTransferConnectError("Cannot connect to server at " 100 "%s:%s" % (address, port), e) 101 try: 102 if self._receive_msg(timeout) != RSS_MAGIC: 103 raise FileTransferConnectError("Received wrong magic number") 104 except FileTransferTimeoutError: 105 raise FileTransferConnectError("Timeout expired while waiting to " 106 "receive magic number") 107 self._send(struct.pack("=i", CHUNKSIZE)) 108 self._log_func = log_func 109 self._last_time = time.time() 110 self._last_transferred = 0 111 self.transferred = 0 112 113 114 def __del__(self): 115 self.close() 116 117 118 def close(self): 119 """ 120 Close the connection. 121 """ 122 self._socket.close() 123 124 125 def _send(self, str, timeout=60): 126 try: 127 if timeout <= 0: 128 raise socket.timeout 129 self._socket.settimeout(timeout) 130 self._socket.sendall(str) 131 except socket.timeout: 132 raise FileTransferTimeoutError("Timeout expired while sending " 133 "data to server") 134 except socket.error, e: 135 raise FileTransferSocketError("Could not send data to server", e) 136 137 138 def _receive(self, size, timeout=60): 139 strs = [] 140 end_time = time.time() + timeout 141 try: 142 while size > 0: 143 timeout = end_time - time.time() 144 if timeout <= 0: 145 raise socket.timeout 146 self._socket.settimeout(timeout) 147 data = self._socket.recv(size) 148 if not data: 149 raise FileTransferProtocolError("Connection closed " 150 "unexpectedly while " 151 "receiving data from " 152 "server") 153 strs.append(data) 154 size -= len(data) 155 except socket.timeout: 156 raise FileTransferTimeoutError("Timeout expired while receiving " 157 "data from server") 158 except socket.error, e: 159 raise FileTransferSocketError("Error receiving data from server", 160 e) 161 return "".join(strs) 162 163 164 def _report_stats(self, str): 165 if self._log_func: 166 dt = time.time() - self._last_time 167 if dt >= 1: 168 transferred = self.transferred / 1048576. 169 speed = (self.transferred - self._last_transferred) / dt 170 speed /= 1048576. 171 self._log_func("%s %.3f MB (%.3f MB/sec)" % 172 (str, transferred, speed)) 173 self._last_time = time.time() 174 self._last_transferred = self.transferred 175 176 177 def _send_packet(self, str, timeout=60): 178 self._send(struct.pack("=I", len(str))) 179 self._send(str, timeout) 180 self.transferred += len(str) + 4 181 self._report_stats("Sent") 182 183 184 def _receive_packet(self, timeout=60): 185 size = struct.unpack("=I", self._receive(4))[0] 186 str = self._receive(size, timeout) 187 self.transferred += len(str) + 4 188 self._report_stats("Received") 189 return str 190 191 192 def _send_file_chunks(self, filename, timeout=60): 193 if self._log_func: 194 self._log_func("Sending file %s" % filename) 195 f = open(filename, "rb") 196 try: 197 try: 198 end_time = time.time() + timeout 199 while True: 200 data = f.read(CHUNKSIZE) 201 self._send_packet(data, end_time - time.time()) 202 if len(data) < CHUNKSIZE: 203 break 204 except FileTransferError, e: 205 e.filename = filename 206 raise 207 finally: 208 f.close() 209 210 211 def _receive_file_chunks(self, filename, timeout=60): 212 if self._log_func: 213 self._log_func("Receiving file %s" % filename) 214 f = open(filename, "wb") 215 try: 216 try: 217 end_time = time.time() + timeout 218 while True: 219 data = self._receive_packet(end_time - time.time()) 220 f.write(data) 221 if len(data) < CHUNKSIZE: 222 break 223 except FileTransferError, e: 224 e.filename = filename 225 raise 226 finally: 227 f.close() 228 229 230 def _send_msg(self, msg, timeout=60): 231 self._send(struct.pack("=I", msg)) 232 233 234 def _receive_msg(self, timeout=60): 235 s = self._receive(4, timeout) 236 return struct.unpack("=I", s)[0] 237 238 239 def _handle_transfer_error(self): 240 # Save original exception 241 e = sys.exc_info() 242 try: 243 # See if we can get an error message 244 msg = self._receive_msg() 245 except FileTransferError: 246 # No error message -- re-raise original exception 247 raise e[0], e[1], e[2] 248 if msg == RSS_ERROR: 249 errmsg = self._receive_packet() 250 raise FileTransferServerError(errmsg) 251 raise e[0], e[1], e[2] 252 253 254class FileUploadClient(FileTransferClient): 255 """ 256 Connect to a RSS (remote shell server) and upload files or directory trees. 257 """ 258 259 def __init__(self, address, port, log_func=None, timeout=20): 260 """ 261 Connect to a server. 262 263 @param address: The server's address 264 @param port: The server's port 265 @param log_func: If provided, transfer stats will be passed to this 266 function during the transfer 267 @param timeout: Time duration to wait for connection to succeed 268 @raise FileTransferConnectError: Raised if the connection fails 269 @raise FileTransferProtocolError: Raised if an incorrect magic number 270 is received 271 @raise FileTransferSocketError: Raised if the RSS_UPLOAD message cannot 272 be sent to the server 273 """ 274 super(FileUploadClient, self).__init__(address, port, log_func, timeout) 275 self._send_msg(RSS_UPLOAD) 276 277 278 def _upload_file(self, path, end_time): 279 if os.path.isfile(path): 280 self._send_msg(RSS_CREATE_FILE) 281 self._send_packet(os.path.basename(path)) 282 self._send_file_chunks(path, end_time - time.time()) 283 elif os.path.isdir(path): 284 self._send_msg(RSS_CREATE_DIR) 285 self._send_packet(os.path.basename(path)) 286 for filename in os.listdir(path): 287 self._upload_file(os.path.join(path, filename), end_time) 288 self._send_msg(RSS_LEAVE_DIR) 289 290 291 def upload(self, src_pattern, dst_path, timeout=600): 292 """ 293 Send files or directory trees to the server. 294 The semantics of src_pattern and dst_path are similar to those of scp. 295 For example, the following are OK: 296 src_pattern='/tmp/foo.txt', dst_path='C:\\' 297 (uploads a single file) 298 src_pattern='/usr/', dst_path='C:\\Windows\\' 299 (uploads a directory tree recursively) 300 src_pattern='/usr/*', dst_path='C:\\Windows\\' 301 (uploads all files and directory trees under /usr/) 302 The following is not OK: 303 src_pattern='/tmp/foo.txt', dst_path='C:\\Windows\\*' 304 (wildcards are only allowed in src_pattern) 305 306 @param src_pattern: A path or wildcard pattern specifying the files or 307 directories to send to the server 308 @param dst_path: A path in the server's filesystem where the files will 309 be saved 310 @param timeout: Time duration in seconds to wait for the transfer to 311 complete 312 @raise FileTransferTimeoutError: Raised if timeout expires 313 @raise FileTransferServerError: Raised if something goes wrong and the 314 server sends an informative error message to the client 315 @note: Other exceptions can be raised. 316 """ 317 end_time = time.time() + timeout 318 try: 319 try: 320 self._send_msg(RSS_SET_PATH) 321 self._send_packet(dst_path) 322 matches = glob.glob(src_pattern) 323 for filename in matches: 324 self._upload_file(os.path.abspath(filename), end_time) 325 self._send_msg(RSS_DONE) 326 except FileTransferTimeoutError: 327 raise 328 except FileTransferError: 329 self._handle_transfer_error() 330 else: 331 # If nothing was transferred, raise an exception 332 if not matches: 333 raise FileTransferNotFoundError("Pattern %s does not " 334 "match any files or " 335 "directories" % 336 src_pattern) 337 # Look for RSS_OK or RSS_ERROR 338 msg = self._receive_msg(end_time - time.time()) 339 if msg == RSS_OK: 340 return 341 elif msg == RSS_ERROR: 342 errmsg = self._receive_packet() 343 raise FileTransferServerError(errmsg) 344 else: 345 # Neither RSS_OK nor RSS_ERROR found 346 raise FileTransferProtocolError("Received unexpected msg") 347 except: 348 # In any case, if the transfer failed, close the connection 349 self.close() 350 raise 351 352 353class FileDownloadClient(FileTransferClient): 354 """ 355 Connect to a RSS (remote shell server) and download files or directory trees. 356 """ 357 358 def __init__(self, address, port, log_func=None, timeout=20): 359 """ 360 Connect to a server. 361 362 @param address: The server's address 363 @param port: The server's port 364 @param log_func: If provided, transfer stats will be passed to this 365 function during the transfer 366 @param timeout: Time duration to wait for connection to succeed 367 @raise FileTransferConnectError: Raised if the connection fails 368 @raise FileTransferProtocolError: Raised if an incorrect magic number 369 is received 370 @raise FileTransferSendError: Raised if the RSS_UPLOAD message cannot 371 be sent to the server 372 """ 373 super(FileDownloadClient, self).__init__(address, port, log_func, timeout) 374 self._send_msg(RSS_DOWNLOAD) 375 376 377 def download(self, src_pattern, dst_path, timeout=600): 378 """ 379 Receive files or directory trees from the server. 380 The semantics of src_pattern and dst_path are similar to those of scp. 381 For example, the following are OK: 382 src_pattern='C:\\foo.txt', dst_path='/tmp' 383 (downloads a single file) 384 src_pattern='C:\\Windows', dst_path='/tmp' 385 (downloads a directory tree recursively) 386 src_pattern='C:\\Windows\\*', dst_path='/tmp' 387 (downloads all files and directory trees under C:\\Windows) 388 The following is not OK: 389 src_pattern='C:\\Windows', dst_path='/tmp/*' 390 (wildcards are only allowed in src_pattern) 391 392 @param src_pattern: A path or wildcard pattern specifying the files or 393 directories, in the server's filesystem, that will be sent to 394 the client 395 @param dst_path: A path in the local filesystem where the files will 396 be saved 397 @param timeout: Time duration in seconds to wait for the transfer to 398 complete 399 @raise FileTransferTimeoutError: Raised if timeout expires 400 @raise FileTransferServerError: Raised if something goes wrong and the 401 server sends an informative error message to the client 402 @note: Other exceptions can be raised. 403 """ 404 dst_path = os.path.abspath(dst_path) 405 end_time = time.time() + timeout 406 file_count = 0 407 dir_count = 0 408 try: 409 try: 410 self._send_msg(RSS_SET_PATH) 411 self._send_packet(src_pattern) 412 except FileTransferError: 413 self._handle_transfer_error() 414 while True: 415 msg = self._receive_msg() 416 if msg == RSS_CREATE_FILE: 417 # Receive filename and file contents 418 filename = self._receive_packet() 419 if os.path.isdir(dst_path): 420 dst_path = os.path.join(dst_path, filename) 421 self._receive_file_chunks(dst_path, end_time - time.time()) 422 dst_path = os.path.dirname(dst_path) 423 file_count += 1 424 elif msg == RSS_CREATE_DIR: 425 # Receive dirname and create the directory 426 dirname = self._receive_packet() 427 if os.path.isdir(dst_path): 428 dst_path = os.path.join(dst_path, dirname) 429 if not os.path.isdir(dst_path): 430 os.mkdir(dst_path) 431 dir_count += 1 432 elif msg == RSS_LEAVE_DIR: 433 # Return to parent dir 434 dst_path = os.path.dirname(dst_path) 435 elif msg == RSS_DONE: 436 # Transfer complete 437 if not file_count and not dir_count: 438 raise FileTransferNotFoundError("Pattern %s does not " 439 "match any files or " 440 "directories that " 441 "could be downloaded" % 442 src_pattern) 443 break 444 elif msg == RSS_ERROR: 445 # Receive error message and abort 446 errmsg = self._receive_packet() 447 raise FileTransferServerError(errmsg) 448 else: 449 # Unexpected msg 450 raise FileTransferProtocolError("Received unexpected msg") 451 except: 452 # In any case, if the transfer failed, close the connection 453 self.close() 454 raise 455 456 457def upload(address, port, src_pattern, dst_path, log_func=None, timeout=60, 458 connect_timeout=20): 459 """ 460 Connect to server and upload files. 461 462 @see: FileUploadClient 463 """ 464 client = FileUploadClient(address, port, log_func, connect_timeout) 465 client.upload(src_pattern, dst_path, timeout) 466 client.close() 467 468 469def download(address, port, src_pattern, dst_path, log_func=None, timeout=60, 470 connect_timeout=20): 471 """ 472 Connect to server and upload files. 473 474 @see: FileDownloadClient 475 """ 476 client = FileDownloadClient(address, port, log_func, connect_timeout) 477 client.download(src_pattern, dst_path, timeout) 478 client.close() 479 480 481def main(): 482 import optparse 483 484 usage = "usage: %prog [options] address port src_pattern dst_path" 485 parser = optparse.OptionParser(usage=usage) 486 parser.add_option("-d", "--download", 487 action="store_true", dest="download", 488 help="download files from server") 489 parser.add_option("-u", "--upload", 490 action="store_true", dest="upload", 491 help="upload files to server") 492 parser.add_option("-v", "--verbose", 493 action="store_true", dest="verbose", 494 help="be verbose") 495 parser.add_option("-t", "--timeout", 496 type="int", dest="timeout", default=3600, 497 help="transfer timeout") 498 options, args = parser.parse_args() 499 if options.download == options.upload: 500 parser.error("you must specify either -d or -u") 501 if len(args) != 4: 502 parser.error("incorrect number of arguments") 503 address, port, src_pattern, dst_path = args 504 port = int(port) 505 506 logger = None 507 if options.verbose: 508 def p(s): 509 print s 510 logger = p 511 512 if options.download: 513 download(address, port, src_pattern, dst_path, logger, options.timeout) 514 elif options.upload: 515 upload(address, port, src_pattern, dst_path, logger, options.timeout) 516 517 518if __name__ == "__main__": 519 main() 520