diff --git a/src/wormhole/servers/cmd_server.py b/src/wormhole/servers/cmd_server.py index 3f4c0cb..0110cb1 100644 --- a/src/wormhole/servers/cmd_server.py +++ b/src/wormhole/servers/cmd_server.py @@ -8,7 +8,7 @@ class MyPlugin: def makeService(self, so): # delay this import as late as possible, to allow twistd's code to # accept --reactor= selection - from .relay import RelayServer + from .server import RelayServer return RelayServer(self.args.rendezvous, self.args.transit, self.args.advertise_version, "relay.sqlite") diff --git a/src/wormhole/servers/relay.py b/src/wormhole/servers/relay.py index 00a6605..d4aa598 100644 --- a/src/wormhole/servers/relay.py +++ b/src/wormhole/servers/relay.py @@ -1,12 +1,7 @@ from __future__ import print_function import re, json, time, random from twisted.python import log -from twisted.internet import reactor, protocol, endpoints -from twisted.application import service, internet -from twisted.web import server, static, resource, http -from ..util.endpoint_service import ServerEndpointService -from .. import __version__ -from ..database import get_db +from twisted.web import server, resource, http SECONDS = 1.0 MINUTE = 60*SECONDS @@ -15,7 +10,6 @@ DAY = 24*HOUR MB = 1000*1000 CHANNEL_EXPIRATION_TIME = 3*DAY -EXPIRATION_CHECK_PERIOD = 2*HOUR class EventsProtocol: def __init__(self, request): @@ -255,161 +249,3 @@ class Relay(resource.Resource): log.msg("expiring %d" % channel_id) self.free_child(channel_id) -class TransitConnection(protocol.Protocol): - def __init__(self): - self.got_token = False - self.token_buffer = b"" - self.sent_ok = False - self.buddy = None - self.total_sent = 0 - - def dataReceived(self, data): - if self.sent_ok: - # TODO: connect as producer/consumer - self.total_sent += len(data) - self.buddy.transport.write(data) - return - if self.got_token: # but not yet sent_ok - self.transport.write("impatient\n") - print("transit impatience failure") - return self.disconnect() # impatience yields failure - # else this should be (part of) the token - self.token_buffer += data - buf = self.token_buffer - wanted = len("please relay \n")+32*2 - if len(buf) < wanted-1 and "\n" in buf: - self.transport.write("bad handshake\n") - print("transit handshake early failure") - return self.disconnect() - if len(buf) < wanted: - return - if len(buf) > wanted: - self.transport.write("impatient\n") - print("transit impatience failure") - return self.disconnect() # impatience yields failure - mo = re.search(r"^please relay (\w{64})\n", buf, re.M) - if not mo: - self.transport.write("bad handshake\n") - print("transit handshake failure") - return self.disconnect() # incorrectness yields failure - token = mo.group(1) - - self.got_token = True - self.factory.connection_got_token(token, self) - - def buddy_connected(self, them): - self.buddy = them - self.transport.write(b"ok\n") - self.sent_ok = True - # TODO: connect as producer/consumer - - def buddy_disconnected(self): - print("buddy_disconnected %r" % self) - self.buddy = None - self.transport.loseConnection() - - def connectionLost(self, reason): - print("connectionLost %r %s" % (self, reason)) - if self.buddy: - self.buddy.buddy_disconnected() - self.factory.transitFinished(self, self.total_sent) - - def disconnect(self): - self.transport.loseConnection() - self.factory.transitFailed(self) - -class Transit(protocol.ServerFactory, service.MultiService): - # I manage pairs of simultaneous connections to a secondary TCP port, - # both forwarded to the other. Clients must begin each connection with - # "please relay TOKEN\n". I will send "ok\n" when the matching connection - # is established, or disconnect if no matching connection is made within - # MAX_WAIT_TIME seconds. I will disconnect if you send data before the - # "ok\n". All data you get after the "ok\n" will be from the other side. - # You will not receive "ok\n" until the other side has also connected and - # submitted a matching token. The token is the same for each side. - - # In addition, the connections will be dropped after MAXLENGTH bytes have - # been sent by either side, or MAXTIME seconds have elapsed after the - # matching connections were established. A future API will reveal these - # limits to clients instead of causing mysterious spontaneous failures. - - # These relay connections are not half-closeable (unlike full TCP - # connections, applications will not receive any data after half-closing - # their outgoing side). Applications must negotiate shutdown with their - # peer and not close the connection until all data has finished - # transferring in both directions. Applications which only need to send - # data in one direction can use close() as usual. - - MAX_WAIT_TIME = 30*SECONDS - MAXLENGTH = 10*MB - MAXTIME = 60*SECONDS - protocol = TransitConnection - - def __init__(self): - service.MultiService.__init__(self) - self.pending_requests = {} # token -> TransitConnection - self.active_connections = set() # TransitConnection - - def connection_got_token(self, token, p): - if token in self.pending_requests: - print("transit relay 2: %r" % token) - buddy = self.pending_requests.pop(token) - self.active_connections.add(p) - self.active_connections.add(buddy) - p.buddy_connected(buddy) - buddy.buddy_connected(p) - else: - self.pending_requests[token] = p - print("transit relay 1: %r" % token) - # TODO: timer - def transitFinished(self, p, total_sent): - print("transitFinished (%dB) %r" % (total_sent, p)) - for token,tc in self.pending_requests.items(): - if tc is p: - del self.pending_requests[token] - break - self.active_connections.discard(p) - - def transitFailed(self, p): - print("transitFailed %r" % p) - pass - - -class Root(resource.Resource): - # child_FOO is a nevow thing, not a twisted.web.resource thing - def __init__(self): - resource.Resource.__init__(self) - self.putChild(b"", static.Data(b"Wormhole Relay\n", "text/plain")) - -class RelayServer(service.MultiService): - def __init__(self, relayport, transitport, advertise_version, - db_url=":memory:"): - service.MultiService.__init__(self) - self.db = get_db(db_url) - welcome = { - "current_version": __version__, - # adding .motd will cause all clients to display the message, - # then keep running normally - #"motd": "Welcome to the public relay.\nPlease enjoy this service.", - # - # adding .error will cause all clients to fail, with this message - #"error": "This server has been disabled, see URL for details.", - } - if advertise_version: - welcome["current_version"] = advertise_version - self.root = Root() - site = server.Site(self.root) - r = endpoints.serverFromString(reactor, relayport) - self.relayport_service = ServerEndpointService(r, site) - self.relayport_service.setServiceParent(self) - self.relay = Relay(self.db, welcome) # accessible from tests - self.root.putChild(b"wormhole-relay", self.relay) - t = internet.TimerService(EXPIRATION_CHECK_PERIOD, - self.relay.prune_old_channels) - t.setServiceParent(self) - if transitport: - self.transit = Transit() - self.transit.setServiceParent(self) # for the timer - t = endpoints.serverFromString(reactor, transitport) - self.transport_service = ServerEndpointService(t, self.transit) - self.transport_service.setServiceParent(self) diff --git a/src/wormhole/servers/server.py b/src/wormhole/servers/server.py new file mode 100644 index 0000000..c6e660e --- /dev/null +++ b/src/wormhole/servers/server.py @@ -0,0 +1,53 @@ +from __future__ import print_function +from twisted.internet import reactor, endpoints +from twisted.application import service, internet +from twisted.web import server, static, resource +from ..util.endpoint_service import ServerEndpointService +from .. import __version__ +from ..database import get_db +from .relay import Relay +from .transit_server import Transit + +SECONDS = 1.0 +MINUTE = 60*SECONDS +HOUR = 60*MINUTE +EXPIRATION_CHECK_PERIOD = 2*HOUR + +class Root(resource.Resource): + # child_FOO is a nevow thing, not a twisted.web.resource thing + def __init__(self): + resource.Resource.__init__(self) + self.putChild(b"", static.Data(b"Wormhole Relay\n", "text/plain")) + +class RelayServer(service.MultiService): + def __init__(self, relayport, transitport, advertise_version, + db_url=":memory:"): + service.MultiService.__init__(self) + self.db = get_db(db_url) + welcome = { + "current_version": __version__, + # adding .motd will cause all clients to display the message, + # then keep running normally + #"motd": "Welcome to the public relay.\nPlease enjoy this service.", + # + # adding .error will cause all clients to fail, with this message + #"error": "This server has been disabled, see URL for details.", + } + if advertise_version: + welcome["current_version"] = advertise_version + self.root = Root() + site = server.Site(self.root) + r = endpoints.serverFromString(reactor, relayport) + self.relayport_service = ServerEndpointService(r, site) + self.relayport_service.setServiceParent(self) + self.relay = Relay(self.db, welcome) # accessible from tests + self.root.putChild(b"wormhole-relay", self.relay) + t = internet.TimerService(EXPIRATION_CHECK_PERIOD, + self.relay.prune_old_channels) + t.setServiceParent(self) + if transitport: + self.transit = Transit() + self.transit.setServiceParent(self) # for the timer + t = endpoints.serverFromString(reactor, transitport) + self.transport_service = ServerEndpointService(t, self.transit) + self.transport_service.setServiceParent(self) diff --git a/src/wormhole/servers/transit_server.py b/src/wormhole/servers/transit_server.py new file mode 100644 index 0000000..6600752 --- /dev/null +++ b/src/wormhole/servers/transit_server.py @@ -0,0 +1,129 @@ +from __future__ import print_function +import re +from twisted.internet import protocol +from twisted.application import service + +SECONDS = 1.0 +MINUTE = 60*SECONDS +HOUR = 60*MINUTE +DAY = 24*HOUR +MB = 1000*1000 + +class TransitConnection(protocol.Protocol): + def __init__(self): + self.got_token = False + self.token_buffer = b"" + self.sent_ok = False + self.buddy = None + self.total_sent = 0 + + def dataReceived(self, data): + if self.sent_ok: + # TODO: connect as producer/consumer + self.total_sent += len(data) + self.buddy.transport.write(data) + return + if self.got_token: # but not yet sent_ok + self.transport.write("impatient\n") + print("transit impatience failure") + return self.disconnect() # impatience yields failure + # else this should be (part of) the token + self.token_buffer += data + buf = self.token_buffer + wanted = len("please relay \n")+32*2 + if len(buf) < wanted-1 and "\n" in buf: + self.transport.write("bad handshake\n") + print("transit handshake early failure") + return self.disconnect() + if len(buf) < wanted: + return + if len(buf) > wanted: + self.transport.write("impatient\n") + print("transit impatience failure") + return self.disconnect() # impatience yields failure + mo = re.search(r"^please relay (\w{64})\n", buf, re.M) + if not mo: + self.transport.write("bad handshake\n") + print("transit handshake failure") + return self.disconnect() # incorrectness yields failure + token = mo.group(1) + + self.got_token = True + self.factory.connection_got_token(token, self) + + def buddy_connected(self, them): + self.buddy = them + self.transport.write(b"ok\n") + self.sent_ok = True + # TODO: connect as producer/consumer + + def buddy_disconnected(self): + print("buddy_disconnected %r" % self) + self.buddy = None + self.transport.loseConnection() + + def connectionLost(self, reason): + print("connectionLost %r %s" % (self, reason)) + if self.buddy: + self.buddy.buddy_disconnected() + self.factory.transitFinished(self, self.total_sent) + + def disconnect(self): + self.transport.loseConnection() + self.factory.transitFailed(self) + +class Transit(protocol.ServerFactory, service.MultiService): + # I manage pairs of simultaneous connections to a secondary TCP port, + # both forwarded to the other. Clients must begin each connection with + # "please relay TOKEN\n". I will send "ok\n" when the matching connection + # is established, or disconnect if no matching connection is made within + # MAX_WAIT_TIME seconds. I will disconnect if you send data before the + # "ok\n". All data you get after the "ok\n" will be from the other side. + # You will not receive "ok\n" until the other side has also connected and + # submitted a matching token. The token is the same for each side. + + # In addition, the connections will be dropped after MAXLENGTH bytes have + # been sent by either side, or MAXTIME seconds have elapsed after the + # matching connections were established. A future API will reveal these + # limits to clients instead of causing mysterious spontaneous failures. + + # These relay connections are not half-closeable (unlike full TCP + # connections, applications will not receive any data after half-closing + # their outgoing side). Applications must negotiate shutdown with their + # peer and not close the connection until all data has finished + # transferring in both directions. Applications which only need to send + # data in one direction can use close() as usual. + + MAX_WAIT_TIME = 30*SECONDS + MAXLENGTH = 10*MB + MAXTIME = 60*SECONDS + protocol = TransitConnection + + def __init__(self): + service.MultiService.__init__(self) + self.pending_requests = {} # token -> TransitConnection + self.active_connections = set() # TransitConnection + + def connection_got_token(self, token, p): + if token in self.pending_requests: + print("transit relay 2: %r" % token) + buddy = self.pending_requests.pop(token) + self.active_connections.add(p) + self.active_connections.add(buddy) + p.buddy_connected(buddy) + buddy.buddy_connected(p) + else: + self.pending_requests[token] = p + print("transit relay 1: %r" % token) + # TODO: timer + def transitFinished(self, p, total_sent): + print("transitFinished (%dB) %r" % (total_sent, p)) + for token,tc in self.pending_requests.items(): + if tc is p: + del self.pending_requests[token] + break + self.active_connections.discard(p) + + def transitFailed(self, p): + print("transitFailed %r" % p) + pass diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 98b2987..ef5a924 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -1,6 +1,6 @@ from twisted.application import service from ..twisted.util import allocate_ports -from ..servers.relay import RelayServer +from ..servers.server import RelayServer from .. import __version__ class ServerBase: