From 18f7ab93086015a2759d67fb740ea3a24600bbdd Mon Sep 17 00:00:00 2001 From: Brian Warner Date: Sun, 18 Dec 2016 21:20:26 -0800 Subject: [PATCH] more state-machine work --- src/wormhole/_connection.py | 413 +++++++++++++++++++++++++----------- 1 file changed, 295 insertions(+), 118 deletions(-) diff --git a/src/wormhole/_connection.py b/src/wormhole/_connection.py index c3c73e6..e50b5a8 100644 --- a/src/wormhole/_connection.py +++ b/src/wormhole/_connection.py @@ -1,7 +1,6 @@ from six.moves.urllib_parse import urlparse from attr import attrs, attrib -from twisted.internet import protocol, reactor from twisted.internet import defer, endpoints #, error from autobahn.twisted import websocket from automat import MethodicalMachine @@ -45,30 +44,26 @@ class WSFactory(websocket.WebSocketClientFactory): return proto -class Dummy(protocol.Protocol): - def connectionMade(self): - print("connectionMade") - reactor.callLater(1.0, self.factory.cm.onConnect, "fake ws") - reactor.callLater(2.0, self.transport.loseConnection) - def connectionLost(self, why): - self.factory.cm.onClose(why) - # pip install (path to automat checkout)[visualize] # automat-visualize wormhole._connection -class WebSocketMachine(object): +# We have one WSRelayClient for each wsurl we know about, and it lasts +# as long as its parent Wormhole does. + +@attrs +class WSRelayClient(object): + _wormhole = attrib() + _ws_url = attrib() + _reactor = attrib() + m = MethodicalMachine() ALLOW_CLOSE = True - def __init__(self, ws_url, reactor): - self._reactor = reactor - self._f = f = WSFactory(ws_url) + def __init__(self): + self._f = f = WSFactory(self._ws_url) f.setProtocolOptions(autoPingInterval=60, autoPingTimeout=600) f.connection_machine = self # calls onOpen and onClose - #self._f = protocol.ClientFactory() - #self._f.cm = self - #self._f.protocol = Dummy - p = urlparse(ws_url) + p = urlparse(self._ws_url) self._ep = self._make_endpoint(p.hostname, p.port or 80) self._connector = None self._done_d = defer.Deferred() @@ -105,16 +100,16 @@ class WebSocketMachine(object): @m.input() def d_errback(self, f): pass ; print("in d_errback", f) @m.input() - def d_cancel(self): pass + def d_cancel(self, f): pass # XXX remove f @m.input() def onOpen(self, ws): pass ; print("in onOpen") @m.input() - def onClose(self, f): pass + def onClose(self, f): pass # XXX maybe remove f @m.input() def expire(self): pass if ALLOW_CLOSE: @m.input() - def close(self): pass + def close(self, f): pass @m.output() def ep_connect(self): @@ -123,20 +118,26 @@ class WebSocketMachine(object): self._d = self._ep.connect(self._f) self._d.addCallbacks(self.d_callback, self.d_errback) @m.output() - def handle_connection(self, ws): - print("handle_connection", ws) - #self._wormhole.new_connection(Connection(ws)) + def add_connection(self, ws): + print("add_connection", ws) + self._connection = WSConnection(ws, self._wormhole.appid, + self._wormhole.side, self) + self._wormhole.add_connection(self._connection) @m.output() - def start_timer(self, f): + def remove_connection(self, f): # XXX remove f + self._wormhole.remove_connection(self._connection) + self._connection = None + @m.output() + def start_timer(self, f): # XXX remove f print("start_timer") self._t = self._reactor.callLater(3.0, self.expire) @m.output() - def cancel_timer(self): + def cancel_timer(self, f): # XXX remove f print("cancel_timer") self._t.cancel() self._t = None @m.output() - def dropConnection(self): + def dropConnection(self, f): # XXX remove f print("dropConnection") self._ws.dropConnection() @m.output() @@ -149,17 +150,19 @@ class WebSocketMachine(object): first_time_connecting.upon(d_errback, enter=failed, outputs=[notify_fail]) first_time_connecting.upon(onClose, enter=failed, outputs=[notify_fail]) if ALLOW_CLOSE: - first_time_connecting.upon(close, enter=disconnecting2, outputs=[d_cancel]) + first_time_connecting.upon(close, enter=disconnecting2, + outputs=[d_cancel]) disconnecting2.upon(d_errback, enter=closed, outputs=[]) - negotiating.upon(onOpen, enter=open, outputs=[handle_connection]) + negotiating.upon(onOpen, enter=open, outputs=[add_connection]) if ALLOW_CLOSE: negotiating.upon(close, enter=disconnecting, outputs=[dropConnection]) negotiating.upon(onClose, enter=failed, outputs=[notify_fail]) - open.upon(onClose, enter=waiting, outputs=[start_timer]) + open.upon(onClose, enter=waiting, outputs=[remove_connection, start_timer]) if ALLOW_CLOSE: - open.upon(close, enter=disconnecting, outputs=[dropConnection]) + open.upon(close, enter=disconnecting, + outputs=[dropConnection, remove_connection]) connecting.upon(d_callback, enter=negotiating, outputs=[]) connecting.upon(d_errback, enter=waiting, outputs=[start_timer]) connecting.upon(onClose, enter=waiting, outputs=[start_timer]) @@ -172,7 +175,7 @@ class WebSocketMachine(object): disconnecting.upon(onClose, enter=closed, outputs=[]) def tryit(reactor): - cm = WebSocketMachine("ws://127.0.0.1:4000/v1", reactor) + cm = WSRelayClient(None, "ws://127.0.0.1:4000/v1", reactor) print("_ConnectionMachine created") print("start:", cm.start()) print("waiting on _done_d to finish") @@ -202,12 +205,14 @@ if __name__ == "__main__": from twisted.internet.task import react react(tryit) +# a new WSConnection is created each time the WSRelayClient gets through +# negotiation @attrs -class Connection(object): +class WSConnection(object): _ws = attrib() _appid = attrib() _side = attrib() - _ws_machine = attrib() + _wsrc = attrib() m = MethodicalMachine() @m.state(initial=True) @@ -232,13 +237,13 @@ class Connection(object): @m.input() def ack_bind(self): pass @m.input() - def c_set_nameplate(self): pass + def wsc_set_nameplate(self): pass @m.input() - def c_set_mailbox(self, mailbox): pass + def wsc_set_mailbox(self, mailbox): pass @m.input() - def c_remove_nameplate(self): pass + def wsc_release_nameplate(self): pass @m.input() - def c_remove_mailbox(self): pass + def wsc_release_mailbox(self): pass @m.input() def ack_close(self): pass @@ -248,34 +253,32 @@ class Connection(object): @m.output() def notify_bound(self): self._nameplate_machine.bound() + self._connection.make_listing_machine() @m.output() def m_set_mailbox(self, mailbox): self._mailbox_machine.m_set_mailbox(mailbox) @m.output() def request_close(self): - self._ws_machine.close() + self._wsrc.close() @m.output() def notify_close(self): pass unbound.upon(bind, enter=binding, outputs=[send_bind]) binding.upon(ack_bind, enter=neither, outputs=[notify_bound]) - neither.upon(c_set_nameplate, enter=has_nameplate, outputs=[]) - neither.upon(c_set_mailbox, enter=has_mailbox, outputs=[m_set_mailbox]) - has_nameplate.upon(c_set_mailbox, enter=has_both, outputs=[m_set_mailbox]) - has_nameplate.upon(c_remove_nameplate, enter=closing, outputs=[request_close]) - has_mailbox.upon(c_set_nameplate, enter=has_both, outputs=[]) - has_mailbox.upon(c_remove_mailbox, enter=closing, outputs=[request_close]) - has_both.upon(c_remove_nameplate, enter=has_mailbox, outputs=[]) - has_both.upon(c_remove_mailbox, enter=has_nameplate, outputs=[]) + neither.upon(wsc_set_nameplate, enter=has_nameplate, outputs=[]) + neither.upon(wsc_set_mailbox, enter=has_mailbox, outputs=[m_set_mailbox]) + has_nameplate.upon(wsc_set_mailbox, enter=has_both, outputs=[m_set_mailbox]) + has_nameplate.upon(wsc_release_nameplate, enter=closing, outputs=[request_close]) + has_mailbox.upon(wsc_set_nameplate, enter=has_both, outputs=[]) + has_mailbox.upon(wsc_release_mailbox, enter=closing, outputs=[request_close]) + has_both.upon(wsc_release_nameplate, enter=has_mailbox, outputs=[]) + has_both.upon(wsc_release_mailbox, enter=has_nameplate, outputs=[]) closing.upon(ack_close, enter=closed, outputs=[]) class NameplateMachine(object): m = MethodicalMachine() - def bound(self): - pass - @m.state(initial=True) def unclaimed(self): pass # but bound @m.state() @@ -284,112 +287,135 @@ class NameplateMachine(object): def claimed(self): pass @m.state() def releasing(self): pass + @m.state(terminal=True) + def done(self): pass - @m.input() - def list_nameplates(self): pass - @m.input() - def got_nameplates(self, nameplates): pass # response("nameplates") @m.input() def learned_nameplate(self, nameplate): """Call learned_nameplate() when you learn the nameplate: either through allocation or code entry""" pass @m.input() - def claim_acked(self, mailbox): pass # response("claimed") + def rx_claimed(self, mailbox): pass # response("claimed") @m.input() - def release(self): pass + def nm_release_nameplate(self): pass @m.input() def release_acked(self): pass # response("released") - @m.output() - def send_list_nameplates(self): - self._ws_send_command("list") - @m.output() - def notify_nameplates(self, nameplates): - # tell somebody - pass @m.output() def send_claim(self, nameplate): self._ws_send_command("claim", nameplate=nameplate) @m.output() - def c_set_nameplate(self, mailbox): - self._connection_machine.set_nameplate() + def wsc_set_nameplate(self, mailbox): + self._connection_machine.wsc_set_nameplate() @m.output() - def c_set_mailbox(self, mailbox): - self._connection_machine.set_mailbox() + def wsc_set_mailbox(self, mailbox): + self._connection_machine.wsc_set_mailbox() + @m.output() + def mm_set_mailbox(self, mailbox): + self._mm.mm_set_mailbox() @m.output() def send_release(self): self._ws_send_command("release") @m.output() - def notify_released(self): + def wsc_release_nameplate(self): # let someone know, when both the mailbox and the nameplate are # released, the websocket can be closed, and we're done - pass + self._wsc.wsc_release_nameplate() - unclaimed.upon(list_nameplates, enter=unclaimed, outputs=[send_list_nameplates]) - unclaimed.upon(got_nameplates, enter=unclaimed, outputs=[notify_nameplates]) unclaimed.upon(learned_nameplate, enter=claiming, outputs=[send_claim]) - claiming.upon(claim_acked, enter=claimed, outputs=[c_set_nameplate, - c_set_mailbox]) - claiming.upon(learned_nameplate, enter=claiming, outputs=[]) - claimed.upon(release, enter=releasing, outputs=[send_release]) - claimed.upon(learned_nameplate, enter=claimed, outputs=[]) - releasing.upon(release, enter=releasing, outputs=[]) - releasing.upon(release_acked, enter=unclaimed, outputs=[notify_released]) - releasing.upon(learned_nameplate, enter=releasing, outputs=[]) + claiming.upon(rx_claimed, enter=claimed, outputs=[wsc_set_nameplate, + mm_set_mailbox, + wsc_set_mailbox]) + #claiming.upon(learned_nameplate, enter=claiming, outputs=[]) + claimed.upon(nm_release_nameplate, enter=releasing, outputs=[send_release]) + #claimed.upon(learned_nameplate, enter=claimed, outputs=[]) + #releasing.upon(release, enter=releasing, outputs=[]) + releasing.upon(release_acked, enter=done, outputs=[wsc_release_nameplate]) + #releasing.upon(learned_nameplate, enter=releasing, outputs=[]) +class NameplateListingMachine(object): + m = MethodicalMachine() + def __init__(self): + self._list_nameplate_waiters = [] + # Ideally, each API request would spawn a new "list_nameplates" message + # to the server, so the response would be maximally fresh, but that would + # require correlating server request+response messages, and the protocol + # is intended to be less stateful than that. So we offer a weaker + # freshness property: if no server requests are in flight, then a new API + # request will provoke a new server request, and the result will be + # fresh. But if a server request is already in flight when a second API + # request arrives, both requests will be satisfied by the same response. + + @m.state(initial=True) + def idle(self): pass + @m.state() + def requesting(self): pass + + @m.input() + def list_nameplates(self): pass # returns Deferred + @m.input() + def response(self, message): pass + + @m.output() + def add_deferred(self): + d = defer.Deferred() + self._list_nameplate_waiters.append(d) + return d + @m.output() + def send_request(self): + self._connection.send_command("list") + @m.output() + def distribute_response(self, message): + nameplates = parse(message) + waiters = self._list_nameplate_waiters + self._list_nameplate_waiters = [] + for d in waiters: + d.callback(nameplates) + + idle.upon(list_nameplates, enter=requesting, + outputs=[add_deferred, send_request], + collector=lambda outs: outs[0]) + idle.upon(response, enter=idle, outputs=[]) + requesting.upon(list_nameplates, enter=requesting, + outputs=[add_deferred], + collector=lambda outs: outs[0]) + requesting.upon(response, enter=idle, outputs=[distribute_response]) + + # nlm._connection = c = Connection(ws) + # nlm.list_nameplates().addCallback(display_completions) + # c.register_dispatch("nameplates", nlm.response) class MailboxMachine(object): m = MethodicalMachine() @m.state() - def closed(initial=True): pass + def unknown(initial=True): pass @m.state() - def open(): pass + def mailbox_unused(): pass @m.state() - def key_established(): pass - @m.state() - def key_verified(): pass + def mailbox_used(): pass @m.input() - def m_set_code(self, code): pass + def mm_set_mailbox(self, mailbox): pass + @m.input() + def add_connection(self, connection): pass + @m.input() + def rx_message(self): pass - @m.input() - def m_set_mailbox(self, mailbox): - """Call m_set_mailbox() when you learn the mailbox id, either from - the response to claim_nameplate, or because we started from a - Wormhole Seed""" - pass - @m.input() - def message_pake(self, pake): pass # reponse["message"][phase=pake] - @m.input() - def message_version(self, version): # response["message"][phase=version] - pass - @m.input() - def message_app(self, msg): # response["message"][phase=\d+] - pass @m.input() def close(self): pass @m.output() - def send_pake(self, pake): - self._ws_send_command("add", phase="pake", body=XXX(pake)) + def open_mailbox(self): + self._mm.mm_set_mailbox(self._mailbox) @m.output() - def send_version(self, pake): # XXX remove pake= - plaintext = dict_to_bytes(self._my_versions) - phase = "version" - data_key = self._derive_phase_key(self._side, phase) - encrypted = self._encrypt_data(data_key, plaintext) - self._msg_send(phase, encrypted) + def nm_release_nameplate(self): + self._nm.nm_release_nameplate() @m.output() - def c_remove_mailbox(self): - self._connection.c_remove_mailbox() - - # decrypt, deliver up to app - - - + def wsc_release_mailbox(self): + self._wsc.wsc_release_mailbox() @m.output() def open_mailbox(self, mailbox): self._ws_send_command("open", mailbox=mailbox) @@ -398,8 +424,159 @@ class MailboxMachine(object): def close_mailbox(self, mood): self._ws_send_command("close", mood=mood) - closed.upon(m_set_mailbox, enter=open, outputs=[open_mailbox]) - open.upon(message_pake, enter=key_established, outputs=[send_pake, - send_version]) - key_established.upon(message_version, enter=key_verified, outputs=[]) - key_verified.upon(close, enter=closed, outputs=[c_remove_mailbox]) + unknown.upon(mm_set_mailbox, enter=mailbox_unused, outputs=[open_mailbox]) + mailbox_unused.upon(rx_message, enter=mailbox_used, + outputs=[nm_release_nameplate]) + #open.upon(message_pake, enter=key_established, outputs=[send_pake, + # send_version]) + #key_established.upon(message_version, enter=key_verified, outputs=[]) + #key_verified.upon(close, enter=closed, outputs=[wsc_release_mailbox]) + +class Wormhole: + m = MethodicalMachine() + + def __init__(self, ws_url, reactor): + self._relay_client = WSRelayClient(self, ws_url, reactor) + # This records all the messages we want the relay to have. Each time + # we establish a connection, we'll send them all (and the relay + # server will filter out duplicates). If we add any while a + # connection is established, we'll send the new ones. + self._outbound_messages = [] + + def start(self): + self._relay_client.start() + + @m.state() + def closed(initial=True): pass + @m.state() + def know_code_not_mailbox(): pass + @m.state() + def know_code_and_mailbox(): pass # no longer need nameplate + @m.state() + def waiting_to_verify(): pass # key is established, want any message + @m.state() + def open(): pass # key is verified, can post app messages + @m.state(terminal=True) + def failed(): pass + + @m.input() + def deliver_message(self, message): pass + + def w_set_seed(self, code, mailbox): + """Call w_set_seed when we sprout a Wormhole Seed, which + contains both the code and the mailbox""" + self.w_set_code(code) + self.w_set_mailbox(mailbox) + + @m.input() + def w_set_code(self, code): + """Call w_set_code when you learn the code, probably because the user + typed it in.""" + @m.input() + def w_set_mailbox(self, mailbox): + """Call w_set_mailbox() when you learn the mailbox id, from the + response to claim_nameplate""" + pass + + + @m.input() + def rx_pake(self, pake): pass # reponse["message"][phase=pake] + def rx_version(self, version): # response["message"][phase=version] + their_verifier = com + if OK: + self.verify_good(verifier) + else: + self.verify_bad(f) + pass + + @m.input() + def verify_good(self, verifier): pass + @m.input() + def verify_bad(self, f): pass + + @m.output() + def compute_and_post_pake(self, code): + self._code = code + self._pake = compute(code) + self._post(pake=self._pake) + self._ws_send_command("add", phase="pake", body=XXX(pake)) + @m.output() + def set_mailbox(self, mailbox): + self._mailbox = mailbox + @m.output() + def set_seed(self, code, mailbox): + self._code = code + self._mailbox = mailbox + + @m.output() + def deliver_message(self, message): + self._qc.deliver_message(message) + + @m.output() + def notify_verified(self, verifier): + for d in self._verify_waiters: + d.callback(verifier) + @m.output() + def notify_failed(self, f): + for d in self._verify_waiters: + d.errback(f) + + @m.output() + def compute_key_and_post_version(self, pake): + self._key = x + self._verifier = x + plaintext = dict_to_bytes(self._my_versions) + phase = "version" + data_key = self._derive_phase_key(self._side, phase) + encrypted = self._encrypt_data(data_key, plaintext) + self._msg_send(phase, encrypted) + + closed.upon(w_set_code, enter=know_code_not_mailbox, + outputs=[compute_and_post_pake]) + know_code_not_mailbox.upon(w_set_mailbox, enter=know_code_and_mailbox, + outputs=[set_mailbox]) + know_code_and_mailbox.upon(rx_pake, enter=waiting_to_verify, + outputs=[compute_key_and_post_version]) + waiting_to_verify.upon(verify_good, enter=open, outputs=[notify_verified]) + waiting_to_verify.upon(verify_bad, enter=failed, outputs=[notify_failed]) + +class QueueConnect: + m = MethodicalMachine() + def __init__(self): + self._outbound_messages = [] + self._connection = None + @m.state() + def disconnected(): pass + @m.state() + def connected(): pass + + @m.input() + def deliver_message(self, message): pass + @m.input() + def connect(self, connection): pass + @m.input() + def disconnect(self): pass + + @m.output() + def remember_connection(self, connection): + self._connection = connection + @m.output() + def forget_connection(self): + self._connection = None + @m.output() + def queue_message(self, message): + self._outbound_messages.append(message) + @m.output() + def send_message(self, message): + self._connection.send(message) + @m.output() + def send_queued_messages(self, connection): + for m in self._outbound_messages: + connection.send(m) + + disconnected.upon(deliver_message, enter=disconnected, outputs=[queue_message]) + disconnected.upon(connect, enter=connected, outputs=[remember_connection, + send_queued_messages]) + connected.upon(deliver_message, enter=connected, + outputs=[queue_message, send_message]) + connected.upon(disconnect, enter=disconnected, outputs=[forget_connection])