Merge branch '280-input-threading'

closes #280
This commit is contained in:
Brian Warner 2018-02-14 02:12:04 -08:00
commit 593b359166
5 changed files with 42 additions and 7 deletions

View File

@ -108,14 +108,16 @@ class Boss(object):
def _set_trace(self, client_name, which, file): def _set_trace(self, client_name, which, file):
names = {"B": self, "N": self._N, "M": self._M, "S": self._S, names = {"B": self, "N": self._N, "M": self._M, "S": self._S,
"O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R, "O": self._O, "K": self._K, "SK": self._K._SK, "R": self._R,
"RC": self._RC, "L": self._L, "C": self._C, "RC": self._RC, "L": self._L, "A": self._A, "I": self._I,
"T": self._T} "C": self._C, "T": self._T}
for machine in which.split(): for machine in which.split():
t = (lambda old_state, input, new_state, machine=machine: t = (lambda old_state, input, new_state, machine=machine:
self._print_trace(old_state, input, new_state, self._print_trace(old_state, input, new_state,
client_name=client_name, client_name=client_name,
machine=machine, file=file)) machine=machine, file=file))
names[machine].set_trace(t) names[machine].set_trace(t)
if machine == "I":
self._I.set_debug(t)
## def serialize(self): ## def serialize(self):
## raise NotImplemented ## raise NotImplemented

View File

@ -1,4 +1,8 @@
from __future__ import print_function, absolute_import, unicode_literals from __future__ import print_function, absolute_import, unicode_literals
# We use 'threading' defensively here, to detect if we're being called from a
# non-main thread. _rlcompleter.py is the only internal Wormhole code that
# deliberately creates a new thread.
import threading
from zope.interface import implementer from zope.interface import implementer
from attr import attrs, attrib from attr import attrs, attrib
from attr.validators import provides from attr.validators import provides
@ -22,6 +26,13 @@ class Input(object):
self._nameplate = None self._nameplate = None
self._wordlist = None self._wordlist = None
self._wordlist_waiters = [] self._wordlist_waiters = []
self._trace = None
def set_debug(self, f):
self._trace = f
def _debug(self, what):
if self._trace:
self._trace(old_state="", input=what, new_state="")
def wire(self, code, lister): def wire(self, code, lister):
self._C = _interfaces.ICode(code) self._C = _interfaces.ICode(code)
@ -233,15 +244,28 @@ class Input(object):
class Helper(object): class Helper(object):
_input = attrib() _input = attrib()
def __attrs_post_init__(self):
self._main_thread = threading.current_thread().ident
def refresh_nameplates(self): def refresh_nameplates(self):
assert threading.current_thread().ident == self._main_thread
self._input.refresh_nameplates() self._input.refresh_nameplates()
def get_nameplate_completions(self, prefix): def get_nameplate_completions(self, prefix):
assert threading.current_thread().ident == self._main_thread
return self._input.get_nameplate_completions(prefix) return self._input.get_nameplate_completions(prefix)
def choose_nameplate(self, nameplate): def choose_nameplate(self, nameplate):
assert threading.current_thread().ident == self._main_thread
self._input._debug("I.choose_nameplate")
self._input.choose_nameplate(nameplate) self._input.choose_nameplate(nameplate)
self._input._debug("I.choose_nameplate finished")
def when_wordlist_is_available(self): def when_wordlist_is_available(self):
assert threading.current_thread().ident == self._main_thread
return self._input.when_wordlist_is_available() return self._input.when_wordlist_is_available()
def get_word_completions(self, prefix): def get_word_completions(self, prefix):
assert threading.current_thread().ident == self._main_thread
return self._input.get_word_completions(prefix) return self._input.get_word_completions(prefix)
def choose_words(self, words): def choose_words(self, words):
assert threading.current_thread().ident == self._main_thread
self._input._debug("I.choose_words")
self._input.choose_words(words) self._input.choose_words(words)
self._input._debug("I.choose_words finished")

View File

@ -134,11 +134,13 @@ class CodeInputter(object):
raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate) raise AlreadyInputNameplateError("nameplate (%s-) already entered, cannot go back" % self._committed_nameplate)
else: else:
debug(" choose_nameplate(%s)" % nameplate) debug(" choose_nameplate(%s)" % nameplate)
self._input_helper.choose_nameplate(nameplate) self.bcft(self._input_helper.choose_nameplate, nameplate)
debug(" choose_words(%s)" % words) debug(" choose_words(%s)" % words)
self._input_helper.choose_words(words) self.bcft(self._input_helper.choose_words, words)
def _input_code_with_completion(prompt, input_helper, reactor): def _input_code_with_completion(prompt, input_helper, reactor):
# reminder: this all occurs in a separate thread. All calls to input_helper
# must go through blockingCallFromThread()
c = CodeInputter(input_helper, reactor) c = CodeInputter(input_helper, reactor)
if readline is not None: if readline is not None:
if readline.__doc__ and "libedit" in readline.__doc__: if readline.__doc__ and "libedit" in readline.__doc__:

View File

@ -147,11 +147,15 @@ def get_completions(c, prefix):
return completions return completions
completions.append(text) completions.append(text)
def fake_blockingCallFromThread(f, *a, **kw):
return f(*a, **kw)
class Completion(unittest.TestCase): class Completion(unittest.TestCase):
def test_simple(self): def test_simple(self):
# no actual completion # no actual completion
helper = mock.Mock() helper = mock.Mock()
c = CodeInputter(helper, "reactor") c = CodeInputter(helper, "reactor")
c.bcft = fake_blockingCallFromThread
c.finish("1-code-ghost") c.finish("1-code-ghost")
self.assertFalse(c.used_completion) self.assertFalse(c.used_completion)
self.assertEqual(helper.mock_calls, self.assertEqual(helper.mock_calls,
@ -164,6 +168,7 @@ class Completion(unittest.TestCase):
# check that it calls _commit_and_build_completions correctly # check that it calls _commit_and_build_completions correctly
helper = mock.Mock() helper = mock.Mock()
c = CodeInputter(helper, "reactor") c = CodeInputter(helper, "reactor")
c.bcft = fake_blockingCallFromThread
# pretend nameplates: 1, 12, 34 # pretend nameplates: 1, 12, 34
@ -304,12 +309,13 @@ class Completion(unittest.TestCase):
self.assertEqual(gwc.mock_calls, [mock.call("and-b")]) self.assertEqual(gwc.mock_calls, [mock.call("and-b")])
gwc.reset_mock() gwc.reset_mock()
c.finish("12-and-bat") yield deferToThread(c.finish, "12-and-bat")
self.assertEqual(cw.mock_calls, [mock.call("and-bat")]) self.assertEqual(cw.mock_calls, [mock.call("and-bat")])
def test_incomplete_code(self): def test_incomplete_code(self):
helper = mock.Mock() helper = mock.Mock()
c = CodeInputter(helper, "reactor") c = CodeInputter(helper, "reactor")
c.bcft = fake_blockingCallFromThread
with self.assertRaises(KeyFormatError) as e: with self.assertRaises(KeyFormatError) as e:
c.finish("1") c.finish("1")
self.assertEqual(str(e.exception), "incomplete wormhole code") self.assertEqual(str(e.exception), "incomplete wormhole code")
@ -349,7 +355,7 @@ class Completion(unittest.TestCase):
self.assertEqual(matches, ["1-code", "1-court"]) self.assertEqual(matches, ["1-code", "1-court"])
helper.reset_mock() helper.reset_mock()
with self.assertRaises(AlreadyInputNameplateError) as e: with self.assertRaises(AlreadyInputNameplateError) as e:
c.finish("2-code") yield deferToThread(c.finish, "2-code")
self.assertEqual(str(e.exception), self.assertEqual(str(e.exception),
"nameplate (1-) already entered, cannot go back") "nameplate (1-) already entered, cannot go back")
self.assertEqual(helper.mock_calls, []) self.assertEqual(helper.mock_calls, [])

View File

@ -214,7 +214,8 @@ class _DeferredWormhole(object):
self._boss.close() # only need to close if it wasn't already self._boss.close() # only need to close if it wasn't already
return d return d
def debug_set_trace(self, client_name, which="B N M S O K SK R RC L C T", def debug_set_trace(self, client_name,
which="B N M S O K SK R RC L A I C T",
file=sys.stderr): file=sys.stderr):
self._boss._set_trace(client_name, which, file) self._boss._set_trace(client_name, which, file)