diff --git a/src/wormhole/_boss.py b/src/wormhole/_boss.py index 530d7ae..a8cf5e4 100644 --- a/src/wormhole/_boss.py +++ b/src/wormhole/_boss.py @@ -17,7 +17,8 @@ from ._rendezvous import RendezvousConnector from ._lister import Lister from ._code import Code from ._terminator import Terminator -from .errors import ServerError, LonelyError, WrongPasswordError, KeyFormatError +from .errors import (ServerError, LonelyError, WrongPasswordError, + KeyFormatError, OnlyOneCodeError) from .util import bytes_to_dict @attrs @@ -62,6 +63,7 @@ class Boss(object): self._C.wire(self, self._RC, self._L) self._T.wire(self, self._RC, self._N, self._M) + self._did_start_code = False self._next_tx_phase = 0 self._next_rx_phase = 0 self._rx_phases = {} # phase -> plaintext @@ -113,12 +115,21 @@ class Boss(object): # Wormhole only knows about this Boss instance, and everything else is # hidden away). def input_code(self, stdio): + if self._did_start_code: + raise OnlyOneCodeError() + self._did_start_code = True self._C.input_code(stdio) def allocate_code(self, code_length): + if self._did_start_code: + raise OnlyOneCodeError() + self._did_start_code = True self._C.allocate_code(code_length) def set_code(self, code): if ' ' in code: raise KeyFormatError("code (%s) contains spaces." % code) + if self._did_start_code: + raise OnlyOneCodeError() + self._did_start_code = True self._C.set_code(code) @m.input() diff --git a/src/wormhole/errors.py b/src/wormhole/errors.py index b6ba419..81fd989 100644 --- a/src/wormhole/errors.py +++ b/src/wormhole/errors.py @@ -54,6 +54,9 @@ class NoTorError(WormholeError): class NoKeyError(WormholeError): """w.derive_key() was called before got_verifier() fired""" +class OnlyOneCodeError(WormholeError): + """Only one w.generate_code/w.set_code/w.type_code may be called""" + class WormholeClosed(Exception): """Deferred-returning API calls errback with WormholeClosed if the wormhole was already closed, or if it closes before a real result can be diff --git a/src/wormhole/test/test_wormhole.py b/src/wormhole/test/test_wormhole.py index b632f84..3033903 100644 --- a/src/wormhole/test/test_wormhole.py +++ b/src/wormhole/test/test_wormhole.py @@ -8,7 +8,8 @@ from twisted.internet.defer import Deferred, gatherResults, inlineCallbacks from .common import ServerBase from .. import wormhole, _rendezvous from ..errors import (WrongPasswordError, WelcomeError, InternalError, - KeyFormatError, WormholeClosed, LonelyError) + KeyFormatError, WormholeClosed, LonelyError, + NoKeyError, OnlyOneCodeError) from spake2 import SPAKE2_Symmetric from ..timing import DebugTiming from ..util import (bytes_to_dict, dict_to_bytes, @@ -979,23 +980,24 @@ class MessageDoubler(_rendezvous.RendezvousConnector): class Errors(ServerBase, unittest.TestCase): @inlineCallbacks - def test_codes_1(self): + def test_derive_key_early(self): w = wormhole.create(APPID, self.relayurl, reactor) # definitely too early - self.assertRaises(InternalError, w.derive_key, "purpose", 12) - - w.set_code("123-purple-elephant") - # code can only be set once - self.assertRaises(InternalError, w.set_code, "123-nope") - yield self.assertFailure(w.when_code(), InternalError) - yield self.assertFailure(w.input_code(), InternalError) - yield w.close() + self.assertRaises(NoKeyError, w.derive_key, "purpose", 12) + yield self.assertFailure(w.close(), LonelyError) @inlineCallbacks - def test_codes_2(self): + def test_multiple_set_code(self): w = wormhole.create(APPID, self.relayurl, reactor) + w.set_code("123-purple-elephant") + # code can only be set once + self.assertRaises(OnlyOneCodeError, w.set_code, "123-nope") + yield self.assertFailure(w.close(), LonelyError) + + @inlineCallbacks + def test_allocate_and_set_code(self): + w = wormhole.create(APPID, self.relayurl, reactor) + w.allocate_code() yield w.when_code() - self.assertRaises(InternalError, w.set_code, "123-nope") - yield self.assertFailure(w.when_code(), InternalError) - yield self.assertFailure(w.input_code(), InternalError) - yield w.close() + self.assertRaises(OnlyOneCodeError, w.set_code, "123-nope") + yield self.assertFailure(w.close(), LonelyError)