rearrange transit.py in preparation for refactoring

This commit is contained in:
Brian Warner 2015-02-19 18:19:17 -08:00
parent 12845f191b
commit 66ad6fb272

View File

@ -55,6 +55,98 @@ TIMEOUT=10000
# 4: add relay
# 5: accelerate shutdown of losing sockets
class BadHandshake(Exception):
pass
def connector(owner, hint, send_handshake, expected_handshake):
if isinstance(hint, type(u"")):
hint = hint.encode("ascii")
addr,port = hint.split(",")
skt = None
try:
skt = socket.create_connection((addr,port)) # timeout here
skt.settimeout(TIMEOUT)
#print("socket(%s) connected" % (hint,))
skt.send(send_handshake)
got = b""
while len(got) < len(expected_handshake):
got += skt.recv(1)
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r' on %s" %
(got, expected_handshake, hint))
#print("connector ready %r" % (hint,))
except Exception as e:
try:
if skt:
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
if skt:
skt.close()
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt) # note thread
def handle(skt, client_address, owner, send_handshake, expected_handshake):
try:
#print("handle %r" % (skt,))
skt.settimeout(TIMEOUT)
skt.send(send_handshake)
got = b""
# for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake):
more = skt.recv(1)
if not more:
raise BadHandshake("disconnect after merely '%r'" % got)
got += more
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r'" %
(got, expected_handshake))
#print("handler negotiation finished %r" % (client_address,))
except Exception as e:
#print("handler failed %r" % (client_address,))
try:
# this raises socket.err(EBADF) if the socket was already closed
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
skt.close() # this appears to be idempotent
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt) # note thread
class MyTCPServer(SocketServer.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
kc = self.owner._have_transit_key
kc.acquire()
while not self.owner._transit_key:
kc.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
kc.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle,
args=(request, client_address,
self.owner,
self.owner.handler_send_handshake,
self.owner.handler_expected_handshake))
t.daemon = True
t.start()
class TransitSender:
def __init__(self):
self.key = os.urandom(32)
@ -113,96 +205,7 @@ class TransitSender:
skt.send("nevermind\n")
skt.close()
class BadHandshake(Exception):
pass
def connector(owner, hint, send_handshake, expected_handshake):
if isinstance(hint, type(u"")):
hint = hint.encode("ascii")
addr,port = hint.split(",")
skt = None
try:
skt = socket.create_connection((addr,port)) # timeout here
skt.settimeout(TIMEOUT)
#print("socket(%s) connected" % (hint,))
skt.send(send_handshake)
got = b""
while len(got) < len(expected_handshake):
got += skt.recv(1)
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r' on %s" %
(got, expected_handshake, hint))
#print("connector ready %r" % (hint,))
except Exception as e:
try:
if skt:
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
if skt:
skt.close()
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt) # note thread
def handle(skt, client_address, owner, send_handshake, expected_handshake):
try:
#print("handle %r" % (skt,))
skt.settimeout(TIMEOUT)
skt.send(send_handshake)
got = b""
# for the receiver, this includes the "go\n"
while len(got) < len(expected_handshake):
more = skt.recv(1)
if not more:
raise BadHandshake("disconnect after merely '%r'" % got)
got += more
if expected_handshake[:len(got)] != got:
raise BadHandshake("got '%r' want '%r'" %
(got, expected_handshake))
#print("handler negotiation finished %r" % (client_address,))
except Exception as e:
#print("handler failed %r" % (client_address,))
try:
# this raises socket.err(EBADF) if the socket was already closed
skt.shutdown(socket.SHUT_WR)
except socket.error:
pass
skt.close() # this appears to be idempotent
# ignore socket errors, warn about coding errors
if not isinstance(e, (socket.error, socket.timeout, BadHandshake)):
raise
return
# owner is now responsible for the socket
owner._negotiation_finished(skt) # note thread
class MyTCPServer(SocketServer.TCPServer):
allow_reuse_address = True
def process_request(self, request, client_address):
kc = self.owner._have_transit_key
kc.acquire()
while not self.owner._transit_key:
kc.wait()
# owner._transit_key is either None or set to a value. We don't
# modify it from here, so we can release the condition lock before
# grabbing the key.
kc.release()
# Once it is set, we can get handler_(send|receive)_handshake, which
# is what we actually care about.
t = threading.Thread(target=handle,
args=(request, client_address,
self.owner,
self.owner.handler_send_handshake,
self.owner.handler_expected_handshake))
t.daemon = True
t.start()
class TransitReceiver:
def __init__(self):