• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""RPC Implementation, originally written for the Python Idle IDE
2
3For security reasons, GvR requested that Idle's Python execution server process
4connect to the Idle process, which listens for the connection.  Since Idle has
5only one client per server, this was not a limitation.
6
7   +---------------------------------+ +-------------+
8   | socketserver.BaseRequestHandler | | SocketIO    |
9   +---------------------------------+ +-------------+
10                   ^                   | register()  |
11                   |                   | unregister()|
12                   |                   +-------------+
13                   |                      ^  ^
14                   |                      |  |
15                   | + -------------------+  |
16                   | |                       |
17   +-------------------------+        +-----------------+
18   | RPCHandler              |        | RPCClient       |
19   | [attribute of RPCServer]|        |                 |
20   +-------------------------+        +-----------------+
21
22The RPCServer handler class is expected to provide register/unregister methods.
23RPCHandler inherits the mix-in class SocketIO, which provides these methods.
24
25See the Idle run.main() docstring for further information on how this was
26accomplished in Idle.
27
28"""
29import builtins
30import copyreg
31import io
32import marshal
33import os
34import pickle
35import queue
36import select
37import socket
38import socketserver
39import struct
40import sys
41import threading
42import traceback
43import types
44
45def unpickle_code(ms):
46    co = marshal.loads(ms)
47    assert isinstance(co, types.CodeType)
48    return co
49
50def pickle_code(co):
51    assert isinstance(co, types.CodeType)
52    ms = marshal.dumps(co)
53    return unpickle_code, (ms,)
54
55def dumps(obj, protocol=None):
56    f = io.BytesIO()
57    p = CodePickler(f, protocol)
58    p.dump(obj)
59    return f.getvalue()
60
61
62class CodePickler(pickle.Pickler):
63    dispatch_table = {types.CodeType: pickle_code}
64    dispatch_table.update(copyreg.dispatch_table)
65
66
67BUFSIZE = 8*1024
68LOCALHOST = '127.0.0.1'
69
70class RPCServer(socketserver.TCPServer):
71
72    def __init__(self, addr, handlerclass=None):
73        if handlerclass is None:
74            handlerclass = RPCHandler
75        socketserver.TCPServer.__init__(self, addr, handlerclass)
76
77    def server_bind(self):
78        "Override TCPServer method, no bind() phase for connecting entity"
79        pass
80
81    def server_activate(self):
82        """Override TCPServer method, connect() instead of listen()
83
84        Due to the reversed connection, self.server_address is actually the
85        address of the Idle Client to which we are connecting.
86
87        """
88        self.socket.connect(self.server_address)
89
90    def get_request(self):
91        "Override TCPServer method, return already connected socket"
92        return self.socket, self.server_address
93
94    def handle_error(self, request, client_address):
95        """Override TCPServer method
96
97        Error message goes to __stderr__.  No error message if exiting
98        normally or socket raised EOF.  Other exceptions not handled in
99        server code will cause os._exit.
100
101        """
102        try:
103            raise
104        except SystemExit:
105            raise
106        except:
107            erf = sys.__stderr__
108            print('\n' + '-'*40, file=erf)
109            print('Unhandled server exception!', file=erf)
110            print('Thread: %s' % threading.current_thread().name, file=erf)
111            print('Client Address: ', client_address, file=erf)
112            print('Request: ', repr(request), file=erf)
113            traceback.print_exc(file=erf)
114            print('\n*** Unrecoverable, server exiting!', file=erf)
115            print('-'*40, file=erf)
116            os._exit(0)
117
118#----------------- end class RPCServer --------------------
119
120objecttable = {}
121request_queue = queue.Queue(0)
122response_queue = queue.Queue(0)
123
124
125class SocketIO(object):
126
127    nextseq = 0
128
129    def __init__(self, sock, objtable=None, debugging=None):
130        self.sockthread = threading.current_thread()
131        if debugging is not None:
132            self.debugging = debugging
133        self.sock = sock
134        if objtable is None:
135            objtable = objecttable
136        self.objtable = objtable
137        self.responses = {}
138        self.cvars = {}
139
140    def close(self):
141        sock = self.sock
142        self.sock = None
143        if sock is not None:
144            sock.close()
145
146    def exithook(self):
147        "override for specific exit action"
148        os._exit(0)
149
150    def debug(self, *args):
151        if not self.debugging:
152            return
153        s = self.location + " " + str(threading.current_thread().name)
154        for a in args:
155            s = s + " " + str(a)
156        print(s, file=sys.__stderr__)
157
158    def register(self, oid, object):
159        self.objtable[oid] = object
160
161    def unregister(self, oid):
162        try:
163            del self.objtable[oid]
164        except KeyError:
165            pass
166
167    def localcall(self, seq, request):
168        self.debug("localcall:", request)
169        try:
170            how, (oid, methodname, args, kwargs) = request
171        except TypeError:
172            return ("ERROR", "Bad request format")
173        if oid not in self.objtable:
174            return ("ERROR", "Unknown object id: %r" % (oid,))
175        obj = self.objtable[oid]
176        if methodname == "__methods__":
177            methods = {}
178            _getmethods(obj, methods)
179            return ("OK", methods)
180        if methodname == "__attributes__":
181            attributes = {}
182            _getattributes(obj, attributes)
183            return ("OK", attributes)
184        if not hasattr(obj, methodname):
185            return ("ERROR", "Unsupported method name: %r" % (methodname,))
186        method = getattr(obj, methodname)
187        try:
188            if how == 'CALL':
189                ret = method(*args, **kwargs)
190                if isinstance(ret, RemoteObject):
191                    ret = remoteref(ret)
192                return ("OK", ret)
193            elif how == 'QUEUE':
194                request_queue.put((seq, (method, args, kwargs)))
195                return("QUEUED", None)
196            else:
197                return ("ERROR", "Unsupported message type: %s" % how)
198        except SystemExit:
199            raise
200        except KeyboardInterrupt:
201            raise
202        except OSError:
203            raise
204        except Exception as ex:
205            return ("CALLEXC", ex)
206        except:
207            msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
208                  " Object: %s \n Method: %s \n Args: %s\n"
209            print(msg % (oid, method, args), file=sys.__stderr__)
210            traceback.print_exc(file=sys.__stderr__)
211            return ("EXCEPTION", None)
212
213    def remotecall(self, oid, methodname, args, kwargs):
214        self.debug("remotecall:asynccall: ", oid, methodname)
215        seq = self.asynccall(oid, methodname, args, kwargs)
216        return self.asyncreturn(seq)
217
218    def remotequeue(self, oid, methodname, args, kwargs):
219        self.debug("remotequeue:asyncqueue: ", oid, methodname)
220        seq = self.asyncqueue(oid, methodname, args, kwargs)
221        return self.asyncreturn(seq)
222
223    def asynccall(self, oid, methodname, args, kwargs):
224        request = ("CALL", (oid, methodname, args, kwargs))
225        seq = self.newseq()
226        if threading.current_thread() != self.sockthread:
227            cvar = threading.Condition()
228            self.cvars[seq] = cvar
229        self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
230        self.putmessage((seq, request))
231        return seq
232
233    def asyncqueue(self, oid, methodname, args, kwargs):
234        request = ("QUEUE", (oid, methodname, args, kwargs))
235        seq = self.newseq()
236        if threading.current_thread() != self.sockthread:
237            cvar = threading.Condition()
238            self.cvars[seq] = cvar
239        self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
240        self.putmessage((seq, request))
241        return seq
242
243    def asyncreturn(self, seq):
244        self.debug("asyncreturn:%d:call getresponse(): " % seq)
245        response = self.getresponse(seq, wait=0.05)
246        self.debug(("asyncreturn:%d:response: " % seq), response)
247        return self.decoderesponse(response)
248
249    def decoderesponse(self, response):
250        how, what = response
251        if how == "OK":
252            return what
253        if how == "QUEUED":
254            return None
255        if how == "EXCEPTION":
256            self.debug("decoderesponse: EXCEPTION")
257            return None
258        if how == "EOF":
259            self.debug("decoderesponse: EOF")
260            self.decode_interrupthook()
261            return None
262        if how == "ERROR":
263            self.debug("decoderesponse: Internal ERROR:", what)
264            raise RuntimeError(what)
265        if how == "CALLEXC":
266            self.debug("decoderesponse: Call Exception:", what)
267            raise what
268        raise SystemError(how, what)
269
270    def decode_interrupthook(self):
271        ""
272        raise EOFError
273
274    def mainloop(self):
275        """Listen on socket until I/O not ready or EOF
276
277        pollresponse() will loop looking for seq number None, which
278        never comes, and exit on EOFError.
279
280        """
281        try:
282            self.getresponse(myseq=None, wait=0.05)
283        except EOFError:
284            self.debug("mainloop:return")
285            return
286
287    def getresponse(self, myseq, wait):
288        response = self._getresponse(myseq, wait)
289        if response is not None:
290            how, what = response
291            if how == "OK":
292                response = how, self._proxify(what)
293        return response
294
295    def _proxify(self, obj):
296        if isinstance(obj, RemoteProxy):
297            return RPCProxy(self, obj.oid)
298        if isinstance(obj, list):
299            return list(map(self._proxify, obj))
300        # XXX Check for other types -- not currently needed
301        return obj
302
303    def _getresponse(self, myseq, wait):
304        self.debug("_getresponse:myseq:", myseq)
305        if threading.current_thread() is self.sockthread:
306            # this thread does all reading of requests or responses
307            while 1:
308                response = self.pollresponse(myseq, wait)
309                if response is not None:
310                    return response
311        else:
312            # wait for notification from socket handling thread
313            cvar = self.cvars[myseq]
314            cvar.acquire()
315            while myseq not in self.responses:
316                cvar.wait()
317            response = self.responses[myseq]
318            self.debug("_getresponse:%s: thread woke up: response: %s" %
319                       (myseq, response))
320            del self.responses[myseq]
321            del self.cvars[myseq]
322            cvar.release()
323            return response
324
325    def newseq(self):
326        self.nextseq = seq = self.nextseq + 2
327        return seq
328
329    def putmessage(self, message):
330        self.debug("putmessage:%d:" % message[0])
331        try:
332            s = dumps(message)
333        except pickle.PicklingError:
334            print("Cannot pickle:", repr(message), file=sys.__stderr__)
335            raise
336        s = struct.pack("<i", len(s)) + s
337        while len(s) > 0:
338            try:
339                r, w, x = select.select([], [self.sock], [])
340                n = self.sock.send(s[:BUFSIZE])
341            except (AttributeError, TypeError):
342                raise OSError("socket no longer exists")
343            s = s[n:]
344
345    buff = b''
346    bufneed = 4
347    bufstate = 0 # meaning: 0 => reading count; 1 => reading data
348
349    def pollpacket(self, wait):
350        self._stage0()
351        if len(self.buff) < self.bufneed:
352            r, w, x = select.select([self.sock.fileno()], [], [], wait)
353            if len(r) == 0:
354                return None
355            try:
356                s = self.sock.recv(BUFSIZE)
357            except OSError:
358                raise EOFError
359            if len(s) == 0:
360                raise EOFError
361            self.buff += s
362            self._stage0()
363        return self._stage1()
364
365    def _stage0(self):
366        if self.bufstate == 0 and len(self.buff) >= 4:
367            s = self.buff[:4]
368            self.buff = self.buff[4:]
369            self.bufneed = struct.unpack("<i", s)[0]
370            self.bufstate = 1
371
372    def _stage1(self):
373        if self.bufstate == 1 and len(self.buff) >= self.bufneed:
374            packet = self.buff[:self.bufneed]
375            self.buff = self.buff[self.bufneed:]
376            self.bufneed = 4
377            self.bufstate = 0
378            return packet
379
380    def pollmessage(self, wait):
381        packet = self.pollpacket(wait)
382        if packet is None:
383            return None
384        try:
385            message = pickle.loads(packet)
386        except pickle.UnpicklingError:
387            print("-----------------------", file=sys.__stderr__)
388            print("cannot unpickle packet:", repr(packet), file=sys.__stderr__)
389            traceback.print_stack(file=sys.__stderr__)
390            print("-----------------------", file=sys.__stderr__)
391            raise
392        return message
393
394    def pollresponse(self, myseq, wait):
395        """Handle messages received on the socket.
396
397        Some messages received may be asynchronous 'call' or 'queue' requests,
398        and some may be responses for other threads.
399
400        'call' requests are passed to self.localcall() with the expectation of
401        immediate execution, during which time the socket is not serviced.
402
403        'queue' requests are used for tasks (which may block or hang) to be
404        processed in a different thread.  These requests are fed into
405        request_queue by self.localcall().  Responses to queued requests are
406        taken from response_queue and sent across the link with the associated
407        sequence numbers.  Messages in the queues are (sequence_number,
408        request/response) tuples and code using this module removing messages
409        from the request_queue is responsible for returning the correct
410        sequence number in the response_queue.
411
412        pollresponse() will loop until a response message with the myseq
413        sequence number is received, and will save other responses in
414        self.responses and notify the owning thread.
415
416        """
417        while 1:
418            # send queued response if there is one available
419            try:
420                qmsg = response_queue.get(0)
421            except queue.Empty:
422                pass
423            else:
424                seq, response = qmsg
425                message = (seq, ('OK', response))
426                self.putmessage(message)
427            # poll for message on link
428            try:
429                message = self.pollmessage(wait)
430                if message is None:  # socket not ready
431                    return None
432            except EOFError:
433                self.handle_EOF()
434                return None
435            except AttributeError:
436                return None
437            seq, resq = message
438            how = resq[0]
439            self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
440            # process or queue a request
441            if how in ("CALL", "QUEUE"):
442                self.debug("pollresponse:%d:localcall:call:" % seq)
443                response = self.localcall(seq, resq)
444                self.debug("pollresponse:%d:localcall:response:%s"
445                           % (seq, response))
446                if how == "CALL":
447                    self.putmessage((seq, response))
448                elif how == "QUEUE":
449                    # don't acknowledge the 'queue' request!
450                    pass
451                continue
452            # return if completed message transaction
453            elif seq == myseq:
454                return resq
455            # must be a response for a different thread:
456            else:
457                cv = self.cvars.get(seq, None)
458                # response involving unknown sequence number is discarded,
459                # probably intended for prior incarnation of server
460                if cv is not None:
461                    cv.acquire()
462                    self.responses[seq] = resq
463                    cv.notify()
464                    cv.release()
465                continue
466
467    def handle_EOF(self):
468        "action taken upon link being closed by peer"
469        self.EOFhook()
470        self.debug("handle_EOF")
471        for key in self.cvars:
472            cv = self.cvars[key]
473            cv.acquire()
474            self.responses[key] = ('EOF', None)
475            cv.notify()
476            cv.release()
477        # call our (possibly overridden) exit function
478        self.exithook()
479
480    def EOFhook(self):
481        "Classes using rpc client/server can override to augment EOF action"
482        pass
483
484#----------------- end class SocketIO --------------------
485
486class RemoteObject(object):
487    # Token mix-in class
488    pass
489
490
491def remoteref(obj):
492    oid = id(obj)
493    objecttable[oid] = obj
494    return RemoteProxy(oid)
495
496
497class RemoteProxy(object):
498
499    def __init__(self, oid):
500        self.oid = oid
501
502
503class RPCHandler(socketserver.BaseRequestHandler, SocketIO):
504
505    debugging = False
506    location = "#S"  # Server
507
508    def __init__(self, sock, addr, svr):
509        svr.current_handler = self ## cgt xxx
510        SocketIO.__init__(self, sock)
511        socketserver.BaseRequestHandler.__init__(self, sock, addr, svr)
512
513    def handle(self):
514        "handle() method required by socketserver"
515        self.mainloop()
516
517    def get_remote_proxy(self, oid):
518        return RPCProxy(self, oid)
519
520
521class RPCClient(SocketIO):
522
523    debugging = False
524    location = "#C"  # Client
525
526    nextseq = 1 # Requests coming from the client are odd numbered
527
528    def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
529        self.listening_sock = socket.socket(family, type)
530        self.listening_sock.bind(address)
531        self.listening_sock.listen(1)
532
533    def accept(self):
534        working_sock, address = self.listening_sock.accept()
535        if self.debugging:
536            print("****** Connection request from ", address, file=sys.__stderr__)
537        if address[0] == LOCALHOST:
538            SocketIO.__init__(self, working_sock)
539        else:
540            print("** Invalid host: ", address, file=sys.__stderr__)
541            raise OSError
542
543    def get_remote_proxy(self, oid):
544        return RPCProxy(self, oid)
545
546
547class RPCProxy(object):
548
549    __methods = None
550    __attributes = None
551
552    def __init__(self, sockio, oid):
553        self.sockio = sockio
554        self.oid = oid
555
556    def __getattr__(self, name):
557        if self.__methods is None:
558            self.__getmethods()
559        if self.__methods.get(name):
560            return MethodProxy(self.sockio, self.oid, name)
561        if self.__attributes is None:
562            self.__getattributes()
563        if name in self.__attributes:
564            value = self.sockio.remotecall(self.oid, '__getattribute__',
565                                           (name,), {})
566            return value
567        else:
568            raise AttributeError(name)
569
570    def __getattributes(self):
571        self.__attributes = self.sockio.remotecall(self.oid,
572                                                "__attributes__", (), {})
573
574    def __getmethods(self):
575        self.__methods = self.sockio.remotecall(self.oid,
576                                                "__methods__", (), {})
577
578def _getmethods(obj, methods):
579    # Helper to get a list of methods from an object
580    # Adds names to dictionary argument 'methods'
581    for name in dir(obj):
582        attr = getattr(obj, name)
583        if callable(attr):
584            methods[name] = 1
585    if isinstance(obj, type):
586        for super in obj.__bases__:
587            _getmethods(super, methods)
588
589def _getattributes(obj, attributes):
590    for name in dir(obj):
591        attr = getattr(obj, name)
592        if not callable(attr):
593            attributes[name] = 1
594
595
596class MethodProxy(object):
597
598    def __init__(self, sockio, oid, name):
599        self.sockio = sockio
600        self.oid = oid
601        self.name = name
602
603    def __call__(self, *args, **kwargs):
604        value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
605        return value
606
607
608# XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
609#                  existing test code was removed at Rev 1.27 (r34098).
610
611def displayhook(value):
612    """Override standard display hook to use non-locale encoding"""
613    if value is None:
614        return
615    # Set '_' to None to avoid recursion
616    builtins._ = None
617    text = repr(value)
618    try:
619        sys.stdout.write(text)
620    except UnicodeEncodeError:
621        # let's use ascii while utf8-bmp codec doesn't present
622        encoding = 'ascii'
623        bytes = text.encode(encoding, 'backslashreplace')
624        text = bytes.decode(encoding, 'strict')
625        sys.stdout.write(text)
626    sys.stdout.write("\n")
627    builtins._ = value
628