receive: fetch channel list before completion, to get welcome message

This commit is contained in:
Brian Warner 2015-11-15 10:53:13 -08:00
parent 7426097ba5
commit 6956f35e9a
2 changed files with 24 additions and 7 deletions

View File

@ -135,7 +135,10 @@ class ChannelManager:
r = requests.get(self._relay_url+"list?%s" % queryargs, r = requests.get(self._relay_url+"list?%s" % queryargs,
timeout=self._timeout) timeout=self._timeout)
r.raise_for_status() r.raise_for_status()
channelids = r.json()["channelids"] data = r.json()
if "welcome" in data:
self._handle_welcome(data["welcome"])
channelids = data["channelids"]
return channelids return channelids
def allocate(self): def allocate(self):
@ -241,7 +244,12 @@ class Wormhole:
def input_code(self, prompt="Enter wormhole code: ", code_length=2): def input_code(self, prompt="Enter wormhole code: ", code_length=2):
lister = self._channel_manager.list_channels lister = self._channel_manager.list_channels
code = codes.input_code_with_completion(prompt, lister, # fetch the list of channels ahead of time, to give us a chance to
# discover the welcome message (and warn the user about an obsolete
# client)
initial_channelids = lister()
code = codes.input_code_with_completion(prompt,
initial_channelids, lister,
code_length) code_length)
return code return code

View File

@ -21,12 +21,20 @@ import readline
#import sys #import sys
class CodeInputter: class CodeInputter:
def __init__(self, get_channel_ids, code_length): def __init__(self, initial_channelids, get_channel_ids, code_length):
self.get_channel_ids = get_channel_ids self._initial_channelids = initial_channelids
self._get_channel_ids = get_channel_ids
self.code_length = code_length self.code_length = code_length
self.last_text = None # memoize for a speedup self.last_text = None # memoize for a speedup
self.last_matches = None self.last_matches = None
def get_current_channel_ids(self):
if self._initial_channelids is not None:
channelids = self._initial_channelids
self._initial_channelids = None
return channelids
return self._get_channel_ids()
def wrap_completer(self, text, state): def wrap_completer(self, text, state):
try: try:
return self.completer(text, state) return self.completer(text, state)
@ -52,7 +60,7 @@ class CodeInputter:
#print(" old matches", len(matches), file=sys.stderr) #print(" old matches", len(matches), file=sys.stderr)
else: else:
if len(pieces) <= 1: if len(pieces) <= 1:
channel_ids = self.get_channel_ids() channel_ids = self.get_current_channel_ids()
matches = [str(channel_id) for channel_id in channel_ids matches = [str(channel_id) for channel_id in channel_ids
if str(channel_id).startswith(last)] if str(channel_id).startswith(last)]
else: else:
@ -76,8 +84,9 @@ class CodeInputter:
return match return match
def input_code_with_completion(prompt, get_channel_ids, code_length): def input_code_with_completion(prompt, initial_channelids, get_channel_ids,
c = CodeInputter(get_channel_ids, code_length) code_length):
c = CodeInputter(initial_channelids, get_channel_ids, code_length)
readline.parse_and_bind("tab: complete") readline.parse_and_bind("tab: complete")
readline.set_completer(c.wrap_completer) readline.set_completer(c.wrap_completer)
readline.set_completer_delims("") readline.set_completer_delims("")