diff --git a/src/wormhole/blocking/transcribe.py b/src/wormhole/blocking/transcribe.py index d66f5b4..88d7757 100644 --- a/src/wormhole/blocking/transcribe.py +++ b/src/wormhole/blocking/transcribe.py @@ -135,7 +135,10 @@ class ChannelManager: r = requests.get(self._relay_url+"list?%s" % queryargs, timeout=self._timeout) r.raise_for_status() - channelids = r.json()["channelids"] + data = r.json() + if "welcome" in data: + self._handle_welcome(data["welcome"]) + channelids = data["channelids"] return channelids def allocate(self): @@ -241,7 +244,12 @@ class Wormhole: def input_code(self, prompt="Enter wormhole code: ", code_length=2): lister = self._channel_manager.list_channels - code = codes.input_code_with_completion(prompt, lister, + # fetch the list of channels ahead of time, to give us a chance to + # discover the welcome message (and warn the user about an obsolete + # client) + initial_channelids = lister() + code = codes.input_code_with_completion(prompt, + initial_channelids, lister, code_length) return code diff --git a/src/wormhole/codes.py b/src/wormhole/codes.py index 0a49890..3b1ff79 100644 --- a/src/wormhole/codes.py +++ b/src/wormhole/codes.py @@ -21,12 +21,20 @@ import readline #import sys class CodeInputter: - def __init__(self, get_channel_ids, code_length): - self.get_channel_ids = get_channel_ids + def __init__(self, initial_channelids, get_channel_ids, code_length): + self._initial_channelids = initial_channelids + self._get_channel_ids = get_channel_ids self.code_length = code_length self.last_text = None # memoize for a speedup self.last_matches = None + def get_current_channel_ids(self): + if self._initial_channelids is not None: + channelids = self._initial_channelids + self._initial_channelids = None + return channelids + return self._get_channel_ids() + def wrap_completer(self, text, state): try: return self.completer(text, state) @@ -52,7 +60,7 @@ class CodeInputter: #print(" old matches", len(matches), file=sys.stderr) else: if len(pieces) <= 1: - channel_ids = self.get_channel_ids() + channel_ids = self.get_current_channel_ids() matches = [str(channel_id) for channel_id in channel_ids if str(channel_id).startswith(last)] else: @@ -76,8 +84,9 @@ class CodeInputter: return match -def input_code_with_completion(prompt, get_channel_ids, code_length): - c = CodeInputter(get_channel_ids, code_length) +def input_code_with_completion(prompt, initial_channelids, get_channel_ids, + code_length): + c = CodeInputter(initial_channelids, get_channel_ids, code_length) readline.parse_and_bind("tab: complete") readline.set_completer(c.wrap_completer) readline.set_completer_delims("")