remove blocking implementation: it will return
It will return as a crochet-based wrapper around the Twisted implementation.
This commit is contained in:
parent
4dfa569769
commit
49785008bb
3
setup.py
3
setup.py
|
@ -14,7 +14,6 @@ setup(name="magic-wormhole",
|
|||
url="https://github.com/warner/magic-wormhole",
|
||||
package_dir={"": "src"},
|
||||
packages=["wormhole",
|
||||
"wormhole.blocking",
|
||||
"wormhole.cli",
|
||||
"wormhole.server",
|
||||
"wormhole.test",
|
||||
|
@ -25,7 +24,7 @@ setup(name="magic-wormhole",
|
|||
["wormhole = wormhole.cli.runner:entry",
|
||||
"wormhole-server = wormhole.server.runner:entry",
|
||||
]},
|
||||
install_requires=["spake2==0.3", "pynacl", "requests", "argparse",
|
||||
install_requires=["spake2==0.3", "pynacl", "argparse",
|
||||
"six", "twisted >= 16.1.0", "hkdf", "tqdm",
|
||||
"autobahn[twisted]", "pytrie",
|
||||
# autobahn seems to have a bug, and one plugin throws
|
||||
|
|
|
@ -1,49 +0,0 @@
|
|||
from __future__ import print_function, unicode_literals
|
||||
import requests
|
||||
|
||||
class EventSourceFollower:
|
||||
def __init__(self, url, timeout):
|
||||
self._resp = requests.get(url,
|
||||
headers={"accept": "text/event-stream"},
|
||||
stream=True,
|
||||
timeout=timeout)
|
||||
self._resp.raise_for_status()
|
||||
self._lines_iter = self._resp.iter_lines(chunk_size=1,
|
||||
decode_unicode=True)
|
||||
|
||||
def close(self):
|
||||
self._resp.close()
|
||||
|
||||
def iter_events(self):
|
||||
# I think Request.iter_lines and .iter_content use chunk_size= in a
|
||||
# funny way, and nothing happens until at least that much data has
|
||||
# arrived. So unless we set chunk_size=1, we won't hear about lines
|
||||
# for a long time. I'd prefer that chunk_size behaved like
|
||||
# read(size), and gave you 1<=x<=size bytes in response.
|
||||
eventtype = "message"
|
||||
current_lines = []
|
||||
for line in self._lines_iter:
|
||||
assert isinstance(line, type(u"")), type(line)
|
||||
if not line:
|
||||
# blank line ends the field: deliver event, reset for next
|
||||
yield (eventtype, "\n".join(current_lines))
|
||||
eventtype = "message"
|
||||
current_lines[:] = []
|
||||
continue
|
||||
if ":" in line:
|
||||
fieldname, data = line.split(":", 1)
|
||||
if data.startswith(" "):
|
||||
data = data[1:]
|
||||
else:
|
||||
fieldname = line
|
||||
data = ""
|
||||
if fieldname == "event":
|
||||
eventtype = data
|
||||
elif fieldname == "data":
|
||||
current_lines.append(data)
|
||||
elif fieldname in ("id", "retry"):
|
||||
# documented but unhandled
|
||||
pass
|
||||
else:
|
||||
#log.msg("weird fieldname", fieldname, data)
|
||||
pass
|
|
@ -1,413 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import os, sys, time, re, requests, json, unicodedata
|
||||
from six.moves.urllib_parse import urlencode
|
||||
from binascii import hexlify, unhexlify
|
||||
from spake2 import SPAKE2_Symmetric
|
||||
from nacl.secret import SecretBox
|
||||
from nacl.exceptions import CryptoError
|
||||
from nacl import utils
|
||||
from .eventsource import EventSourceFollower
|
||||
from .. import __version__
|
||||
from .. import codes
|
||||
from ..errors import ServerError, Timeout, WrongPasswordError, UsageError
|
||||
from ..timing import DebugTiming
|
||||
from hkdf import Hkdf
|
||||
from ..channel_monitor import monitor
|
||||
|
||||
def HKDF(skm, outlen, salt=None, CTXinfo=b""):
|
||||
return Hkdf(salt, skm).expand(CTXinfo, outlen)
|
||||
|
||||
SECOND = 1
|
||||
MINUTE = 60*SECOND
|
||||
|
||||
CONFMSG_NONCE_LENGTH = 128//8
|
||||
CONFMSG_MAC_LENGTH = 256//8
|
||||
def make_confmsg(confkey, nonce):
|
||||
return nonce+HKDF(confkey, CONFMSG_MAC_LENGTH, nonce)
|
||||
|
||||
def to_bytes(u):
|
||||
return unicodedata.normalize("NFC", u).encode("utf-8")
|
||||
|
||||
class Channel:
|
||||
def __init__(self, relay_url, appid, channelid, side, handle_welcome,
|
||||
wait, timeout, timing):
|
||||
self._relay_url = relay_url
|
||||
self._appid = appid
|
||||
self._channelid = channelid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._messages = set() # (phase,body) , body is bytes
|
||||
self._sent_messages = set() # (phase,body)
|
||||
self._started = time.time()
|
||||
self._wait = wait
|
||||
self._timeout = timeout
|
||||
self._timing = timing
|
||||
|
||||
def _add_inbound_messages(self, messages):
|
||||
for msg in messages:
|
||||
phase = msg["phase"]
|
||||
body = unhexlify(msg["body"].encode("ascii"))
|
||||
self._messages.add( (phase, body) )
|
||||
|
||||
def _find_inbound_message(self, phases):
|
||||
their_messages = self._messages - self._sent_messages
|
||||
for phase in phases:
|
||||
for (their_phase,body) in their_messages:
|
||||
if their_phase == phase:
|
||||
return (phase, body)
|
||||
return None
|
||||
|
||||
def send(self, phase, msg):
|
||||
# TODO: retry on failure, with exponential backoff. We're guarding
|
||||
# against the rendezvous server being temporarily offline.
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if not isinstance(msg, type(b"")): raise TypeError(type(msg))
|
||||
self._sent_messages.add( (phase,msg) )
|
||||
payload = {"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"phase": phase,
|
||||
"body": hexlify(msg).decode("ascii")}
|
||||
data = json.dumps(payload).encode("utf-8")
|
||||
with self._timing.add("send %s" % phase):
|
||||
r = requests.post(self._relay_url+"add", data=data,
|
||||
timeout=self._timeout)
|
||||
r.raise_for_status()
|
||||
resp = r.json()
|
||||
if "welcome" in resp:
|
||||
self._handle_welcome(resp["welcome"])
|
||||
self._add_inbound_messages(resp["messages"])
|
||||
|
||||
def get_first_of(self, phases):
|
||||
if not isinstance(phases, (list, set)): raise TypeError(type(phases))
|
||||
for phase in phases:
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
|
||||
# For now, server errors cause the client to fail. TODO: don't. This
|
||||
# will require changing the client to re-post messages when the
|
||||
# server comes back up.
|
||||
|
||||
# fire with a bytestring of the first message for any 'phase' that
|
||||
# wasn't one of our own messages. It will either come from
|
||||
# previously-received messages, or from an EventSource that we attach
|
||||
# to the corresponding URL
|
||||
with self._timing.add("get %s" % "/".join(sorted(phases))):
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
while phase_and_body is None:
|
||||
remaining = self._started + self._timeout - time.time()
|
||||
if remaining < 0:
|
||||
raise Timeout
|
||||
queryargs = urlencode([("appid", self._appid),
|
||||
("channelid", self._channelid)])
|
||||
f = EventSourceFollower(self._relay_url+"watch?%s" % queryargs,
|
||||
remaining)
|
||||
# we loop here until the connection is lost, or we see the
|
||||
# message we want
|
||||
for (eventtype, line) in f.iter_events():
|
||||
if eventtype == "welcome":
|
||||
self._handle_welcome(json.loads(line))
|
||||
if eventtype == "message":
|
||||
data = json.loads(line)
|
||||
self._add_inbound_messages([data])
|
||||
phase_and_body = self._find_inbound_message(phases)
|
||||
if phase_and_body:
|
||||
f.close()
|
||||
break
|
||||
if not phase_and_body:
|
||||
time.sleep(self._wait)
|
||||
return phase_and_body
|
||||
|
||||
def get(self, phase):
|
||||
(got_phase, body) = self.get_first_of([phase])
|
||||
assert got_phase == phase
|
||||
return body
|
||||
|
||||
def deallocate(self, mood=None):
|
||||
# only try once, no retries
|
||||
data = json.dumps({"appid": self._appid,
|
||||
"channelid": self._channelid,
|
||||
"side": self._side,
|
||||
"mood": mood}).encode("utf-8")
|
||||
try:
|
||||
# ignore POST failure, don't call r.raise_for_status(), set a
|
||||
# short timeout and ignore failures
|
||||
with self._timing.add("close"):
|
||||
r = requests.post(self._relay_url+"deallocate", data=data,
|
||||
timeout=5)
|
||||
r.json()
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
|
||||
class ChannelManager:
|
||||
def __init__(self, relay_url, appid, side, handle_welcome, timing=None,
|
||||
wait=0.5*SECOND, timeout=3*MINUTE):
|
||||
self._relay_url = relay_url
|
||||
self._appid = appid
|
||||
self._side = side
|
||||
self._handle_welcome = handle_welcome
|
||||
self._timing = timing or DebugTiming()
|
||||
self._wait = wait
|
||||
self._timeout = timeout
|
||||
|
||||
def list_channels(self):
|
||||
queryargs = urlencode([("appid", self._appid)])
|
||||
with self._timing.add("list"):
|
||||
r = requests.get(self._relay_url+"list?%s" % queryargs,
|
||||
timeout=self._timeout)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "welcome" in data:
|
||||
self._handle_welcome(data["welcome"])
|
||||
channelids = data["channelids"]
|
||||
return channelids
|
||||
|
||||
def allocate(self):
|
||||
data = json.dumps({"appid": self._appid,
|
||||
"side": self._side}).encode("utf-8")
|
||||
with self._timing.add("allocate"):
|
||||
r = requests.post(self._relay_url+"allocate", data=data,
|
||||
timeout=self._timeout)
|
||||
r.raise_for_status()
|
||||
data = r.json()
|
||||
if "welcome" in data:
|
||||
self._handle_welcome(data["welcome"])
|
||||
channelid = data["channelid"]
|
||||
return channelid
|
||||
|
||||
def connect(self, channelid):
|
||||
return Channel(self._relay_url, self._appid, channelid, self._side,
|
||||
self._handle_welcome, self._wait, self._timeout,
|
||||
self._timing)
|
||||
|
||||
def close_on_error(f): # method decorator
|
||||
# Clients report certain errors as "moods", so the server can make a
|
||||
# rough count failed connections (due to mismatched passwords, attacks,
|
||||
# or timeouts). We don't report precondition failures, as those are the
|
||||
# responsibility/fault of the local application code. We count
|
||||
# non-precondition errors in case they represent server-side problems.
|
||||
def _f(self, *args, **kwargs):
|
||||
try:
|
||||
return f(self, *args, **kwargs)
|
||||
except Timeout:
|
||||
self.close(u"lonely")
|
||||
raise
|
||||
except WrongPasswordError:
|
||||
self.close(u"scary")
|
||||
raise
|
||||
except (TypeError, UsageError):
|
||||
# preconditions don't warrant _close_with_error()
|
||||
raise
|
||||
except:
|
||||
self.close(u"errory")
|
||||
raise
|
||||
return _f
|
||||
|
||||
class Wormhole:
|
||||
motd_displayed = False
|
||||
version_warning_displayed = False
|
||||
_send_confirm = True
|
||||
|
||||
def __init__(self, appid, relay_url, wait=0.5*SECOND, timeout=3*MINUTE,
|
||||
timing=None):
|
||||
if not isinstance(appid, type(u"")): raise TypeError(type(appid))
|
||||
if not isinstance(relay_url, type(u"")):
|
||||
raise TypeError(type(relay_url))
|
||||
if not relay_url.endswith(u"/"): raise UsageError
|
||||
self._appid = appid
|
||||
self._relay_url = relay_url
|
||||
self._wait = wait
|
||||
self._timeout = timeout
|
||||
self._timing = timing or DebugTiming()
|
||||
side = hexlify(os.urandom(5)).decode("ascii")
|
||||
self._channel_manager = ChannelManager(relay_url, appid, side,
|
||||
self.handle_welcome,
|
||||
self._timing,
|
||||
self._wait, self._timeout)
|
||||
self._channel = None
|
||||
self.code = None
|
||||
self.key = None
|
||||
self.verifier = None
|
||||
self._sent_data = set() # phases
|
||||
self._got_data = set()
|
||||
self._got_confirmation = False
|
||||
self._closed = False
|
||||
self._timing_started = self._timing.add("wormhole")
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.close()
|
||||
return False
|
||||
|
||||
def handle_welcome(self, welcome):
|
||||
if ("motd" in welcome and
|
||||
not self.motd_displayed):
|
||||
motd_lines = welcome["motd"].splitlines()
|
||||
motd_formatted = "\n ".join(motd_lines)
|
||||
print("Server (at %s) says:\n %s" % (self._relay_url, motd_formatted),
|
||||
file=sys.stderr)
|
||||
self.motd_displayed = True
|
||||
|
||||
# Only warn if we're running a release version (e.g. 0.0.6, not
|
||||
# 0.0.6-DISTANCE-gHASH). Only warn once.
|
||||
if ("-" not in __version__ and
|
||||
not self.version_warning_displayed and
|
||||
welcome["current_version"] != __version__):
|
||||
print("Warning: errors may occur unless both sides are running the same version", file=sys.stderr)
|
||||
print("Server claims %s is current, but ours is %s"
|
||||
% (welcome["current_version"], __version__), file=sys.stderr)
|
||||
self.version_warning_displayed = True
|
||||
|
||||
if "error" in welcome:
|
||||
raise ServerError(welcome["error"], self._relay_url)
|
||||
|
||||
def get_code(self, code_length=2):
|
||||
if self.code is not None: raise UsageError
|
||||
channelid = self._channel_manager.allocate()
|
||||
code = codes.make_code(channelid, code_length)
|
||||
assert isinstance(code, type(u"")), type(code)
|
||||
self._set_code_and_channelid(code)
|
||||
self._start()
|
||||
return code
|
||||
|
||||
def input_code(self, prompt="Enter wormhole code: ", code_length=2):
|
||||
lister = self._channel_manager.list_channels
|
||||
# fetch the list of channels ahead of time, to give us a chance to
|
||||
# discover the welcome message (and warn the user about an obsolete
|
||||
# client)
|
||||
initial_channelids = lister()
|
||||
with self._timing.add("input code", waiting="user"):
|
||||
code = codes.input_code_with_completion(prompt,
|
||||
initial_channelids, lister,
|
||||
code_length)
|
||||
return code
|
||||
|
||||
def set_code(self, code): # used for human-made pre-generated codes
|
||||
if not isinstance(code, type(u"")): raise TypeError(type(code))
|
||||
if self.code is not None: raise UsageError
|
||||
self._set_code_and_channelid(code)
|
||||
self._start()
|
||||
|
||||
def _set_code_and_channelid(self, code):
|
||||
if self.code is not None: raise UsageError
|
||||
self._timing.add("code established")
|
||||
mo = re.search(r'^(\d+)-', code)
|
||||
if not mo:
|
||||
raise ValueError("code (%s) must start with NN-" % code)
|
||||
self.code = code
|
||||
channelid = int(mo.group(1))
|
||||
self._channel = self._channel_manager.connect(channelid)
|
||||
monitor.add(self._channel)
|
||||
|
||||
def _start(self):
|
||||
# allocate the rest now too, so it can be serialized
|
||||
self.sp = SPAKE2_Symmetric(to_bytes(self.code),
|
||||
idSymmetric=to_bytes(self._appid))
|
||||
self.msg1 = self.sp.start()
|
||||
|
||||
def derive_key(self, purpose, length=SecretBox.KEY_SIZE):
|
||||
if not isinstance(purpose, type(u"")): raise TypeError(type(purpose))
|
||||
return HKDF(self.key, length, CTXinfo=to_bytes(purpose))
|
||||
|
||||
def _encrypt_data(self, key, data):
|
||||
assert isinstance(key, type(b"")), type(key)
|
||||
assert isinstance(data, type(b"")), type(data)
|
||||
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||
box = SecretBox(key)
|
||||
nonce = utils.random(SecretBox.NONCE_SIZE)
|
||||
return box.encrypt(data, nonce)
|
||||
|
||||
def _decrypt_data(self, key, encrypted):
|
||||
assert isinstance(key, type(b"")), type(key)
|
||||
assert isinstance(encrypted, type(b"")), type(encrypted)
|
||||
assert len(key) == SecretBox.KEY_SIZE, len(key)
|
||||
box = SecretBox(key)
|
||||
data = box.decrypt(encrypted)
|
||||
return data
|
||||
|
||||
|
||||
def _get_key(self):
|
||||
if not self.key:
|
||||
self._channel.send(u"pake", self.msg1)
|
||||
pake_msg = self._channel.get(u"pake")
|
||||
|
||||
self.key = self.sp.finish(pake_msg)
|
||||
self.verifier = self.derive_key(u"wormhole:verifier")
|
||||
self._timing.add("key established")
|
||||
|
||||
if not self._send_confirm:
|
||||
return
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = os.urandom(CONFMSG_NONCE_LENGTH)
|
||||
confmsg = make_confmsg(confkey, nonce)
|
||||
self._channel.send(u"_confirm", confmsg)
|
||||
|
||||
@close_on_error
|
||||
def get_verifier(self):
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
self._get_key()
|
||||
return self.verifier
|
||||
|
||||
@close_on_error
|
||||
def send_data(self, outbound_data, phase=u"data"):
|
||||
if not isinstance(outbound_data, type(b"")):
|
||||
raise TypeError(type(outbound_data))
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if self._closed: raise UsageError
|
||||
if phase in self._sent_data: raise UsageError # only call this once
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
with self._timing.add("API send data", phase=phase):
|
||||
# Without predefined roles, we can't derive predictably unique
|
||||
# keys for each side, so we use the same key for both. We use
|
||||
# random nonces to keep the messages distinct, and the Channel
|
||||
# automatically ignores reflections.
|
||||
self._sent_data.add(phase)
|
||||
self._get_key()
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
outbound_encrypted = self._encrypt_data(data_key, outbound_data)
|
||||
self._channel.send(phase, outbound_encrypted)
|
||||
|
||||
@close_on_error
|
||||
def get_data(self, phase=u"data"):
|
||||
if not isinstance(phase, type(u"")): raise TypeError(type(phase))
|
||||
if phase in self._got_data: raise UsageError # only call this once
|
||||
if phase.startswith(u"_"): raise UsageError # reserved for internals
|
||||
if self._closed: raise UsageError
|
||||
if self.code is None: raise UsageError
|
||||
if self._channel is None: raise UsageError
|
||||
with self._timing.add("API get data", phase=phase):
|
||||
self._got_data.add(phase)
|
||||
self._get_key()
|
||||
phases = []
|
||||
if not self._got_confirmation:
|
||||
phases.append(u"_confirm")
|
||||
phases.append(phase)
|
||||
(got_phase, body) = self._channel.get_first_of(phases)
|
||||
if got_phase == u"_confirm":
|
||||
confkey = self.derive_key(u"wormhole:confirmation")
|
||||
nonce = body[:CONFMSG_NONCE_LENGTH]
|
||||
if body != make_confmsg(confkey, nonce):
|
||||
raise WrongPasswordError
|
||||
self._got_confirmation = True
|
||||
(got_phase, body) = self._channel.get_first_of([phase])
|
||||
assert got_phase == phase
|
||||
try:
|
||||
data_key = self.derive_key(u"wormhole:phase:%s" % phase)
|
||||
inbound_data = self._decrypt_data(data_key, body)
|
||||
return inbound_data
|
||||
except CryptoError:
|
||||
raise WrongPasswordError
|
||||
|
||||
def close(self, mood=u"happy"):
|
||||
if not isinstance(mood, (type(None), type(u""))):
|
||||
raise TypeError(type(mood))
|
||||
self._closed = True
|
||||
if self._channel:
|
||||
self._timing_started.finish(mood=mood)
|
||||
c, self._channel = self._channel, None
|
||||
monitor.close(c)
|
||||
c.deallocate(mood)
|
|
@ -1,446 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import json
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import gatherResults, succeed
|
||||
from twisted.internet.threads import deferToThread
|
||||
from ..blocking.transcribe import (Wormhole, UsageError, ChannelManager,
|
||||
WrongPasswordError)
|
||||
from ..blocking.eventsource import EventSourceFollower
|
||||
from .common import ServerBase
|
||||
|
||||
APPID = u"appid"
|
||||
|
||||
class Channel(ServerBase, unittest.TestCase):
|
||||
def ignore(self, welcome):
|
||||
pass
|
||||
|
||||
def test_allocate(self):
|
||||
cm = ChannelManager(self.relayurl, APPID, u"side", self.ignore)
|
||||
d = deferToThread(cm.list_channels)
|
||||
def _got_channels(channels):
|
||||
self.failUnlessEqual(channels, [])
|
||||
d.addCallback(_got_channels)
|
||||
d.addCallback(lambda _: deferToThread(cm.allocate))
|
||||
def _allocated(channelid):
|
||||
self.failUnlessEqual(type(channelid), int)
|
||||
self._channelid = channelid
|
||||
d.addCallback(_allocated)
|
||||
d.addCallback(lambda _: deferToThread(cm.connect, self._channelid))
|
||||
def _connected(c):
|
||||
self._channel = c
|
||||
d.addCallback(_connected)
|
||||
d.addCallback(lambda _: deferToThread(self._channel.deallocate,
|
||||
u"happy"))
|
||||
return d
|
||||
|
||||
def test_messages(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
|
||||
d.addCallback(lambda _: deferToThread(c2.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1"))
|
||||
d.addCallback(lambda _: deferToThread(c2.send, u"phase1", b"msg2"))
|
||||
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# it's legal to fetch a phase multiple times, should be idempotent
|
||||
d.addCallback(lambda _: deferToThread(c1.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg2"))
|
||||
# deallocating one side is not enough to destroy the channel
|
||||
d.addCallback(lambda _: deferToThread(c2.deallocate))
|
||||
def _not_yet(_):
|
||||
self._rendezvous.prune()
|
||||
self.failUnlessEqual(len(self._rendezvous._apps), 1)
|
||||
d.addCallback(_not_yet)
|
||||
# but deallocating both will make the messages go away
|
||||
d.addCallback(lambda _: deferToThread(c1.deallocate, u"sad"))
|
||||
def _gone(_):
|
||||
self._rendezvous.prune()
|
||||
self.failUnlessEqual(len(self._rendezvous._apps), 0)
|
||||
d.addCallback(_gone)
|
||||
|
||||
return d
|
||||
|
||||
def test_get_multiple_phases(self):
|
||||
cm1 = ChannelManager(self.relayurl, APPID, u"side1", self.ignore)
|
||||
cm2 = ChannelManager(self.relayurl, APPID, u"side2", self.ignore)
|
||||
c1 = cm1.connect(1)
|
||||
c2 = cm2.connect(1)
|
||||
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, u"phase1")
|
||||
self.failUnlessRaises(TypeError, c2.get_first_of, [u"phase1", 7])
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: deferToThread(c1.send, u"phase1", b"msg1"))
|
||||
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
|
||||
u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
|
||||
u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
|
||||
d.addCallback(lambda _: deferToThread(c1.send, u"phase2", b"msg2"))
|
||||
d.addCallback(lambda _: deferToThread(c2.get, u"phase2"))
|
||||
|
||||
# if both are present, it should prefer the first one we asked for
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase1",
|
||||
u"phase2"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase1", b"msg1")))
|
||||
d.addCallback(lambda _: deferToThread(c2.get_first_of, [u"phase2",
|
||||
u"phase1"]))
|
||||
d.addCallback(lambda phase_and_body:
|
||||
self.failUnlessEqual(phase_and_body,
|
||||
(u"phase2", b"msg2")))
|
||||
|
||||
return d
|
||||
|
||||
def test_appid_independence(self):
|
||||
APPID_A = u"appid_A"
|
||||
APPID_B = u"appid_B"
|
||||
cm1a = ChannelManager(self.relayurl, APPID_A, u"side1", self.ignore)
|
||||
cm2a = ChannelManager(self.relayurl, APPID_A, u"side2", self.ignore)
|
||||
c1a = cm1a.connect(1)
|
||||
c2a = cm2a.connect(1)
|
||||
cm1b = ChannelManager(self.relayurl, APPID_B, u"side1", self.ignore)
|
||||
cm2b = ChannelManager(self.relayurl, APPID_B, u"side2", self.ignore)
|
||||
c1b = cm1b.connect(1)
|
||||
c2b = cm2b.connect(1)
|
||||
|
||||
d = succeed(None)
|
||||
d.addCallback(lambda _: deferToThread(c1a.send, u"phase1", b"msg1a"))
|
||||
d.addCallback(lambda _: deferToThread(c1b.send, u"phase1", b"msg1b"))
|
||||
d.addCallback(lambda _: deferToThread(c2a.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1a"))
|
||||
d.addCallback(lambda _: deferToThread(c2b.get, u"phase1"))
|
||||
d.addCallback(lambda msg: self.failUnlessEqual(msg, b"msg1b"))
|
||||
return d
|
||||
|
||||
class _DoBothMixin:
|
||||
def doBoth(self, call1, call2):
|
||||
f1 = call1[0]
|
||||
f1args = call1[1:]
|
||||
f2 = call2[0]
|
||||
f2args = call2[1:]
|
||||
return gatherResults([deferToThread(f1, *f1args),
|
||||
deferToThread(f2, *f2args)], True)
|
||||
|
||||
class Blocking(_DoBothMixin, ServerBase, unittest.TestCase):
|
||||
# we need Twisted to run the server, but we run the sender and receiver
|
||||
# with deferToThread()
|
||||
|
||||
def test_basic(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w1.get_code)
|
||||
def _got_code(code):
|
||||
w2.set_code(code)
|
||||
return self.doBoth([w1.send_data, b"data1"],
|
||||
[w2.send_data, b"data2"])
|
||||
d.addCallback(_got_code)
|
||||
def _sent(res):
|
||||
return self.doBoth([w1.get_data], [w2.get_data])
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
self.assertEqual(dataY, b"data1")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_same_message(self):
|
||||
# the two sides use random nonces for their messages, so it's ok for
|
||||
# both to try and send the same body: they'll result in distinct
|
||||
# encrypted messages
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w1.get_code)
|
||||
def _got_code(code):
|
||||
w2.set_code(code)
|
||||
return self.doBoth([w1.send_data, b"data"],
|
||||
[w2.send_data, b"data"])
|
||||
d.addCallback(_got_code)
|
||||
def _sent(res):
|
||||
return self.doBoth([w1.get_data], [w2.get_data])
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data")
|
||||
self.assertEqual(dataY, b"data")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_interleaved(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w1.get_code)
|
||||
def _got_code(code):
|
||||
w2.set_code(code)
|
||||
return self.doBoth([w1.send_data, b"data1"],
|
||||
[w2.get_data])
|
||||
d.addCallback(_got_code)
|
||||
def _sent(res):
|
||||
(_, dataY) = res
|
||||
self.assertEqual(dataY, b"data1")
|
||||
return self.doBoth([w1.get_data], [w2.send_data, b"data2"])
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, _) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_fixed_code(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
w1.set_code(u"123-purple-elephant")
|
||||
w2.set_code(u"123-purple-elephant")
|
||||
d = self.doBoth([w1.send_data, b"data1"], [w2.send_data, b"data2"])
|
||||
def _sent(res):
|
||||
return self.doBoth([w1.get_data], [w2.get_data])
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
self.assertEqual(dataY, b"data1")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_phases(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
w1.set_code(u"123-purple-elephant")
|
||||
w2.set_code(u"123-purple-elephant")
|
||||
d = self.doBoth([w1.send_data, b"data1", u"p1"],
|
||||
[w2.send_data, b"data2", u"p1"])
|
||||
d.addCallback(lambda _:
|
||||
self.doBoth([w1.send_data, b"data3", u"p2"],
|
||||
[w2.send_data, b"data4", u"p2"]))
|
||||
d.addCallback(lambda _:
|
||||
self.doBoth([w1.get_data, u"p2"],
|
||||
[w2.get_data, u"p1"]))
|
||||
def _got_1(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data4")
|
||||
self.assertEqual(dataY, b"data1")
|
||||
return self.doBoth([w1.get_data, u"p1"],
|
||||
[w2.get_data, u"p2"])
|
||||
d.addCallback(_got_1)
|
||||
def _got_2(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
self.assertEqual(dataY, b"data3")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_got_2)
|
||||
return d
|
||||
|
||||
def test_wrong_password(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
|
||||
# make sure we can detect WrongPasswordError even if one side only
|
||||
# does get_data() and not send_data(), like "wormhole receive" does
|
||||
d = deferToThread(w1.get_code)
|
||||
d.addCallback(lambda code: w2.set_code(code+"not"))
|
||||
|
||||
# w2 can't throw WrongPasswordError until it sees a CONFIRM message,
|
||||
# and w1 won't send CONFIRM until it sees a PAKE message, which w2
|
||||
# won't send until we call get_data. So we need both sides to be
|
||||
# running at the same time for this test.
|
||||
def _w1_sends():
|
||||
w1.send_data(b"data1")
|
||||
def _w2_gets():
|
||||
self.assertRaises(WrongPasswordError, w2.get_data)
|
||||
d.addCallback(lambda _: self.doBoth([_w1_sends], [_w2_gets]))
|
||||
|
||||
# and now w1 should have enough information to throw too
|
||||
d.addCallback(lambda _: deferToThread(self.assertRaises,
|
||||
WrongPasswordError, w1.get_data))
|
||||
def _done(_):
|
||||
# both sides are closed automatically upon error, but it's still
|
||||
# legal to call .close(), and should be idempotent
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_no_confirm(self):
|
||||
# newer versions (which check confirmations) should will work with
|
||||
# older versions (that don't send confirmations)
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w1._send_confirm = False
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
|
||||
d = deferToThread(w1.get_code)
|
||||
d.addCallback(lambda code: w2.set_code(code))
|
||||
d.addCallback(lambda _: self.doBoth([w1.send_data, b"data1"],
|
||||
[w2.get_data]))
|
||||
d.addCallback(lambda dl: self.assertEqual(dl[1], b"data1"))
|
||||
d.addCallback(lambda _: self.doBoth([w1.get_data],
|
||||
[w2.send_data, b"data2"]))
|
||||
d.addCallback(lambda dl: self.assertEqual(dl[0], b"data2"))
|
||||
d.addCallback(lambda _: self.doBoth([w1.close], [w2.close]))
|
||||
return d
|
||||
|
||||
def test_verifier(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w1.get_code)
|
||||
def _got_code(code):
|
||||
w2.set_code(code)
|
||||
return self.doBoth([w1.get_verifier], [w2.get_verifier])
|
||||
d.addCallback(_got_code)
|
||||
def _check_verifier(res):
|
||||
v1, v2 = res
|
||||
self.failUnlessEqual(type(v1), type(b""))
|
||||
self.failUnlessEqual(v1, v2)
|
||||
return self.doBoth([w1.send_data, b"data1"],
|
||||
[w2.send_data, b"data2"])
|
||||
d.addCallback(_check_verifier)
|
||||
def _sent(res):
|
||||
return self.doBoth([w1.get_data], [w2.get_data])
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
self.assertEqual(dataY, b"data1")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_verifier_mismatch(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w1.get_code)
|
||||
def _got_code(code):
|
||||
w2.set_code(code+"not")
|
||||
return self.doBoth([w1.get_verifier], [w2.get_verifier])
|
||||
d.addCallback(_got_code)
|
||||
def _check_verifier(res):
|
||||
v1, v2 = res
|
||||
self.failUnlessEqual(type(v1), type(b""))
|
||||
self.failIfEqual(v1, v2)
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_check_verifier)
|
||||
return d
|
||||
|
||||
def test_errors(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
self.assertRaises(UsageError, w1.get_verifier)
|
||||
self.assertRaises(UsageError, w1.get_data)
|
||||
self.assertRaises(UsageError, w1.send_data, b"data")
|
||||
w1.set_code(u"123-purple-elephant")
|
||||
self.assertRaises(UsageError, w1.set_code, u"123-nope")
|
||||
self.assertRaises(UsageError, w1.get_code)
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w2.get_code)
|
||||
def _done(code):
|
||||
self.assertRaises(UsageError, w2.get_code)
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_repeat_phases(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
w1.set_code(u"123-purple-elephant")
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
w2.set_code(u"123-purple-elephant")
|
||||
# we must let them establish a key before we can send data
|
||||
d = self.doBoth([w1.get_verifier], [w2.get_verifier])
|
||||
d.addCallback(lambda _:
|
||||
deferToThread(w1.send_data, b"data1", phase=u"1"))
|
||||
def _sent(res):
|
||||
# underscore-prefixed phases are reserved
|
||||
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"_1")
|
||||
self.assertRaises(UsageError, w1.get_data, phase=u"_1")
|
||||
# you can't send twice to the same phase
|
||||
self.assertRaises(UsageError, w1.send_data, b"data1", phase=u"1")
|
||||
# but you can send to a different one
|
||||
return deferToThread(w1.send_data, b"data2", phase=u"2")
|
||||
d.addCallback(_sent)
|
||||
d.addCallback(lambda _: deferToThread(w2.get_data, phase=u"1"))
|
||||
def _got1(res):
|
||||
self.failUnlessEqual(res, b"data1")
|
||||
# and you can't read twice from the same phase
|
||||
self.assertRaises(UsageError, w2.get_data, phase=u"1")
|
||||
# but you can read from a different one
|
||||
return deferToThread(w2.get_data, phase=u"2")
|
||||
d.addCallback(_got1)
|
||||
def _got2(res):
|
||||
self.failUnlessEqual(res, b"data2")
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_got2)
|
||||
return d
|
||||
|
||||
def test_serialize(self):
|
||||
w1 = Wormhole(APPID, self.relayurl)
|
||||
self.assertRaises(UsageError, w1.serialize) # too early
|
||||
w2 = Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(w1.get_code)
|
||||
def _got_code(code):
|
||||
self.assertRaises(UsageError, w2.serialize) # too early
|
||||
w2.set_code(code)
|
||||
w2.serialize() # ok
|
||||
s = w1.serialize()
|
||||
self.assertEqual(type(s), type(""))
|
||||
unpacked = json.loads(s) # this is supposed to be JSON
|
||||
self.assertEqual(type(unpacked), dict)
|
||||
self.new_w1 = Wormhole.from_serialized(s)
|
||||
return self.doBoth([self.new_w1.send_data, b"data1"],
|
||||
[w2.send_data, b"data2"])
|
||||
d.addCallback(_got_code)
|
||||
def _sent(res):
|
||||
return self.doBoth(self.new_w1.get_data(), w2.get_data())
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
self.assertEqual(dataY, b"data1")
|
||||
self.assertRaises(UsageError, w2.serialize) # too late
|
||||
return self.doBoth([w1.close], [w2.close])
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
test_serialize.skip = "not yet implemented for the blocking flavor"
|
||||
|
||||
data1 = u"""\
|
||||
event: welcome
|
||||
data: one and a
|
||||
data: two
|
||||
data:.
|
||||
|
||||
data: three
|
||||
|
||||
: this line is ignored
|
||||
event: e2
|
||||
: this line is ignored too
|
||||
i am a dataless field name
|
||||
data: four
|
||||
|
||||
"""
|
||||
|
||||
class NoNetworkESF(EventSourceFollower):
|
||||
def __init__(self, text):
|
||||
self._lines_iter = iter(text.splitlines())
|
||||
|
||||
class EventSourceClient(unittest.TestCase):
|
||||
def test_parser(self):
|
||||
events = []
|
||||
f = NoNetworkESF(data1)
|
||||
events = list(f.iter_events())
|
||||
self.failUnlessEqual(events,
|
||||
[(u"welcome", u"one and a\ntwo\n."),
|
||||
(u"message", u"three"),
|
||||
(u"e2", u"four"),
|
||||
])
|
|
@ -1,57 +0,0 @@
|
|||
from __future__ import print_function
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet.defer import gatherResults
|
||||
from twisted.internet.threads import deferToThread
|
||||
from ..twisted.transcribe import Wormhole as twisted_Wormhole
|
||||
from ..blocking.transcribe import Wormhole as blocking_Wormhole
|
||||
from .common import ServerBase
|
||||
|
||||
# make sure the two implementations (Twisted-style and blocking-style) can
|
||||
# interoperate
|
||||
|
||||
APPID = u"appid"
|
||||
|
||||
class Basic(ServerBase, unittest.TestCase):
|
||||
|
||||
def doBoth(self, call1, d2):
|
||||
f1 = call1[0]
|
||||
f1args = call1[1:]
|
||||
return gatherResults([deferToThread(f1, *f1args), d2], True)
|
||||
|
||||
def test_twisted_to_blocking(self):
|
||||
tw = twisted_Wormhole(APPID, self.relayurl)
|
||||
bw = blocking_Wormhole(APPID, self.relayurl)
|
||||
d = tw.get_code()
|
||||
def _got_code(code):
|
||||
bw.set_code(code)
|
||||
return self.doBoth([bw.send_data, b"data2"], tw.send_data(b"data1"))
|
||||
d.addCallback(_got_code)
|
||||
def _sent(res):
|
||||
return self.doBoth([bw.get_data], tw.get_data())
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data1")
|
||||
self.assertEqual(dataY, b"data2")
|
||||
return self.doBoth([bw.close], tw.close())
|
||||
d.addCallback(_done)
|
||||
return d
|
||||
|
||||
def test_blocking_to_twisted(self):
|
||||
bw = blocking_Wormhole(APPID, self.relayurl)
|
||||
tw = twisted_Wormhole(APPID, self.relayurl)
|
||||
d = deferToThread(bw.get_code)
|
||||
def _got_code(code):
|
||||
tw.set_code(code)
|
||||
return self.doBoth([bw.send_data, b"data1"], tw.send_data(b"data2"))
|
||||
d.addCallback(_got_code)
|
||||
def _sent(res):
|
||||
return self.doBoth([bw.get_data], tw.get_data())
|
||||
d.addCallback(_sent)
|
||||
def _done(dl):
|
||||
(dataX, dataY) = dl
|
||||
self.assertEqual(dataX, b"data2")
|
||||
self.assertEqual(dataY, b"data1")
|
||||
return self.doBoth([bw.close], tw.close())
|
||||
d.addCallback(_done)
|
||||
return d
|
Loading…
Reference in New Issue
Block a user