• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 """RPC Implementation, originally written for the Python Idle IDE
2 
3 For security reasons, GvR requested that Idle's Python execution server process
4 connect to the Idle process, which listens for the connection.  Since Idle has
5 only one client per server, this was not a limitation.
6 
7    +---------------------------------+ +-------------+
8    | socketserver.BaseRequestHandler | | SocketIO    |
9    +---------------------------------+ +-------------+
10                    ^                   | register()  |
11                    |                   | unregister()|
12                    |                   +-------------+
13                    |                      ^  ^
14                    |                      |  |
15                    | + -------------------+  |
16                    | |                       |
17    +-------------------------+        +-----------------+
18    | RPCHandler              |        | RPCClient       |
19    | [attribute of RPCServer]|        |                 |
20    +-------------------------+        +-----------------+
21 
22 The RPCServer handler class is expected to provide register/unregister methods.
23 RPCHandler inherits the mix-in class SocketIO, which provides these methods.
24 
25 See the Idle run.main() docstring for further information on how this was
26 accomplished in Idle.
27 
28 """
29 import builtins
30 import copyreg
31 import io
32 import marshal
33 import os
34 import pickle
35 import queue
36 import select
37 import socket
38 import socketserver
39 import struct
40 import sys
41 import threading
42 import traceback
43 import types
44 
45 def unpickle_code(ms):
46     "Return code object from marshal string ms."
47     co = marshal.loads(ms)
48     assert isinstance(co, types.CodeType)
49     return co
50 
51 def pickle_code(co):
52     "Return unpickle function and tuple with marshalled co code object."
53     assert isinstance(co, types.CodeType)
54     ms = marshal.dumps(co)
55     return unpickle_code, (ms,)
56 
57 def dumps(obj, protocol=None):
58     "Return pickled (or marshalled) string for obj."
59     # IDLE passes 'None' to select pickle.DEFAULT_PROTOCOL.
60     f = io.BytesIO()
61     p = CodePickler(f, protocol)
62     p.dump(obj)
63     return f.getvalue()
64 
65 
66 class CodePickler(pickle.Pickler):
67     dispatch_table = {types.CodeType: pickle_code, **copyreg.dispatch_table}
68 
69 
70 BUFSIZE = 8*1024
71 LOCALHOST = '127.0.0.1'
72 
73 class RPCServer(socketserver.TCPServer):
74 
75     def __init__(self, addr, handlerclass=None):
76         if handlerclass is None:
77             handlerclass = RPCHandler
78         socketserver.TCPServer.__init__(self, addr, handlerclass)
79 
80     def server_bind(self):
81         "Override TCPServer method, no bind() phase for connecting entity"
82         pass
83 
84     def server_activate(self):
85         """Override TCPServer method, connect() instead of listen()
86 
87         Due to the reversed connection, self.server_address is actually the
88         address of the Idle Client to which we are connecting.
89 
90         """
91         self.socket.connect(self.server_address)
92 
93     def get_request(self):
94         "Override TCPServer method, return already connected socket"
95         return self.socket, self.server_address
96 
97     def handle_error(self, request, client_address):
98         """Override TCPServer method
99 
100         Error message goes to __stderr__.  No error message if exiting
101         normally or socket raised EOF.  Other exceptions not handled in
102         server code will cause os._exit.
103 
104         """
105         try:
106             raise
107         except SystemExit:
108             raise
109         except:
110             erf = sys.__stderr__
111             print('\n' + '-'*40, file=erf)
112             print('Unhandled server exception!', file=erf)
113             print('Thread: %s' % threading.current_thread().name, file=erf)
114             print('Client Address: ', client_address, file=erf)
115             print('Request: ', repr(request), file=erf)
116             traceback.print_exc(file=erf)
117             print('\n*** Unrecoverable, server exiting!', file=erf)
118             print('-'*40, file=erf)
119             os._exit(0)
120 
121 #----------------- end class RPCServer --------------------
122 
123 objecttable = {}
124 request_queue = queue.Queue(0)
125 response_queue = queue.Queue(0)
126 
127 
128 class SocketIO:
129 
130     nextseq = 0
131 
132     def __init__(self, sock, objtable=None, debugging=None):
133         self.sockthread = threading.current_thread()
134         if debugging is not None:
135             self.debugging = debugging
136         self.sock = sock
137         if objtable is None:
138             objtable = objecttable
139         self.objtable = objtable
140         self.responses = {}
141         self.cvars = {}
142 
143     def close(self):
144         sock = self.sock
145         self.sock = None
146         if sock is not None:
147             sock.close()
148 
149     def exithook(self):
150         "override for specific exit action"
151         os._exit(0)
152 
153     def debug(self, *args):
154         if not self.debugging:
155             return
156         s = self.location + " " + str(threading.current_thread().name)
157         for a in args:
158             s = s + " " + str(a)
159         print(s, file=sys.__stderr__)
160 
161     def register(self, oid, object):
162         self.objtable[oid] = object
163 
164     def unregister(self, oid):
165         try:
166             del self.objtable[oid]
167         except KeyError:
168             pass
169 
170     def localcall(self, seq, request):
171         self.debug("localcall:", request)
172         try:
173             how, (oid, methodname, args, kwargs) = request
174         except TypeError:
175             return ("ERROR", "Bad request format")
176         if oid not in self.objtable:
177             return ("ERROR", "Unknown object id: %r" % (oid,))
178         obj = self.objtable[oid]
179         if methodname == "__methods__":
180             methods = {}
181             _getmethods(obj, methods)
182             return ("OK", methods)
183         if methodname == "__attributes__":
184             attributes = {}
185             _getattributes(obj, attributes)
186             return ("OK", attributes)
187         if not hasattr(obj, methodname):
188             return ("ERROR", "Unsupported method name: %r" % (methodname,))
189         method = getattr(obj, methodname)
190         try:
191             if how == 'CALL':
192                 ret = method(*args, **kwargs)
193                 if isinstance(ret, RemoteObject):
194                     ret = remoteref(ret)
195                 return ("OK", ret)
196             elif how == 'QUEUE':
197                 request_queue.put((seq, (method, args, kwargs)))
198                 return("QUEUED", None)
199             else:
200                 return ("ERROR", "Unsupported message type: %s" % how)
201         except SystemExit:
202             raise
203         except KeyboardInterrupt:
204             raise
205         except OSError:
206             raise
207         except Exception as ex:
208             return ("CALLEXC", ex)
209         except:
210             msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
211                   " Object: %s \n Method: %s \n Args: %s\n"
212             print(msg % (oid, method, args), file=sys.__stderr__)
213             traceback.print_exc(file=sys.__stderr__)
214             return ("EXCEPTION", None)
215 
216     def remotecall(self, oid, methodname, args, kwargs):
217         self.debug("remotecall:asynccall: ", oid, methodname)
218         seq = self.asynccall(oid, methodname, args, kwargs)
219         return self.asyncreturn(seq)
220 
221     def remotequeue(self, oid, methodname, args, kwargs):
222         self.debug("remotequeue:asyncqueue: ", oid, methodname)
223         seq = self.asyncqueue(oid, methodname, args, kwargs)
224         return self.asyncreturn(seq)
225 
226     def asynccall(self, oid, methodname, args, kwargs):
227         request = ("CALL", (oid, methodname, args, kwargs))
228         seq = self.newseq()
229         if threading.current_thread() != self.sockthread:
230             cvar = threading.Condition()
231             self.cvars[seq] = cvar
232         self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
233         self.putmessage((seq, request))
234         return seq
235 
236     def asyncqueue(self, oid, methodname, args, kwargs):
237         request = ("QUEUE", (oid, methodname, args, kwargs))
238         seq = self.newseq()
239         if threading.current_thread() != self.sockthread:
240             cvar = threading.Condition()
241             self.cvars[seq] = cvar
242         self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
243         self.putmessage((seq, request))
244         return seq
245 
246     def asyncreturn(self, seq):
247         self.debug("asyncreturn:%d:call getresponse(): " % seq)
248         response = self.getresponse(seq, wait=0.05)
249         self.debug(("asyncreturn:%d:response: " % seq), response)
250         return self.decoderesponse(response)
251 
252     def decoderesponse(self, response):
253         how, what = response
254         if how == "OK":
255             return what
256         if how == "QUEUED":
257             return None
258         if how == "EXCEPTION":
259             self.debug("decoderesponse: EXCEPTION")
260             return None
261         if how == "EOF":
262             self.debug("decoderesponse: EOF")
263             self.decode_interrupthook()
264             return None
265         if how == "ERROR":
266             self.debug("decoderesponse: Internal ERROR:", what)
267             raise RuntimeError(what)
268         if how == "CALLEXC":
269             self.debug("decoderesponse: Call Exception:", what)
270             raise what
271         raise SystemError(how, what)
272 
273     def decode_interrupthook(self):
274         ""
275         raise EOFError
276 
277     def mainloop(self):
278         """Listen on socket until I/O not ready or EOF
279 
280         pollresponse() will loop looking for seq number None, which
281         never comes, and exit on EOFError.
282 
283         """
284         try:
285             self.getresponse(myseq=None, wait=0.05)
286         except EOFError:
287             self.debug("mainloop:return")
288             return
289 
290     def getresponse(self, myseq, wait):
291         response = self._getresponse(myseq, wait)
292         if response is not None:
293             how, what = response
294             if how == "OK":
295                 response = how, self._proxify(what)
296         return response
297 
298     def _proxify(self, obj):
299         if isinstance(obj, RemoteProxy):
300             return RPCProxy(self, obj.oid)
301         if isinstance(obj, list):
302             return list(map(self._proxify, obj))
303         # XXX Check for other types -- not currently needed
304         return obj
305 
306     def _getresponse(self, myseq, wait):
307         self.debug("_getresponse:myseq:", myseq)
308         if threading.current_thread() is self.sockthread:
309             # this thread does all reading of requests or responses
310             while 1:
311                 response = self.pollresponse(myseq, wait)
312                 if response is not None:
313                     return response
314         else:
315             # wait for notification from socket handling thread
316             cvar = self.cvars[myseq]
317             cvar.acquire()
318             while myseq not in self.responses:
319                 cvar.wait()
320             response = self.responses[myseq]
321             self.debug("_getresponse:%s: thread woke up: response: %s" %
322                        (myseq, response))
323             del self.responses[myseq]
324             del self.cvars[myseq]
325             cvar.release()
326             return response
327 
328     def newseq(self):
329         self.nextseq = seq = self.nextseq + 2
330         return seq
331 
332     def putmessage(self, message):
333         self.debug("putmessage:%d:" % message[0])
334         try:
335             s = dumps(message)
336         except pickle.PicklingError:
337             print("Cannot pickle:", repr(message), file=sys.__stderr__)
338             raise
339         s = struct.pack("<i", len(s)) + s
340         while len(s) > 0:
341             try:
342                 r, w, x = select.select([], [self.sock], [])
343                 n = self.sock.send(s[:BUFSIZE])
344             except (AttributeError, TypeError):
345                 raise OSError("socket no longer exists")
346             s = s[n:]
347 
348     buff = b''
349     bufneed = 4
350     bufstate = 0 # meaning: 0 => reading count; 1 => reading data
351 
352     def pollpacket(self, wait):
353         self._stage0()
354         if len(self.buff) < self.bufneed:
355             r, w, x = select.select([self.sock.fileno()], [], [], wait)
356             if len(r) == 0:
357                 return None
358             try:
359                 s = self.sock.recv(BUFSIZE)
360             except OSError:
361                 raise EOFError
362             if len(s) == 0:
363                 raise EOFError
364             self.buff += s
365             self._stage0()
366         return self._stage1()
367 
368     def _stage0(self):
369         if self.bufstate == 0 and len(self.buff) >= 4:
370             s = self.buff[:4]
371             self.buff = self.buff[4:]
372             self.bufneed = struct.unpack("<i", s)[0]
373             self.bufstate = 1
374 
375     def _stage1(self):
376         if self.bufstate == 1 and len(self.buff) >= self.bufneed:
377             packet = self.buff[:self.bufneed]
378             self.buff = self.buff[self.bufneed:]
379             self.bufneed = 4
380             self.bufstate = 0
381             return packet
382 
383     def pollmessage(self, wait):
384         packet = self.pollpacket(wait)
385         if packet is None:
386             return None
387         try:
388             message = pickle.loads(packet)
389         except pickle.UnpicklingError:
390             print("-----------------------", file=sys.__stderr__)
391             print("cannot unpickle packet:", repr(packet), file=sys.__stderr__)
392             traceback.print_stack(file=sys.__stderr__)
393             print("-----------------------", file=sys.__stderr__)
394             raise
395         return message
396 
397     def pollresponse(self, myseq, wait):
398         """Handle messages received on the socket.
399 
400         Some messages received may be asynchronous 'call' or 'queue' requests,
401         and some may be responses for other threads.
402 
403         'call' requests are passed to self.localcall() with the expectation of
404         immediate execution, during which time the socket is not serviced.
405 
406         'queue' requests are used for tasks (which may block or hang) to be
407         processed in a different thread.  These requests are fed into
408         request_queue by self.localcall().  Responses to queued requests are
409         taken from response_queue and sent across the link with the associated
410         sequence numbers.  Messages in the queues are (sequence_number,
411         request/response) tuples and code using this module removing messages
412         from the request_queue is responsible for returning the correct
413         sequence number in the response_queue.
414 
415         pollresponse() will loop until a response message with the myseq
416         sequence number is received, and will save other responses in
417         self.responses and notify the owning thread.
418 
419         """
420         while 1:
421             # send queued response if there is one available
422             try:
423                 qmsg = response_queue.get(0)
424             except queue.Empty:
425                 pass
426             else:
427                 seq, response = qmsg
428                 message = (seq, ('OK', response))
429                 self.putmessage(message)
430             # poll for message on link
431             try:
432                 message = self.pollmessage(wait)
433                 if message is None:  # socket not ready
434                     return None
435             except EOFError:
436                 self.handle_EOF()
437                 return None
438             except AttributeError:
439                 return None
440             seq, resq = message
441             how = resq[0]
442             self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
443             # process or queue a request
444             if how in ("CALL", "QUEUE"):
445                 self.debug("pollresponse:%d:localcall:call:" % seq)
446                 response = self.localcall(seq, resq)
447                 self.debug("pollresponse:%d:localcall:response:%s"
448                            % (seq, response))
449                 if how == "CALL":
450                     self.putmessage((seq, response))
451                 elif how == "QUEUE":
452                     # don't acknowledge the 'queue' request!
453                     pass
454                 continue
455             # return if completed message transaction
456             elif seq == myseq:
457                 return resq
458             # must be a response for a different thread:
459             else:
460                 cv = self.cvars.get(seq, None)
461                 # response involving unknown sequence number is discarded,
462                 # probably intended for prior incarnation of server
463                 if cv is not None:
464                     cv.acquire()
465                     self.responses[seq] = resq
466                     cv.notify()
467                     cv.release()
468                 continue
469 
470     def handle_EOF(self):
471         "action taken upon link being closed by peer"
472         self.EOFhook()
473         self.debug("handle_EOF")
474         for key in self.cvars:
475             cv = self.cvars[key]
476             cv.acquire()
477             self.responses[key] = ('EOF', None)
478             cv.notify()
479             cv.release()
480         # call our (possibly overridden) exit function
481         self.exithook()
482 
483     def EOFhook(self):
484         "Classes using rpc client/server can override to augment EOF action"
485         pass
486 
487 #----------------- end class SocketIO --------------------
488 
489 class RemoteObject:
490     # Token mix-in class
491     pass
492 
493 
494 def remoteref(obj):
495     oid = id(obj)
496     objecttable[oid] = obj
497     return RemoteProxy(oid)
498 
499 
500 class RemoteProxy:
501 
502     def __init__(self, oid):
503         self.oid = oid
504 
505 
506 class RPCHandler(socketserver.BaseRequestHandler, SocketIO):
507 
508     debugging = False
509     location = "#S"  # Server
510 
511     def __init__(self, sock, addr, svr):
512         svr.current_handler = self ## cgt xxx
513         SocketIO.__init__(self, sock)
514         socketserver.BaseRequestHandler.__init__(self, sock, addr, svr)
515 
516     def handle(self):
517         "handle() method required by socketserver"
518         self.mainloop()
519 
520     def get_remote_proxy(self, oid):
521         return RPCProxy(self, oid)
522 
523 
524 class RPCClient(SocketIO):
525 
526     debugging = False
527     location = "#C"  # Client
528 
529     nextseq = 1 # Requests coming from the client are odd numbered
530 
531     def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
532         self.listening_sock = socket.socket(family, type)
533         self.listening_sock.bind(address)
534         self.listening_sock.listen(1)
535 
536     def accept(self):
537         working_sock, address = self.listening_sock.accept()
538         if self.debugging:
539             print("****** Connection request from ", address, file=sys.__stderr__)
540         if address[0] == LOCALHOST:
541             SocketIO.__init__(self, working_sock)
542         else:
543             print("** Invalid host: ", address, file=sys.__stderr__)
544             raise OSError
545 
546     def get_remote_proxy(self, oid):
547         return RPCProxy(self, oid)
548 
549 
550 class RPCProxy:
551 
552     __methods = None
553     __attributes = None
554 
555     def __init__(self, sockio, oid):
556         self.sockio = sockio
557         self.oid = oid
558 
559     def __getattr__(self, name):
560         if self.__methods is None:
561             self.__getmethods()
562         if self.__methods.get(name):
563             return MethodProxy(self.sockio, self.oid, name)
564         if self.__attributes is None:
565             self.__getattributes()
566         if name in self.__attributes:
567             value = self.sockio.remotecall(self.oid, '__getattribute__',
568                                            (name,), {})
569             return value
570         else:
571             raise AttributeError(name)
572 
573     def __getattributes(self):
574         self.__attributes = self.sockio.remotecall(self.oid,
575                                                 "__attributes__", (), {})
576 
577     def __getmethods(self):
578         self.__methods = self.sockio.remotecall(self.oid,
579                                                 "__methods__", (), {})
580 
581 def _getmethods(obj, methods):
582     # Helper to get a list of methods from an object
583     # Adds names to dictionary argument 'methods'
584     for name in dir(obj):
585         attr = getattr(obj, name)
586         if callable(attr):
587             methods[name] = 1
588     if isinstance(obj, type):
589         for super in obj.__bases__:
590             _getmethods(super, methods)
591 
592 def _getattributes(obj, attributes):
593     for name in dir(obj):
594         attr = getattr(obj, name)
595         if not callable(attr):
596             attributes[name] = 1
597 
598 
599 class MethodProxy:
600 
601     def __init__(self, sockio, oid, name):
602         self.sockio = sockio
603         self.oid = oid
604         self.name = name
605 
606     def __call__(self, /, *args, **kwargs):
607         value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
608         return value
609 
610 
611 # XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
612 #                  existing test code was removed at Rev 1.27 (r34098).
613 
614 def displayhook(value):
615     """Override standard display hook to use non-locale encoding"""
616     if value is None:
617         return
618     # Set '_' to None to avoid recursion
619     builtins._ = None
620     text = repr(value)
621     try:
622         sys.stdout.write(text)
623     except UnicodeEncodeError:
624         # let's use ascii while utf8-bmp codec doesn't present
625         encoding = 'ascii'
626         bytes = text.encode(encoding, 'backslashreplace')
627         text = bytes.decode(encoding, 'strict')
628         sys.stdout.write(text)
629     sys.stdout.write("\n")
630     builtins._ = value
631 
632 
633 if __name__ == '__main__':
634     from unittest import main
635     main('idlelib.idle_test.test_rpc', verbosity=2,)
636