diff --git a/src/wormhole/test/common.py b/src/wormhole/test/common.py index 3b4d73b..b2230bd 100644 --- a/src/wormhole/test/common.py +++ b/src/wormhole/test/common.py @@ -8,27 +8,59 @@ from ..cli import cli from ..transit import allocate_tcp_port from wormhole_mailbox_server.server import make_server from wormhole_mailbox_server.web import make_web_server -from wormhole_mailbox_server.database import create_channel_db +from wormhole_mailbox_server.database import create_channel_db, create_usage_db from wormhole_transit_relay.transit_server import Transit -class ServerBase: - def setUp(self): - self._setup_relay(None) +class MyInternetService(service.Service, object): + # like StreamServerEndpointService, but you can retrieve the port + def __init__(self, endpoint, factory): + self.endpoint = endpoint + self.factory = factory + self._port_d = defer.Deferred() + self._lp = None + def startService(self): + super(MyInternetService, self).startService() + d = self.endpoint.listen(self.factory) + def good(lp): + self._lp = lp + self._port_d.callback(lp.getHost().port) + def bad(f): + log.err(f) + self._port_d.errback(f) + d.addCallbacks(good, bad) + + @defer.inlineCallbacks + def stopService(self): + if self._lp: + yield self._lp.stopListening() + + def getPort(self): # only call once! + return self._port_d + +class ServerBase: + @defer.inlineCallbacks + def setUp(self): + yield self._setup_relay(None) + + @defer.inlineCallbacks def _setup_relay(self, error, advertise_version=None): self.sp = service.MultiService() self.sp.startService() # need to talk to twisted team about only using unicode in # endpoints.serverFromString db = create_channel_db(":memory:") + self._usage_db = create_usage_db(":memory:") self._rendezvous = make_server(db, advertise_version=advertise_version, - signal_error=error) + signal_error=error, + usage_db=self._usage_db) ep = endpoints.TCP4ServerEndpoint(reactor, 0, interface="127.0.01") site = make_web_server(self._rendezvous, log_requests=False) - s = internet.StreamServerEndpointService(ep, site) + #self._lp = yield ep.listen(site) + s = MyInternetService(ep, site) s.setServiceParent(self.sp) - self.rdv_ws_port = s.__lp.getHost().port + self.rdv_ws_port = yield s.getPort() self._relay_server = s #self._rendezvous = s._rendezvous self.relayurl = u"ws://127.0.0.1:%d/v1" % self.rdv_ws_port @@ -43,6 +75,7 @@ class ServerBase: internet.StreamServerEndpointService(ep, f).setServiceParent(self.sp) self.transit = u"tcp:127.0.0.1:%d" % self.transitport + @defer.inlineCallbacks def tearDown(self): # Unit tests that spawn a (blocking) client in a thread might still # have threads running at this point, if one is stuck waiting for a @@ -54,34 +87,27 @@ class ServerBase: # XXX FIXME there's something in _noclobber test that's not # waiting for a close, I think -- was pretty relieably getting # unclean-reactor, but adding a slight pause here stops it... - from twisted.internet import reactor tp = reactor.getThreadPool() if not tp.working: - d = defer.succeed(None) - d.addCallback(lambda _: self.sp.stopService()) - d.addCallback(lambda _: task.deferLater(reactor, 0.1, lambda: None)) - return d - return self.sp.stopService() + yield self.sp.stopService() + yield task.deferLater(reactor, 0.1, lambda: None) + defer.returnValue(None) # disconnect all callers d = defer.maybeDeferred(self.sp.stopService) - wait_d = defer.Deferred() # wait a second, then check to see if it worked - reactor.callLater(1.0, wait_d.callback, None) - def _later(res): - if len(tp.working): - log.msg("wormhole.test.common.ServerBase.tearDown:" - " I was unable to convince all threads to exit.") - tp.dumpStats() - print("tearDown warning: threads are still active") - print("This test will probably hang until one of the" - " clients gives up of their own accord.") - else: - log.msg("wormhole.test.common.ServerBase.tearDown:" - " I convinced all threads to exit.") - return d - wait_d.addCallback(_later) - return wait_d + yield task.deferLater(reactor, 1.0, lambda: None) + if len(tp.working): + log.msg("wormhole.test.common.ServerBase.tearDown:" + " I was unable to convince all threads to exit.") + tp.dumpStats() + print("tearDown warning: threads are still active") + print("This test will probably hang until one of the" + " clients gives up of their own accord.") + else: + log.msg("wormhole.test.common.ServerBase.tearDown:" + " I convinced all threads to exit.") + yield d def config(*argv): r = CliRunner() diff --git a/src/wormhole/test/test_cli.py b/src/wormhole/test/test_cli.py index 644b71b..5cc8b60 100644 --- a/src/wormhole/test/test_cli.py +++ b/src/wormhole/test/test_cli.py @@ -562,7 +562,7 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): yield gatherResults([send_d, receive_d], True) if fake_tor: - expected_endpoints = [("127.0.0.1", self.relayport)] + expected_endpoints = [("127.0.0.1", self.rdv_ws_port)] if mode in ("file", "directory"): expected_endpoints.append(("127.0.0.1", self.transitport)) tx_timing = mtx_tm.call_args[1]["timing"] @@ -665,9 +665,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): self.failUnlessEqual(modes[i], stat.S_IMODE(os.stat(fn).st_mode)) - # check server stats - self._rendezvous.get_stats() - def test_text(self): return self._do_test() def test_text_subprocess(self): @@ -847,9 +844,6 @@ class PregeneratedCode(ServerBase, ScriptsBase, unittest.TestCase): with open(fn, "r") as f: self.failUnlessEqual(f.read(), PRESERVE) - # check server stats - self._rendezvous.get_stats() - def test_fail_file_noclobber(self): return self._do_test_fail("file", "noclobber") def test_fail_directory_noclobber(self): @@ -913,12 +907,10 @@ class ZeroMode(ServerBase, unittest.TestCase): self.assertEqual(receive_stdout, message+NL) self.assertEqual(receive_stderr, "") - # check server stats - self._rendezvous.get_stats() - class NotWelcome(ServerBase, unittest.TestCase): + @inlineCallbacks def setUp(self): - self._setup_relay(error="please upgrade XYZ") + yield self._setup_relay(error="please upgrade XYZ") self.cfg = cfg = config("send") cfg.hide_progress = True cfg.listen = False @@ -947,7 +939,7 @@ class NotWelcome(ServerBase, unittest.TestCase): class NoServer(ServerBase, unittest.TestCase): @inlineCallbacks def setUp(self): - self._setup_relay(None) + yield self._setup_relay(None) yield self._relay_server.disownServiceParent() @inlineCallbacks @@ -1091,8 +1083,9 @@ class ExtractFile(unittest.TestCase): self.assertIn("malicious zipfile", str(e)) class AppID(ServerBase, unittest.TestCase): + @inlineCallbacks def setUp(self): - d = super(AppID, self).setUp() + yield super(AppID, self).setUp() self.cfg = cfg = config("send") # common options for all tests in this suite cfg.hide_progress = True @@ -1100,7 +1093,6 @@ class AppID(ServerBase, unittest.TestCase): cfg.transit_helper = "" cfg.stdout = io.StringIO() cfg.stderr = io.StringIO() - return d @inlineCallbacks def test_override(self): @@ -1115,9 +1107,9 @@ class AppID(ServerBase, unittest.TestCase): yield send_d yield receive_d - used = self._rendezvous._db.execute("SELECT DISTINCT `app_id`" - " FROM `nameplate_usage`" - ).fetchall() + used = self._usage_db.execute("SELECT DISTINCT `app_id`" + " FROM `nameplates`" + ).fetchall() self.assertEqual(len(used), 1, used) self.assertEqual(used[0]["app_id"], u"appid2")