• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""RPC Implementation, originally written for the Python Idle IDE
2
3For security reasons, GvR requested that Idle's Python execution server process
4connect to the Idle process, which listens for the connection.  Since Idle has
5only one client per server, this was not a limitation.
6
7   +---------------------------------+ +-------------+
8   | SocketServer.BaseRequestHandler | | SocketIO    |
9   +---------------------------------+ +-------------+
10                   ^                   | register()  |
11                   |                   | unregister()|
12                   |                   +-------------+
13                   |                      ^  ^
14                   |                      |  |
15                   | + -------------------+  |
16                   | |                       |
17   +-------------------------+        +-----------------+
18   | RPCHandler              |        | RPCClient       |
19   | [attribute of RPCServer]|        |                 |
20   +-------------------------+        +-----------------+
21
22The RPCServer handler class is expected to provide register/unregister methods.
23RPCHandler inherits the mix-in class SocketIO, which provides these methods.
24
25See the Idle run.main() docstring for further information on how this was
26accomplished in Idle.
27
28"""
29
30import sys
31import os
32import socket
33import select
34import SocketServer
35import struct
36import cPickle as pickle
37import threading
38import Queue
39import traceback
40import copy_reg
41import types
42import marshal
43
44
45def unpickle_code(ms):
46    co = marshal.loads(ms)
47    assert isinstance(co, types.CodeType)
48    return co
49
50def pickle_code(co):
51    assert isinstance(co, types.CodeType)
52    ms = marshal.dumps(co)
53    return unpickle_code, (ms,)
54
55# XXX KBK 24Aug02 function pickling capability not used in Idle
56#  def unpickle_function(ms):
57#      return ms
58
59#  def pickle_function(fn):
60#      assert isinstance(fn, type.FunctionType)
61#      return repr(fn)
62
63copy_reg.pickle(types.CodeType, pickle_code, unpickle_code)
64# copy_reg.pickle(types.FunctionType, pickle_function, unpickle_function)
65
66BUFSIZE = 8*1024
67LOCALHOST = '127.0.0.1'
68
69class RPCServer(SocketServer.TCPServer):
70
71    def __init__(self, addr, handlerclass=None):
72        if handlerclass is None:
73            handlerclass = RPCHandler
74        SocketServer.TCPServer.__init__(self, addr, handlerclass)
75
76    def server_bind(self):
77        "Override TCPServer method, no bind() phase for connecting entity"
78        pass
79
80    def server_activate(self):
81        """Override TCPServer method, connect() instead of listen()
82
83        Due to the reversed connection, self.server_address is actually the
84        address of the Idle Client to which we are connecting.
85
86        """
87        self.socket.connect(self.server_address)
88
89    def get_request(self):
90        "Override TCPServer method, return already connected socket"
91        return self.socket, self.server_address
92
93    def handle_error(self, request, client_address):
94        """Override TCPServer method
95
96        Error message goes to __stderr__.  No error message if exiting
97        normally or socket raised EOF.  Other exceptions not handled in
98        server code will cause os._exit.
99
100        """
101        try:
102            raise
103        except SystemExit:
104            raise
105        except:
106            erf = sys.__stderr__
107            print>>erf, '\n' + '-'*40
108            print>>erf, 'Unhandled server exception!'
109            print>>erf, 'Thread: %s' % threading.currentThread().getName()
110            print>>erf, 'Client Address: ', client_address
111            print>>erf, 'Request: ', repr(request)
112            traceback.print_exc(file=erf)
113            print>>erf, '\n*** Unrecoverable, server exiting!'
114            print>>erf, '-'*40
115            os._exit(0)
116
117#----------------- end class RPCServer --------------------
118
119objecttable = {}
120request_queue = Queue.Queue(0)
121response_queue = Queue.Queue(0)
122
123
124class SocketIO(object):
125
126    nextseq = 0
127
128    def __init__(self, sock, objtable=None, debugging=None):
129        self.sockthread = threading.currentThread()
130        if debugging is not None:
131            self.debugging = debugging
132        self.sock = sock
133        if objtable is None:
134            objtable = objecttable
135        self.objtable = objtable
136        self.responses = {}
137        self.cvars = {}
138
139    def close(self):
140        sock = self.sock
141        self.sock = None
142        if sock is not None:
143            sock.close()
144
145    def exithook(self):
146        "override for specific exit action"
147        os._exit(0)
148
149    def debug(self, *args):
150        if not self.debugging:
151            return
152        s = self.location + " " + str(threading.currentThread().getName())
153        for a in args:
154            s = s + " " + str(a)
155        print>>sys.__stderr__, s
156
157    def register(self, oid, object):
158        self.objtable[oid] = object
159
160    def unregister(self, oid):
161        try:
162            del self.objtable[oid]
163        except KeyError:
164            pass
165
166    def localcall(self, seq, request):
167        self.debug("localcall:", request)
168        try:
169            how, (oid, methodname, args, kwargs) = request
170        except TypeError:
171            return ("ERROR", "Bad request format")
172        if oid not in self.objtable:
173            return ("ERROR", "Unknown object id: %r" % (oid,))
174        obj = self.objtable[oid]
175        if methodname == "__methods__":
176            methods = {}
177            _getmethods(obj, methods)
178            return ("OK", methods)
179        if methodname == "__attributes__":
180            attributes = {}
181            _getattributes(obj, attributes)
182            return ("OK", attributes)
183        if not hasattr(obj, methodname):
184            return ("ERROR", "Unsupported method name: %r" % (methodname,))
185        method = getattr(obj, methodname)
186        try:
187            if how == 'CALL':
188                ret = method(*args, **kwargs)
189                if isinstance(ret, RemoteObject):
190                    ret = remoteref(ret)
191                return ("OK", ret)
192            elif how == 'QUEUE':
193                request_queue.put((seq, (method, args, kwargs)))
194                return("QUEUED", None)
195            else:
196                return ("ERROR", "Unsupported message type: %s" % how)
197        except SystemExit:
198            raise
199        except socket.error:
200            raise
201        except:
202            msg = "*** Internal Error: rpc.py:SocketIO.localcall()\n\n"\
203                  " Object: %s \n Method: %s \n Args: %s\n"
204            print>>sys.__stderr__, msg % (oid, method, args)
205            traceback.print_exc(file=sys.__stderr__)
206            return ("EXCEPTION", None)
207
208    def remotecall(self, oid, methodname, args, kwargs):
209        self.debug("remotecall:asynccall: ", oid, methodname)
210        seq = self.asynccall(oid, methodname, args, kwargs)
211        return self.asyncreturn(seq)
212
213    def remotequeue(self, oid, methodname, args, kwargs):
214        self.debug("remotequeue:asyncqueue: ", oid, methodname)
215        seq = self.asyncqueue(oid, methodname, args, kwargs)
216        return self.asyncreturn(seq)
217
218    def asynccall(self, oid, methodname, args, kwargs):
219        request = ("CALL", (oid, methodname, args, kwargs))
220        seq = self.newseq()
221        if threading.currentThread() != self.sockthread:
222            cvar = threading.Condition()
223            self.cvars[seq] = cvar
224        self.debug(("asynccall:%d:" % seq), oid, methodname, args, kwargs)
225        self.putmessage((seq, request))
226        return seq
227
228    def asyncqueue(self, oid, methodname, args, kwargs):
229        request = ("QUEUE", (oid, methodname, args, kwargs))
230        seq = self.newseq()
231        if threading.currentThread() != self.sockthread:
232            cvar = threading.Condition()
233            self.cvars[seq] = cvar
234        self.debug(("asyncqueue:%d:" % seq), oid, methodname, args, kwargs)
235        self.putmessage((seq, request))
236        return seq
237
238    def asyncreturn(self, seq):
239        self.debug("asyncreturn:%d:call getresponse(): " % seq)
240        response = self.getresponse(seq, wait=0.05)
241        self.debug(("asyncreturn:%d:response: " % seq), response)
242        return self.decoderesponse(response)
243
244    def decoderesponse(self, response):
245        how, what = response
246        if how == "OK":
247            return what
248        if how == "QUEUED":
249            return None
250        if how == "EXCEPTION":
251            self.debug("decoderesponse: EXCEPTION")
252            return None
253        if how == "EOF":
254            self.debug("decoderesponse: EOF")
255            self.decode_interrupthook()
256            return None
257        if how == "ERROR":
258            self.debug("decoderesponse: Internal ERROR:", what)
259            raise RuntimeError, what
260        raise SystemError, (how, what)
261
262    def decode_interrupthook(self):
263        ""
264        raise EOFError
265
266    def mainloop(self):
267        """Listen on socket until I/O not ready or EOF
268
269        pollresponse() will loop looking for seq number None, which
270        never comes, and exit on EOFError.
271
272        """
273        try:
274            self.getresponse(myseq=None, wait=0.05)
275        except EOFError:
276            self.debug("mainloop:return")
277            return
278
279    def getresponse(self, myseq, wait):
280        response = self._getresponse(myseq, wait)
281        if response is not None:
282            how, what = response
283            if how == "OK":
284                response = how, self._proxify(what)
285        return response
286
287    def _proxify(self, obj):
288        if isinstance(obj, RemoteProxy):
289            return RPCProxy(self, obj.oid)
290        if isinstance(obj, types.ListType):
291            return map(self._proxify, obj)
292        # XXX Check for other types -- not currently needed
293        return obj
294
295    def _getresponse(self, myseq, wait):
296        self.debug("_getresponse:myseq:", myseq)
297        if threading.currentThread() is self.sockthread:
298            # this thread does all reading of requests or responses
299            while 1:
300                response = self.pollresponse(myseq, wait)
301                if response is not None:
302                    return response
303        else:
304            # wait for notification from socket handling thread
305            cvar = self.cvars[myseq]
306            cvar.acquire()
307            while myseq not in self.responses:
308                cvar.wait()
309            response = self.responses[myseq]
310            self.debug("_getresponse:%s: thread woke up: response: %s" %
311                       (myseq, response))
312            del self.responses[myseq]
313            del self.cvars[myseq]
314            cvar.release()
315            return response
316
317    def newseq(self):
318        self.nextseq = seq = self.nextseq + 2
319        return seq
320
321    def putmessage(self, message):
322        self.debug("putmessage:%d:" % message[0])
323        try:
324            s = pickle.dumps(message)
325        except pickle.PicklingError:
326            print >>sys.__stderr__, "Cannot pickle:", repr(message)
327            raise
328        s = struct.pack("<i", len(s)) + s
329        while len(s) > 0:
330            try:
331                r, w, x = select.select([], [self.sock], [])
332                n = self.sock.send(s[:BUFSIZE])
333            except (AttributeError, TypeError):
334                raise IOError, "socket no longer exists"
335            s = s[n:]
336
337    buffer = ""
338    bufneed = 4
339    bufstate = 0 # meaning: 0 => reading count; 1 => reading data
340
341    def pollpacket(self, wait):
342        self._stage0()
343        if len(self.buffer) < self.bufneed:
344            r, w, x = select.select([self.sock.fileno()], [], [], wait)
345            if len(r) == 0:
346                return None
347            try:
348                s = self.sock.recv(BUFSIZE)
349            except socket.error:
350                raise EOFError
351            if len(s) == 0:
352                raise EOFError
353            self.buffer += s
354            self._stage0()
355        return self._stage1()
356
357    def _stage0(self):
358        if self.bufstate == 0 and len(self.buffer) >= 4:
359            s = self.buffer[:4]
360            self.buffer = self.buffer[4:]
361            self.bufneed = struct.unpack("<i", s)[0]
362            self.bufstate = 1
363
364    def _stage1(self):
365        if self.bufstate == 1 and len(self.buffer) >= self.bufneed:
366            packet = self.buffer[:self.bufneed]
367            self.buffer = self.buffer[self.bufneed:]
368            self.bufneed = 4
369            self.bufstate = 0
370            return packet
371
372    def pollmessage(self, wait):
373        packet = self.pollpacket(wait)
374        if packet is None:
375            return None
376        try:
377            message = pickle.loads(packet)
378        except pickle.UnpicklingError:
379            print >>sys.__stderr__, "-----------------------"
380            print >>sys.__stderr__, "cannot unpickle packet:", repr(packet)
381            traceback.print_stack(file=sys.__stderr__)
382            print >>sys.__stderr__, "-----------------------"
383            raise
384        return message
385
386    def pollresponse(self, myseq, wait):
387        """Handle messages received on the socket.
388
389        Some messages received may be asynchronous 'call' or 'queue' requests,
390        and some may be responses for other threads.
391
392        'call' requests are passed to self.localcall() with the expectation of
393        immediate execution, during which time the socket is not serviced.
394
395        'queue' requests are used for tasks (which may block or hang) to be
396        processed in a different thread.  These requests are fed into
397        request_queue by self.localcall().  Responses to queued requests are
398        taken from response_queue and sent across the link with the associated
399        sequence numbers.  Messages in the queues are (sequence_number,
400        request/response) tuples and code using this module removing messages
401        from the request_queue is responsible for returning the correct
402        sequence number in the response_queue.
403
404        pollresponse() will loop until a response message with the myseq
405        sequence number is received, and will save other responses in
406        self.responses and notify the owning thread.
407
408        """
409        while 1:
410            # send queued response if there is one available
411            try:
412                qmsg = response_queue.get(0)
413            except Queue.Empty:
414                pass
415            else:
416                seq, response = qmsg
417                message = (seq, ('OK', response))
418                self.putmessage(message)
419            # poll for message on link
420            try:
421                message = self.pollmessage(wait)
422                if message is None:  # socket not ready
423                    return None
424            except EOFError:
425                self.handle_EOF()
426                return None
427            except AttributeError:
428                return None
429            seq, resq = message
430            how = resq[0]
431            self.debug("pollresponse:%d:myseq:%s" % (seq, myseq))
432            # process or queue a request
433            if how in ("CALL", "QUEUE"):
434                self.debug("pollresponse:%d:localcall:call:" % seq)
435                response = self.localcall(seq, resq)
436                self.debug("pollresponse:%d:localcall:response:%s"
437                           % (seq, response))
438                if how == "CALL":
439                    self.putmessage((seq, response))
440                elif how == "QUEUE":
441                    # don't acknowledge the 'queue' request!
442                    pass
443                continue
444            # return if completed message transaction
445            elif seq == myseq:
446                return resq
447            # must be a response for a different thread:
448            else:
449                cv = self.cvars.get(seq, None)
450                # response involving unknown sequence number is discarded,
451                # probably intended for prior incarnation of server
452                if cv is not None:
453                    cv.acquire()
454                    self.responses[seq] = resq
455                    cv.notify()
456                    cv.release()
457                continue
458
459    def handle_EOF(self):
460        "action taken upon link being closed by peer"
461        self.EOFhook()
462        self.debug("handle_EOF")
463        for key in self.cvars:
464            cv = self.cvars[key]
465            cv.acquire()
466            self.responses[key] = ('EOF', None)
467            cv.notify()
468            cv.release()
469        # call our (possibly overridden) exit function
470        self.exithook()
471
472    def EOFhook(self):
473        "Classes using rpc client/server can override to augment EOF action"
474        pass
475
476#----------------- end class SocketIO --------------------
477
478class RemoteObject(object):
479    # Token mix-in class
480    pass
481
482def remoteref(obj):
483    oid = id(obj)
484    objecttable[oid] = obj
485    return RemoteProxy(oid)
486
487class RemoteProxy(object):
488
489    def __init__(self, oid):
490        self.oid = oid
491
492class RPCHandler(SocketServer.BaseRequestHandler, SocketIO):
493
494    debugging = False
495    location = "#S"  # Server
496
497    def __init__(self, sock, addr, svr):
498        svr.current_handler = self ## cgt xxx
499        SocketIO.__init__(self, sock)
500        SocketServer.BaseRequestHandler.__init__(self, sock, addr, svr)
501
502    def handle(self):
503        "handle() method required by SocketServer"
504        self.mainloop()
505
506    def get_remote_proxy(self, oid):
507        return RPCProxy(self, oid)
508
509class RPCClient(SocketIO):
510
511    debugging = False
512    location = "#C"  # Client
513
514    nextseq = 1 # Requests coming from the client are odd numbered
515
516    def __init__(self, address, family=socket.AF_INET, type=socket.SOCK_STREAM):
517        self.listening_sock = socket.socket(family, type)
518        self.listening_sock.bind(address)
519        self.listening_sock.listen(1)
520
521    def accept(self):
522        working_sock, address = self.listening_sock.accept()
523        if self.debugging:
524            print>>sys.__stderr__, "****** Connection request from ", address
525        if address[0] == LOCALHOST:
526            SocketIO.__init__(self, working_sock)
527        else:
528            print>>sys.__stderr__, "** Invalid host: ", address
529            raise socket.error
530
531    def get_remote_proxy(self, oid):
532        return RPCProxy(self, oid)
533
534class RPCProxy(object):
535
536    __methods = None
537    __attributes = None
538
539    def __init__(self, sockio, oid):
540        self.sockio = sockio
541        self.oid = oid
542
543    def __getattr__(self, name):
544        if self.__methods is None:
545            self.__getmethods()
546        if self.__methods.get(name):
547            return MethodProxy(self.sockio, self.oid, name)
548        if self.__attributes is None:
549            self.__getattributes()
550        if name in self.__attributes:
551            value = self.sockio.remotecall(self.oid, '__getattribute__',
552                                           (name,), {})
553            return value
554        else:
555            raise AttributeError, name
556
557    def __getattributes(self):
558        self.__attributes = self.sockio.remotecall(self.oid,
559                                                "__attributes__", (), {})
560
561    def __getmethods(self):
562        self.__methods = self.sockio.remotecall(self.oid,
563                                                "__methods__", (), {})
564
565def _getmethods(obj, methods):
566    # Helper to get a list of methods from an object
567    # Adds names to dictionary argument 'methods'
568    for name in dir(obj):
569        attr = getattr(obj, name)
570        if hasattr(attr, '__call__'):
571            methods[name] = 1
572    if type(obj) == types.InstanceType:
573        _getmethods(obj.__class__, methods)
574    if type(obj) == types.ClassType:
575        for super in obj.__bases__:
576            _getmethods(super, methods)
577
578def _getattributes(obj, attributes):
579    for name in dir(obj):
580        attr = getattr(obj, name)
581        if not hasattr(attr, '__call__'):
582            attributes[name] = 1
583
584class MethodProxy(object):
585
586    def __init__(self, sockio, oid, name):
587        self.sockio = sockio
588        self.oid = oid
589        self.name = name
590
591    def __call__(self, *args, **kwargs):
592        value = self.sockio.remotecall(self.oid, self.name, args, kwargs)
593        return value
594
595
596# XXX KBK 09Sep03  We need a proper unit test for this module.  Previously
597#                  existing test code was removed at Rev 1.27 (r34098).
598