This commit is contained in:
meejah 2021-04-02 14:58:31 -06:00
parent 99c71112b6
commit f18edc89f9
2 changed files with 17 additions and 9 deletions

View File

@ -594,7 +594,7 @@ class New(unittest.TestCase):
server_protocol = server_factory.buildProtocol(('127.0.0.1', 0)) server_protocol = server_factory.buildProtocol(('127.0.0.1', 0))
class ClientProtocol(protocol.Protocol): class ClientProtocol(protocol.Protocol):
def sendMessage(self, data): def send(self, data):
self.transport.write(data) self.transport.write(data)
def disconnect(self): def disconnect(self):
@ -623,8 +623,13 @@ class New(unittest.TestCase):
ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0)) ws_protocol = ws_factory.buildProtocol(('127.0.0.1', 0))
from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol from autobahn.twisted.websocket import WebSocketClientFactory, WebSocketClientProtocol
class TransitWebSocketClientProtocol(WebSocketClientProtocol):
def send(self, data):
self.sendMessage(data, True)
client_factory = WebSocketClientFactory() client_factory = WebSocketClientFactory()
client_factory.protocol = WebSocketClientProtocol client_factory.protocol = TransitWebSocketClientProtocol
client_factory.protocols = ["binary"] client_factory.protocols = ["binary"]
client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337)) client_protocol = client_factory.buildProtocol(('127.0.0.1', 31337))
client_protocol.disconnect = client_protocol.dropConnection client_protocol.disconnect = client_protocol.dropConnection
@ -639,9 +644,10 @@ class New(unittest.TestCase):
return client_protocol return client_protocol
def test_short(self): def test_short(self):
# XXX this test only makes sense for TCP
p1 = self.new_protocol() p1 = self.new_protocol()
# hang up before sending a complete handshake # hang up before sending a complete handshake
# p1.sendMessage(b"short") # <-- only makes sense for TCP p1.send(b"short")
p1.disconnect() p1.disconnect()
self.flush() self.flush()
@ -659,21 +665,19 @@ class New(unittest.TestCase):
from twisted.internet import reactor from twisted.internet import reactor
print("p1 data") print("p1 data")
p1.sendMessage(handshake(token1, side=side1), True) p1.send(handshake(token1, side=side1))
print("p2 data") print("p2 data")
p2.sendMessage(handshake(token1, side=side2), True) p2.send(handshake(token1, side=side2))
self.flush() self.flush()
print("shouldn't be events yet") print("shouldn't be events yet")
self.assertEqual(self._usage.events, []) # no events yet self.assertEqual(self._usage.events, []) # no events yet
print("p1 moar") print("p1 moar")
for x in range(13): p1.send(b"\x00" * 13)
p1.sendMessage(b"\x00", True)
##p1.sendMessage(b"\x00" * 13)
self.flush() self.flush()
print("p2 moar") print("p2 moar")
p2.sendMessage(b"\xff" * 7, True) p2.send(b"\xff" * 7)
self.flush() self.flush()
print("p1 lose") print("p1 lose")

View File

@ -277,6 +277,10 @@ class WebSocketTransitConnection(WebSocketServerProtocol):
""" """
We may have a 'handshake' on our hands or we may just have some bytes to relay We may have a 'handshake' on our hands or we may just have some bytes to relay
""" """
if not isBinary:
raise ValueError(
"All messages must be binary"
)
# print("onMessage isBinary={}: {}".format(isBinary, payload)) # print("onMessage isBinary={}: {}".format(isBinary, payload))
if self._first_message: if self._first_message:
self._first_message = False self._first_message = False