diff --git a/src/wormhole/_dilation/connection.py b/src/wormhole/_dilation/connection.py index 1601b25..d53cdd5 100644 --- a/src/wormhole/_dilation/connection.py +++ b/src/wormhole/_dilation/connection.py @@ -250,7 +250,7 @@ Ping = namedtuple("Ping", ["ping_id"]) # ping_id is arbitrary 4-byte value Pong = namedtuple("Pong", ["ping_id"]) Open = namedtuple("Open", ["seqnum", "scid"]) # seqnum is integer Data = namedtuple("Data", ["seqnum", "scid", "data"]) -Close = namedtuple("Close", ["seqnum", "scid"]) # scid is integer +Close = namedtuple("Close", ["seqnum", "scid"]) # scid is arbitrary 4-byte value Ack = namedtuple("Ack", ["resp_seqnum"]) # resp_seqnum is integer Records = (KCM, Ping, Pong, Open, Data, Close, Ack) Handshake_or_Records = (Handshake,) + Records @@ -275,16 +275,16 @@ def parse_record(plaintext): ping_id = plaintext[1:5] return Pong(ping_id) if msgtype == T_OPEN: - scid = from_be4(plaintext[1:5]) + scid = plaintext[1:5] seqnum = from_be4(plaintext[5:9]) return Open(seqnum, scid) if msgtype == T_DATA: - scid = from_be4(plaintext[1:5]) + scid = plaintext[1:5] seqnum = from_be4(plaintext[5:9]) data = plaintext[9:] return Data(seqnum, scid, data) if msgtype == T_CLOSE: - scid = from_be4(plaintext[1:5]) + scid = plaintext[1:5] seqnum = from_be4(plaintext[5:9]) return Close(seqnum, scid) if msgtype == T_ACK: @@ -302,17 +302,20 @@ def encode_record(r): if isinstance(r, Pong): return b"\x02" + r.ping_id if isinstance(r, Open): - assert isinstance(r.scid, six.integer_types) + assert isinstance(r.scid, bytes) + assert len(r.scid) == 4 assert isinstance(r.seqnum, six.integer_types) - return b"\x03" + to_be4(r.scid) + to_be4(r.seqnum) + return b"\x03" + r.scid + to_be4(r.seqnum) if isinstance(r, Data): - assert isinstance(r.scid, six.integer_types) + assert isinstance(r.scid, bytes) + assert len(r.scid) == 4 assert isinstance(r.seqnum, six.integer_types) - return b"\x04" + to_be4(r.scid) + to_be4(r.seqnum) + r.data + return b"\x04" + r.scid + to_be4(r.seqnum) + r.data if isinstance(r, Close): - assert isinstance(r.scid, six.integer_types) + assert isinstance(r.scid, bytes) + assert len(r.scid) == 4 assert isinstance(r.seqnum, six.integer_types) - return b"\x05" + to_be4(r.scid) + to_be4(r.seqnum) + return b"\x05" + r.scid + to_be4(r.seqnum) if isinstance(r, Ack): assert isinstance(r.resp_seqnum, six.integer_types) return b"\x06" + to_be4(r.resp_seqnum) diff --git a/src/wormhole/_dilation/manager.py b/src/wormhole/_dilation/manager.py index c9b2b47..7d50380 100644 --- a/src/wormhole/_dilation/manager.py +++ b/src/wormhole/_dilation/manager.py @@ -164,12 +164,15 @@ class Manager(object): self._outbound.subchannel_unregisterProducer(sc) def send_open(self, scid): + assert isinstance(scid, bytes) self._queue_and_send(Open, scid) def send_data(self, scid, data): + assert isinstance(scid, bytes) self._queue_and_send(Data, scid, data) def send_close(self, scid): + assert isinstance(scid, bytes) self._queue_and_send(Close, scid) def _queue_and_send(self, record_type, *args): @@ -528,7 +531,7 @@ class Dilator(object): yield self._manager.when_first_connected() # we can open subchannels as soon as we get our first connection - scid0 = b"\x00\x00\x00\x00" + scid0 = to_be4(0) self._host_addr = _WormholeAddress() # TODO: share with Manager peer_addr0 = _SubchannelAddress(scid0) control_ep = ControlEndpoint(peer_addr0) diff --git a/src/wormhole/_dilation/subchannel.py b/src/wormhole/_dilation/subchannel.py index abd1939..ddcc856 100644 --- a/src/wormhole/_dilation/subchannel.py +++ b/src/wormhole/_dilation/subchannel.py @@ -55,7 +55,7 @@ class _WormholeAddress(object): @implementer(IAddress) @attrs class _SubchannelAddress(object): - _scid = attrib() + _scid = attrib(validator=instance_of(bytes)) @attrs diff --git a/src/wormhole/test/dilate/test_connection.py b/src/wormhole/test/dilate/test_connection.py index 07eab68..45c71e1 100644 --- a/src/wormhole/test/dilate/test_connection.py +++ b/src/wormhole/test/dilate/test_connection.py @@ -9,6 +9,7 @@ from ..._interfaces import IDilationConnector from ..._dilation.roles import LEADER, FOLLOWER from ..._dilation.connection import (DilatedConnectionProtocol, encode_record, KCM, Open, Ack) +from ..._dilation.encode import to_be4 from .common import clear_mock_calls @@ -56,7 +57,7 @@ class Connection(unittest.TestCase): def _test_no_relay(self, role): c, n, connector, t, eq = make_con(role) t_kcm = KCM() - t_open = Open(seqnum=1, scid=0x11223344) + t_open = Open(seqnum=1, scid=to_be4(0x11223344)) t_ack = Ack(resp_seqnum=2) n.decrypt = mock.Mock(side_effect=[ encode_record(t_kcm), diff --git a/src/wormhole/test/dilate/test_parse.py b/src/wormhole/test/dilate/test_parse.py index f7276a6..f40c661 100644 --- a/src/wormhole/test/dilate/test_parse.py +++ b/src/wormhole/test/dilate/test_parse.py @@ -13,11 +13,11 @@ class Parse(unittest.TestCase): self.assertEqual(parse_record(b"\x02\x55\x44\x33\x22"), Pong(ping_id=b"\x55\x44\x33\x22")) self.assertEqual(parse_record(b"\x03\x00\x00\x02\x01\x00\x00\x01\x00"), - Open(scid=513, seqnum=256)) + Open(scid=b"\x00\x00\x02\x01", seqnum=256)) self.assertEqual(parse_record(b"\x04\x00\x00\x02\x02\x00\x00\x01\x01dataaa"), - Data(scid=514, seqnum=257, data=b"dataaa")) + Data(scid=b"\x00\x00\x02\x02", seqnum=257, data=b"dataaa")) self.assertEqual(parse_record(b"\x05\x00\x00\x02\x03\x00\x00\x01\x02"), - Close(scid=515, seqnum=258)) + Close(scid=b"\x00\x00\x02\x03", seqnum=258)) self.assertEqual(parse_record(b"\x06\x00\x00\x01\x03"), Ack(resp_seqnum=259)) with mock.patch("wormhole._dilation.connection.log.err") as le: @@ -31,11 +31,11 @@ class Parse(unittest.TestCase): self.assertEqual(encode_record(KCM()), b"\x00") self.assertEqual(encode_record(Ping(ping_id=b"ping")), b"\x01ping") self.assertEqual(encode_record(Pong(ping_id=b"pong")), b"\x02pong") - self.assertEqual(encode_record(Open(scid=65536, seqnum=16)), + self.assertEqual(encode_record(Open(scid=b"\x00\x01\x00\x00", seqnum=16)), b"\x03\x00\x01\x00\x00\x00\x00\x00\x10") - self.assertEqual(encode_record(Data(scid=65537, seqnum=17, data=b"dataaa")), + self.assertEqual(encode_record(Data(scid=b"\x00\x01\x00\x01", seqnum=17, data=b"dataaa")), b"\x04\x00\x01\x00\x01\x00\x00\x00\x11dataaa") - self.assertEqual(encode_record(Close(scid=65538, seqnum=18)), + self.assertEqual(encode_record(Close(scid=b"\x00\x01\x00\x02", seqnum=18)), b"\x05\x00\x01\x00\x02\x00\x00\x00\x12") self.assertEqual(encode_record(Ack(resp_seqnum=19)), b"\x06\x00\x00\x00\x13")