make SubChannel IDs integers, not 4-bytes

I'm sure I had a good reason for avoiding integers, but it makes logging and
testing more difficult, and both sides are using integers to generate them
anyways (so one side can pick the odd ones, and the other can pick the even
ones).
This commit is contained in:
Brian Warner 2019-07-06 01:10:34 -07:00
parent a74cc99e6a
commit 8043e508fa
10 changed files with 53 additions and 50 deletions

View File

@ -254,7 +254,7 @@ Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
Pong = namedtuple("Pong", ["ping_id"])
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
Data = namedtuple("Data", ["seqnum", "scid", "data"])
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is arbitrary 4-byte value
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
Handshake_or_Records = (Handshake,) + Records
@ -279,16 +279,16 @@ def parse_record(plaintext):
ping_id = plaintext[1:5]
return Pong(ping_id)
if msgtype == T_OPEN:
scid = plaintext[1:5]
scid = from_be4(plaintext[1:5])
seqnum = from_be4(plaintext[5:9])
return Open(seqnum, scid)
if msgtype == T_DATA:
scid = plaintext[1:5]
scid = from_be4(plaintext[1:5])
seqnum = from_be4(plaintext[5:9])
data = plaintext[9:]
return Data(seqnum, scid, data)
if msgtype == T_CLOSE:
scid = plaintext[1:5]
scid = from_be4(plaintext[1:5])
seqnum = from_be4(plaintext[5:9])
return Close(seqnum, scid)
if msgtype == T_ACK:
@ -306,20 +306,17 @@ def encode_record(r):
if isinstance(r, Pong):
return b"\x02" + r.ping_id
if isinstance(r, Open):
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.seqnum, six.integer_types)
return b"\x03" + r.scid + to_be4(r.seqnum)
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
if isinstance(r, Data):
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.seqnum, six.integer_types)
return b"\x04" + r.scid + to_be4(r.seqnum) + r.data
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
if isinstance(r, Close):
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.seqnum, six.integer_types)
return b"\x05" + r.scid + to_be4(r.seqnum)
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
if isinstance(r, Ack):
assert isinstance(r.resp_seqnum, six.integer_types)
return b"\x06" + to_be4(r.resp_seqnum)

View File

@ -65,6 +65,7 @@ class Inbound(object):
seqnum)
def handle_open(self, scid):
log.msg("inbound.handle_open", scid)
if scid in self._open_subchannels:
log.err(DuplicateOpenError(
"received duplicate OPEN for {}".format(scid)))
@ -75,6 +76,7 @@ class Inbound(object):
self._listener_endpoint._got_open(sc, peer_addr)
def handle_data(self, scid, data):
log.msg("inbound.handle_data", scid, len(data))
sc = self._open_subchannels.get(scid)
if sc is None:
log.err(DataForMissingSubchannelError(
@ -83,6 +85,7 @@ class Inbound(object):
sc.remote_data(data)
def handle_close(self, scid):
log.msg("inbound.handle_close", scid)
sc = self._open_subchannels.get(scid)
if sc is None:
log.err(CloseForMissingSubchannelError(

View File

@ -1,4 +1,5 @@
from __future__ import print_function, unicode_literals
import six
import os
from collections import deque
from attr import attrs, attrib
@ -12,7 +13,6 @@ from .._interfaces import IDilator, IDilationManager, ISend, ITerminator
from ..util import dict_to_bytes, bytes_to_dict, bytes_to_hexstr
from ..observer import OneShotObserver
from .._key import derive_key
from .encode import to_be4
from .subchannel import (SubChannel, _SubchannelAddress, _WormholeAddress,
ControlEndpoint, SubchannelConnectorEndpoint,
SubchannelListenerEndpoint)
@ -166,15 +166,15 @@ class Manager(object):
self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid):
assert isinstance(scid, bytes)
assert isinstance(scid, six.integer_types)
self._queue_and_send(Open, scid)
def send_data(self, scid, data):
assert isinstance(scid, bytes)
assert isinstance(scid, six.integer_types)
self._queue_and_send(Data, scid, data)
def send_close(self, scid):
assert isinstance(scid, bytes)
assert isinstance(scid, six.integer_types)
self._queue_and_send(Close, scid)
def _queue_and_send(self, record_type, *args):
@ -265,7 +265,7 @@ class Manager(object):
def allocate_subchannel_id(self):
scid_num = self._next_subchannel_id
self._next_subchannel_id += 2
return to_be4(scid_num)
return scid_num
# state machine
@ -534,7 +534,7 @@ class Dilator(object):
# quickly once the connection is established. This subchannel may or
# may not ever get revealed to the caller, since the peer might not
# even be capable of dilation.
scid0 = to_be4(0)
scid0 = 0
peer_addr0 = _SubchannelAddress(scid0)
sc0 = SubChannel(scid0, self._manager, self._host_addr, peer_addr0)
self._manager.set_subchannel_zero(scid0, sc0)

View File

@ -1,3 +1,4 @@
import six
from attr import attrs, attrib
from attr.validators import instance_of, provides
from zope.interface import implementer
@ -55,7 +56,7 @@ class _WormholeAddress(object):
@implementer(IAddress)
@attrs
class _SubchannelAddress(object):
_scid = attrib(validator=instance_of(bytes))
_scid = attrib(validator=instance_of(six.integer_types))
@attrs
@ -64,7 +65,7 @@ class _SubchannelAddress(object):
@implementer(IConsumer)
@implementer(ISubChannel)
class SubChannel(object):
_id = attrib(validator=instance_of(bytes))
_scid = attrib(validator=instance_of(six.integer_types))
_manager = attrib(validator=provides(IDilationManager))
_host_addr = attrib(validator=instance_of(_WormholeAddress))
_peer_addr = attrib(validator=instance_of(_SubchannelAddress))
@ -111,11 +112,11 @@ class SubChannel(object):
@m.output()
def send_data(self, data):
self._manager.send_data(self._id, data)
self._manager.send_data(self._scid, data)
@m.output()
def send_close(self):
self._manager.send_close(self._id)
self._manager.send_close(self._scid)
@m.output()
def signal_dataReceived(self, data):

View File

@ -9,7 +9,6 @@ from ..._interfaces import IDilationConnector
from ..._dilation.roles import LEADER, FOLLOWER
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
KCM, Open, Ack)
from ..._dilation.encode import to_be4
from .common import clear_mock_calls
@ -57,7 +56,7 @@ class Connection(unittest.TestCase):
def _test_no_relay(self, role):
c, n, connector, t, eq = make_con(role)
t_kcm = KCM()
t_open = Open(seqnum=1, scid=to_be4(0x11223344))
t_open = Open(seqnum=1, scid=0x11223344)
t_ack = Ack(resp_seqnum=2)
n.decrypt = mock.Mock(side_effect=[
encode_record(t_kcm),
@ -237,7 +236,7 @@ class Connection(unittest.TestCase):
def test_follower_combined(self):
c, n, connector, t, eq = make_con(FOLLOWER)
t_kcm = KCM()
t_open = Open(seqnum=1, scid=to_be4(0x11223344))
t_open = Open(seqnum=1, scid=0x11223344)
n.decrypt = mock.Mock(side_effect=[
encode_record(t_kcm),
encode_record(t_open),

View File

@ -14,7 +14,7 @@ from .common import mock_manager
class Endpoints(unittest.TestCase):
def test_control(self):
scid0 = b"scid0"
scid0 = 0
peeraddr = _SubchannelAddress(scid0)
ep = ControlEndpoint(peeraddr)
@ -43,9 +43,9 @@ class Endpoints(unittest.TestCase):
def test_connector(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(b"scid")
peeraddr = _SubchannelAddress(0)
ep = SubchannelConnectorEndpoint(m, hostaddr)
f = mock.Mock()
@ -57,13 +57,13 @@ class Endpoints(unittest.TestCase):
d = ep.connect(f)
self.assertIdentical(self.successResultOf(d), p)
self.assertEqual(f.buildProtocol.mock_calls, [mock.call(peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(b"scid", m, hostaddr, peeraddr)])
self.assertEqual(sc.mock_calls, [mock.call(0, m, hostaddr, peeraddr)])
self.assertEqual(t.mock_calls, [mock.call._set_protocol(p)])
self.assertEqual(p.mock_calls, [mock.call.makeConnection(t)])
def test_listener(self):
m = mock_manager()
m.allocate_subchannel_id = mock.Mock(return_value=b"scid")
m.allocate_subchannel_id = mock.Mock(return_value=0)
hostaddr = _WormholeAddress()
ep = SubchannelListenerEndpoint(m, hostaddr)
@ -75,7 +75,7 @@ class Endpoints(unittest.TestCase):
# OPEN that arrives before we ep.listen() should be queued
t1 = mock.Mock()
peeraddr1 = _SubchannelAddress(b"peer1")
peeraddr1 = _SubchannelAddress(1)
ep._got_open(t1, peeraddr1)
d = ep.listen(f)
@ -89,7 +89,7 @@ class Endpoints(unittest.TestCase):
self.assertEqual(p1.mock_calls, [mock.call.makeConnection(t1)])
t2 = mock.Mock()
peeraddr2 = _SubchannelAddress(b"peer2")
peeraddr2 = _SubchannelAddress(2)
ep._got_open(t2, peeraddr2)
self.assertEqual(t2.mock_calls, [mock.call._set_protocol(p2)])

View File

@ -35,7 +35,7 @@ class Full(ServerBase, unittest.TestCase):
yield self._setup_relay(None)
@inlineCallbacks
def test_full(self):
def test_control(self):
eq = EventualQueue(reactor)
w1 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
w2 = wormhole.create(APPID, self.relayurl, reactor, _enable_dilate=True)
@ -67,6 +67,11 @@ class Full(ServerBase, unittest.TestCase):
yield d1
yield d2
print("control endpoints connected")
# note: I'm making some horrible assumptions about one-to-one writes
# and reads across a TCP stack that isn't obligated to maintain such
# a relationship, but it's much easier than doing this properly. If
# the tests ever start failing, do the extra work, probably by
# using a twisted.protocols.basic.LineOnlyReceiver
data1 = yield f1.d
data2 = yield f2.d
self.assertEqual(data1, b"hello\n")
@ -74,8 +79,7 @@ class Full(ServerBase, unittest.TestCase):
yield w1.close()
yield w2.close()
test_full.timeout = 30
test_control.timeout = 30
class ReconP(Protocol):

View File

@ -9,7 +9,6 @@ from ...eventual import EventualQueue
from ..._interfaces import ISend, IDilationManager, ITerminator
from ...util import dict_to_bytes
from ..._dilation import roles
from ..._dilation.encode import to_be4
from ..._dilation.manager import (Dilator, Manager, make_side,
OldPeerCannotDilateError,
UnknownDilationMessageType,
@ -64,7 +63,7 @@ class TestDilator(unittest.TestCase):
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
scid0 = b"\x00\x00\x00\x00"
scid0 = 0
m = mock.Mock()
alsoProvides(m, IDilationManager)
@ -178,7 +177,7 @@ class TestDilator(unittest.TestCase):
alsoProvides(m, IDilationManager)
m.when_first_connected.return_value = Deferred()
scid0 = b"\x00\x00\x00\x00"
scid0 = 0
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
@ -205,7 +204,7 @@ class TestDilator(unittest.TestCase):
d1 = dil.dilate(transit_relay_location=relay)
self.assertNoResult(d1)
scid0 = b"\x00\x00\x00\x00"
scid0 = 0
sc = mock.Mock()
m_sc = mock.patch("wormhole._dilation.manager.SubChannel",
return_value=sc)
@ -338,7 +337,7 @@ class TestManager(unittest.TestCase):
h.eq.flush_sync()
self.successResultOf(wfc_d2)
scid0 = b"\x00\x00\x00\x00"
scid0 = 0
sc0 = mock.Mock()
m.set_subchannel_zero(scid0, sc0)
listen_ep = mock.Mock()
@ -350,7 +349,7 @@ class TestManager(unittest.TestCase):
clear_mock_calls(h.inbound)
# the Leader making a new outbound channel should get scid=1
scid1 = to_be4(1)
scid1 = 1
self.assertEqual(m.allocate_subchannel_id(), scid1)
r1 = Open(10, scid1) # seqnum=10
h.outbound.build_record = mock.Mock(return_value=r1)
@ -388,7 +387,7 @@ class TestManager(unittest.TestCase):
# test that inbound records get acked and routed to Inbound
h.inbound.is_record_old = mock.Mock(return_value=False)
scid2 = to_be4(2)
scid2 = 2
o200 = Open(200, scid2)
m.got_record(o200)
self.assertEqual(h.outbound.mock_calls, [

View File

@ -13,11 +13,11 @@ class Parse(unittest.TestCase):
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
Pong(ping_id=b"\x55\x44\x33\x22"))
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
Open(scid=b"\x00\x00\x02\x01", seqnum=256))
Open(scid=513, seqnum=256))
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
Data(scid=b"\x00\x00\x02\x02", seqnum=257, data=b"dataaa"))
Data(scid=514, seqnum=257, data=b"dataaa"))
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
Close(scid=b"\x00\x00\x02\x03", seqnum=258))
Close(scid=515, seqnum=258))
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
Ack(resp_seqnum=259))
with mock.patch("wormhole._dilation.connection.log.err") as le:
@ -31,11 +31,11 @@ class Parse(unittest.TestCase):
self.assertEqual(encode_record(KCM()), b"\x00")
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
self.assertEqual(encode_record(Open(scid=b"\x00\x01\x00\x00", seqnum=16)),
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
self.assertEqual(encode_record(Data(scid=b"\x00\x01\x00\x01", seqnum=17, data=b"dataaa")),
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
self.assertEqual(encode_record(Close(scid=b"\x00\x01\x00\x02", seqnum=18)),
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
b"\x06\x00\x00\x00\x13")

View File

@ -10,7 +10,7 @@ from .common import mock_manager
def make_sc(set_protocol=True):
scid = b"scid"
scid = 4
hostaddr = _WormholeAddress()
peeraddr = _SubchannelAddress(scid)
m = mock_manager()