scids are four-byte strings, not integers

be consistent about it
This commit is contained in:
Brian Warner 2019-02-10 16:23:20 -08:00
parent 2ec7b8e662
commit 74c416517f
5 changed files with 26 additions and 19 deletions

View File

@ -250,7 +250,7 @@ Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value
Pong = namedtuple("Pong", ["ping_id"])
Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer
Data = namedtuple("Data", ["seqnum", "scid", "data"])
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer
Close = namedtuple("Close", ["seqnum", "scid"]) # scid is arbitrary 4-byte value
Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer
Records = (KCM, Ping, Pong, Open, Data, Close, Ack)
Handshake_or_Records = (Handshake,) + Records
@ -275,16 +275,16 @@ def parse_record(plaintext):
ping_id = plaintext[1:5]
return Pong(ping_id)
if msgtype == T_OPEN:
scid = from_be4(plaintext[1:5])
scid = plaintext[1:5]
seqnum = from_be4(plaintext[5:9])
return Open(seqnum, scid)
if msgtype == T_DATA:
scid = from_be4(plaintext[1:5])
scid = plaintext[1:5]
seqnum = from_be4(plaintext[5:9])
data = plaintext[9:]
return Data(seqnum, scid, data)
if msgtype == T_CLOSE:
scid = from_be4(plaintext[1:5])
scid = plaintext[1:5]
seqnum = from_be4(plaintext[5:9])
return Close(seqnum, scid)
if msgtype == T_ACK:
@ -302,17 +302,20 @@ def encode_record(r):
if isinstance(r, Pong):
return b"\x02" + r.ping_id
if isinstance(r, Open):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.seqnum, six.integer_types)
return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum)
return b"\x03" + r.scid + to_be4(r.seqnum)
if isinstance(r, Data):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.seqnum, six.integer_types)
return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data
return b"\x04" + r.scid + to_be4(r.seqnum) + r.data
if isinstance(r, Close):
assert isinstance(r.scid, six.integer_types)
assert isinstance(r.scid, bytes)
assert len(r.scid) == 4
assert isinstance(r.seqnum, six.integer_types)
return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum)
return b"\x05" + r.scid + to_be4(r.seqnum)
if isinstance(r, Ack):
assert isinstance(r.resp_seqnum, six.integer_types)
return b"\x06" + to_be4(r.resp_seqnum)

View File

@ -164,12 +164,15 @@ class Manager(object):
self._outbound.subchannel_unregisterProducer(sc)
def send_open(self, scid):
assert isinstance(scid, bytes)
self._queue_and_send(Open, scid)
def send_data(self, scid, data):
assert isinstance(scid, bytes)
self._queue_and_send(Data, scid, data)
def send_close(self, scid):
assert isinstance(scid, bytes)
self._queue_and_send(Close, scid)
def _queue_and_send(self, record_type, *args):
@ -528,7 +531,7 @@ class Dilator(object):
yield self._manager.when_first_connected()
# we can open subchannels as soon as we get our first connection
scid0 = b"\x00\x00\x00\x00"
scid0 = to_be4(0)
self._host_addr = _WormholeAddress() # TODO: share with Manager
peer_addr0 = _SubchannelAddress(scid0)
control_ep = ControlEndpoint(peer_addr0)

View File

@ -55,7 +55,7 @@ class _WormholeAddress(object):
@implementer(IAddress)
@attrs
class _SubchannelAddress(object):
_scid = attrib()
_scid = attrib(validator=instance_of(bytes))
@attrs

View File

@ -9,6 +9,7 @@ from ..._interfaces import IDilationConnector
from ..._dilation.roles import LEADER, FOLLOWER
from ..._dilation.connection import (DilatedConnectionProtocol, encode_record,
KCM, Open, Ack)
from ..._dilation.encode import to_be4
from .common import clear_mock_calls
@ -56,7 +57,7 @@ class Connection(unittest.TestCase):
def _test_no_relay(self, role):
c, n, connector, t, eq = make_con(role)
t_kcm = KCM()
t_open = Open(seqnum=1, scid=0x11223344)
t_open = Open(seqnum=1, scid=to_be4(0x11223344))
t_ack = Ack(resp_seqnum=2)
n.decrypt = mock.Mock(side_effect=[
encode_record(t_kcm),

View File

@ -13,11 +13,11 @@ class Parse(unittest.TestCase):
self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"),
Pong(ping_id=b"\x55\x44\x33\x22"))
self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"),
Open(scid=513, seqnum=256))
Open(scid=b"\x00\x00\x02\x01", seqnum=256))
self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"),
Data(scid=514, seqnum=257, data=b"dataaa"))
Data(scid=b"\x00\x00\x02\x02", seqnum=257, data=b"dataaa"))
self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"),
Close(scid=515, seqnum=258))
Close(scid=b"\x00\x00\x02\x03", seqnum=258))
self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"),
Ack(resp_seqnum=259))
with mock.patch("wormhole._dilation.connection.log.err") as le:
@ -31,11 +31,11 @@ class Parse(unittest.TestCase):
self.assertEqual(encode_record(KCM()), b"\x00")
self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping")
self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong")
self.assertEqual(encode_record(Open(scid=65536, seqnum=16)),
self.assertEqual(encode_record(Open(scid=b"\x00\x01\x00\x00", seqnum=16)),
b"\x03\x00\x01\x00\x00\x00\x00\x00\x10")
self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")),
self.assertEqual(encode_record(Data(scid=b"\x00\x01\x00\x01", seqnum=17, data=b"dataaa")),
b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa")
self.assertEqual(encode_record(Close(scid=65538, seqnum=18)),
self.assertEqual(encode_record(Close(scid=b"\x00\x01\x00\x02", seqnum=18)),
b"\x05\x00\x01\x00\x02\x00\x00\x00\x12")
self.assertEqual(encode_record(Ack(resp_seqnum=19)),
b"\x06\x00\x00\x00\x13")