git » pymisc » master » tree

[master] / pickle_rpc.py

"""
pickle_rpc - RPC using pickle for serialization
Alberto Bertogli (albertito@blitiri.com.ar)
-----------------------------------------------

This module implements a simple RPC using pickle for serialization. It
provides an interface similar (but not exactly like) XML-RPC.

It should work under any Unix, Windows and Mac systems.

To create an RPC server:

	srv = pickle_rpc.Server('localhost')
	srv.register(my_function)
	srv.loop()

To create an RPC client (we call it "Remote" because the object represents a
"remote execution server"):

	r = pickle_rpc.Remote('localhost')
	print r.my_function(1, a = 'b')

"""

import sys
import os
import traceback
import socket
import select
import errno

try:
	import cPickle as pickle
except:
	import pickle


# default listening port
default_port = 1642


# valid replies
class replies:
	SUCCESS = 0
	UNKNOWN = 1
	EXCEPTION = 2


class SockFD (object):
	"""File descriptor wrapper for socket objects. Implement .write(),
	.read() and .readline(), which is what pickle needs to work."""

	def __init__(self, sock):
		self.sock = sock
		self.fd = os.fdopen(sock.fileno())

	def write(self, s):
		self.sock.send(s)

	def read(self, size = -1):
		if size == -1:
			return self.fd.read()
		return self.fd.read(size)

	def readline(self, size = -1):
		return self.fd.readline(size)

	def fileno(self):
		return self.sock.fileno()

	def close(self):
		self.sock = None
		self.fd.close()
		self.fd = None


class Server (object):
	"Pickle-RPC server"
	def __init__(self, ip, port = default_port):
		self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		self.sock.bind((ip, port))
		self.sock.setsockopt(socket.SOL_SOCKET, \
				socket.SO_REUSEADDR, 1)
		self.fds = []
		self.functions = {}
		self.objects = {}

	def register(self, func, name = None):
		if not name:
			name = func.__name__
		self.functions[name] = func

	def register_object(self, name, obj):
		self.objects[name] = obj

	def find_method(self, chain):
		chain = chain.split('.')
		chain, fname = chain[:-1], chain[-1]

		# the first object is a special case because it comes from
		# self.objects
		if chain[0] not in self.objects:
			return None
		obj = self.objects[chain[0]]

		# walk through the object names
		for oname in chain[1:]:
			obj = getattr(obj, oname, None)
			if obj is None:
				return None

		# finally, return the function (if there is one)
		return getattr(obj, fname, None)

	def loop(self):
		self.sock.listen(15)
		while True:
			l = [self.sock] + self.fds
			ifd, ofd, efd = select.select(l, [], [])

			for fd in ifd:
				if fd == self.sock:
					self.new_connection()
				else:
					self.recv(fd)

	def new_connection(self):
		conn, addr = self.sock.accept()
		self.fds.append(SockFD(conn))

	def end_connection(self, fd):
		if fd in self.fds:
			self.fds.remove(fd)
			try:
				fd.close()
			except IOError:
				pass

	def recv(self, fd):
		try:
			req = pickle.load(fd)
			if not isinstance(req, tuple) and len(req) < 1:
				# invalid request, break the connection
				raise socket.error
		except (socket.error, EOFError), info:
			self.end_connection(fd)
			return
		except:
			print 'Error unpickling'
			traceback.print_exc()
			self.end_connection(fd)
			return

		funcname = req[0]
		params = req[1]
		kwparams = req[2]

		if '.' in funcname:
			# it's a method from (hopefully) one of our registered
			# objects
			func = self.find_method(funcname)
		else:
			func = self.functions.get(funcname, None)

		if func is None:
			self.send(fd, (replies.UNKNOWN,))
			return

		try:
			r = func(*params, **kwparams)
		except:
			e, i, tb = sys.exc_info()
			strtb = traceback.format_exc()
			self.send(fd, (replies.EXCEPTION, e, i, strtb))
			return

		self.send(fd, (replies.SUCCESS, r))

	def send(self, fd, r):
		try:
			pickle.dump(r, fd, -1)
		except (socket.error, EOFError), info:
			self.end_connection(fd)
		except:
			print 'Error pickling'
			traceback.print_exc()
			self.end_connection(fd)


class UnknownFunction (Exception):
	pass


class _ImObject (object):
	"Intermediate object returned by Remote.__getattr__()"
	def __init__(self, remote, name):
		self.__remote = remote
		self.__name = name

	def __call__(self, *args, **kwargs):
		return self.__remote._rpc(self.__name, args, kwargs)

	def __getattr__(self, name):
		return _ImObject(self.__remote, self.__name + '.' + name)


class Remote (object):
	"Pickle-RPC client"
	def __init__(self, ip, port = default_port):
		self.__sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
		self.__sock.connect((ip, port))
		self.__fd = SockFD(self.__sock)
		self.__last_tb_str = ''

	def _last_tb(self):
		"""Returns a string showing the server's last traceback,
		useful for debugging."""
		return self.__last_tb_str

	def __getattr__(self, name):
		return _ImObject(self, name)

	def _rpc(self, name, args, kwargs):
		msg = (name, args, kwargs)
		pickle.dump(msg, self.__fd, -1)
		rep = pickle.load(self.__fd)
		if rep[0] == replies.SUCCESS:
			return rep[1]
		elif rep[0] == replies.EXCEPTION:
			e, i, strtb = rep[1:]
			self.__last_tb_str = strtb
			raise e, i
		elif rep[0] == replies.UNKNOWN:
			raise UnknownFunction, name