diff --git a/src/wormhole/blocking/transit.py b/src/wormhole/blocking/transit.py index 3b39153..22f3343 100644 --- a/src/wormhole/blocking/transit.py +++ b/src/wormhole/blocking/transit.py @@ -55,6 +55,98 @@ TIMEOUT=10000 # 4: add relay # 5: accelerate shutdown of losing sockets + +class BadHandshake(Exception): + pass + +def connector(owner, hint, send_handshake, expected_handshake): + if isinstance(hint, type(u"")): + hint = hint.encode("ascii") + addr,port = hint.split(",") + skt = None + try: + skt = socket.create_connection((addr,port)) # timeout here + skt.settimeout(TIMEOUT) + #print("socket(%s) connected" % (hint,)) + skt.send(send_handshake) + got = b"" + while len(got) < len(expected_handshake): + got += skt.recv(1) + if expected_handshake[:len(got)] != got: + raise BadHandshake("got '%r' want '%r' on %s" % + (got, expected_handshake, hint)) + #print("connector ready %r" % (hint,)) + except Exception as e: + try: + if skt: + skt.shutdown(socket.SHUT_WR) + except socket.error: + pass + if skt: + skt.close() + # ignore socket errors, warn about coding errors + if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): + raise + return + # owner is now responsible for the socket + owner._negotiation_finished(skt) # note thread + +def handle(skt, client_address, owner, send_handshake, expected_handshake): + try: + #print("handle %r" % (skt,)) + skt.settimeout(TIMEOUT) + skt.send(send_handshake) + got = b"" + # for the receiver, this includes the "go\n" + while len(got) < len(expected_handshake): + more = skt.recv(1) + if not more: + raise BadHandshake("disconnect after merely '%r'" % got) + got += more + if expected_handshake[:len(got)] != got: + raise BadHandshake("got '%r' want '%r'" % + (got, expected_handshake)) + #print("handler negotiation finished %r" % (client_address,)) + except Exception as e: + #print("handler failed %r" % (client_address,)) + try: + # this raises socket.err(EBADF) if the socket was already closed + skt.shutdown(socket.SHUT_WR) + except socket.error: + pass + skt.close() # this appears to be idempotent + # ignore socket errors, warn about coding errors + if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): + raise + return + # owner is now responsible for the socket + owner._negotiation_finished(skt) # note thread + +class MyTCPServer(SocketServer.TCPServer): + allow_reuse_address = True + + def process_request(self, request, client_address): + kc = self.owner._have_transit_key + kc.acquire() + while not self.owner._transit_key: + kc.wait() + # owner._transit_key is either None or set to a value. We don't + # modify it from here, so we can release the condition lock before + # grabbing the key. + kc.release() + + # Once it is set, we can get handler_(send|receive)_handshake, which + # is what we actually care about. + t = threading.Thread(target=handle, + args=(request, client_address, + self.owner, + self.owner.handler_send_handshake, + self.owner.handler_expected_handshake)) + t.daemon = True + t.start() + + + class TransitSender: def __init__(self): self.key = os.urandom(32) @@ -113,96 +205,7 @@ class TransitSender: skt.send("nevermind\n") skt.close() -class BadHandshake(Exception): - pass -def connector(owner, hint, send_handshake, expected_handshake): - if isinstance(hint, type(u"")): - hint = hint.encode("ascii") - addr,port = hint.split(",") - skt = None - try: - skt = socket.create_connection((addr,port)) # timeout here - skt.settimeout(TIMEOUT) - #print("socket(%s) connected" % (hint,)) - skt.send(send_handshake) - got = b"" - while len(got) < len(expected_handshake): - got += skt.recv(1) - if expected_handshake[:len(got)] != got: - raise BadHandshake("got '%r' want '%r' on %s" % - (got, expected_handshake, hint)) - #print("connector ready %r" % (hint,)) - except Exception as e: - try: - if skt: - skt.shutdown(socket.SHUT_WR) - except socket.error: - pass - if skt: - skt.close() - # ignore socket errors, warn about coding errors - if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): - raise - return - # owner is now responsible for the socket - owner._negotiation_finished(skt) # note thread - - - -def handle(skt, client_address, owner, send_handshake, expected_handshake): - try: - #print("handle %r" % (skt,)) - skt.settimeout(TIMEOUT) - skt.send(send_handshake) - got = b"" - # for the receiver, this includes the "go\n" - while len(got) < len(expected_handshake): - more = skt.recv(1) - if not more: - raise BadHandshake("disconnect after merely '%r'" % got) - got += more - if expected_handshake[:len(got)] != got: - raise BadHandshake("got '%r' want '%r'" % - (got, expected_handshake)) - #print("handler negotiation finished %r" % (client_address,)) - except Exception as e: - #print("handler failed %r" % (client_address,)) - try: - # this raises socket.err(EBADF) if the socket was already closed - skt.shutdown(socket.SHUT_WR) - except socket.error: - pass - skt.close() # this appears to be idempotent - # ignore socket errors, warn about coding errors - if not isinstance(e, (socket.error, socket.timeout, BadHandshake)): - raise - return - # owner is now responsible for the socket - owner._negotiation_finished(skt) # note thread - -class MyTCPServer(SocketServer.TCPServer): - allow_reuse_address = True - - def process_request(self, request, client_address): - kc = self.owner._have_transit_key - kc.acquire() - while not self.owner._transit_key: - kc.wait() - # owner._transit_key is either None or set to a value. We don't - # modify it from here, so we can release the condition lock before - # grabbing the key. - kc.release() - - # Once it is set, we can get handler_(send|receive)_handshake, which - # is what we actually care about. - t = threading.Thread(target=handle, - args=(request, client_address, - self.owner, - self.owner.handler_send_handshake, - self.owner.handler_expected_handshake)) - t.daemon = True - t.start() class TransitReceiver: def __init__(self):