Skip to content
154 changes: 154 additions & 0 deletions test/asynchronous/test_async_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Async-only unit tests for network_layer.py."""

from __future__ import annotations

import asyncio
import struct
import sys
from unittest.mock import AsyncMock, MagicMock, patch

sys.path[0:0] = [""]

from test.asynchronous import AsyncUnitTest, unittest

from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError
from pymongo.network_layer import PyMongoProtocol, _async_socket_receive


def _make_protocol(timeout=None):
# PyMongoProtocol.__init__ calls asyncio.get_running_loop(), so this helper
# must be called from inside an async test method.
protocol = PyMongoProtocol(timeout=timeout)
mock_transport = MagicMock()
mock_transport.is_closing.return_value = False
protocol.transport = mock_transport
return protocol


def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)


class TestPyMongoProtocol(AsyncUnitTest):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, the tests in this class that do not perform asynchronous I/O don't need to be async def either.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not addressed.

def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE):
protocol = _make_protocol()
protocol._max_message_size = max_size
protocol._header = memoryview(bytearray(header_bytes))
return protocol

def test_normal_op_msg(self):
header = _make_header(length=32, request_id=1, response_to=99, op_code=2013)
protocol = self._make_proto_with_header(header)
body_len, op_code, response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 16)
self.assertEqual(op_code, 2013)
self.assertEqual(response_to, 99)
self.assertFalse(expecting_compression)

def test_op_compressed(self):
# OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header
# (op code + uncompressed size + compressor id), then the 16-byte standard header.
# length=35 → after compression sub-header: 26 → body: 10
header = _make_header(length=35, request_id=1, response_to=0, op_code=2012)
protocol = self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 10)
self.assertEqual(op_code, 2012)
self.assertTrue(expecting_compression)

def test_op_compressed_length_too_small_raises(self):
header = _make_header(length=25, request_id=1, response_to=0, op_code=2012)
protocol = self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

def test_non_compressed_length_too_small_raises(self):
header = _make_header(length=16, request_id=1, response_to=0, op_code=2013)
protocol = self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

def test_length_exceeds_max_raises(self):
header = _make_header(
length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013
)
protocol = self._make_proto_with_header(header)
with self.assertRaises(ProtocolError):
protocol.process_header()

def test_op_reply_op_code(self):
header = _make_header(length=20, request_id=0, response_to=0, op_code=1)
protocol = self._make_proto_with_header(header)
body_len, op_code, _response_to, expecting_compression = protocol.process_header()
self.assertEqual(body_len, 4)
self.assertEqual(op_code, 1)
self.assertFalse(expecting_compression)

def test_compression_header_snappy_compressor_id(self):
protocol = _make_protocol()
# <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy)
data = struct.pack("<iiB", 2013, 0, 1)
protocol._compression_header = memoryview(bytearray(data))
op_code, compressor_id = protocol.process_compression_header()
self.assertEqual(op_code, 2013)
self.assertEqual(compressor_id, 1)

def test_compression_header_zlib_compressor_id(self):
protocol = _make_protocol()
data = struct.pack("<iiB", 2013, 0, 2)
protocol._compression_header = memoryview(bytearray(data))
_, compressor_id = protocol.process_compression_header()
self.assertEqual(compressor_id, 2)

def test_close_aborts_transport(self):
protocol = _make_protocol()
protocol.close()
self.assertTrue(protocol.transport.abort.called)

def test_connection_lost_twice_does_not_raise(self):
protocol = _make_protocol()
protocol.connection_lost(None)
protocol.connection_lost(None)

async def test_close_with_exception_propagates_to_pending(self):
protocol = _make_protocol()
future = asyncio.get_running_loop().create_future()
protocol._pending_messages.append(future)
exc = OSError("connection reset")
protocol.close(exc)
with self.assertRaisesRegex(OSError, "connection reset"):
await future


class TestAsyncSocketReceive(AsyncUnitTest):
async def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# sock_recv_into returns 0.
mock_socket = MagicMock()
loop = asyncio.get_running_loop()

async def fake_recv_into(sock, buf):
return 0

with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)):
with self.assertRaisesRegex(OSError, "connection closed"):
await _async_socket_receive(mock_socket, 10, loop)


if __name__ == "__main__":
unittest.main()
130 changes: 130 additions & 0 deletions test/test_network_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Copyright 2026-present MongoDB, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Sync-only unit tests for network_layer.py.

These cover ``receive_message`` and ``receive_data``, which only exist on the
synchronous receive path (the async path uses ``PyMongoProtocol`` instead).
The async-only tests live in ``test/asynchronous/test_async_network_layer.py``.
"""

from __future__ import annotations

import struct
import sys
from unittest.mock import MagicMock, patch

sys.path[0:0] = [""]

from test import UnitTest, unittest

from pymongo import network_layer
from pymongo.common import MAX_MESSAGE_SIZE
from pymongo.errors import ProtocolError


def _make_header(length, request_id, response_to, op_code):
return struct.pack("<iiii", length, request_id, response_to, op_code)


def _make_compression_header(op_code, uncompressed_size, compressor_id):
return struct.pack("<iiB", op_code, uncompressed_size, compressor_id)


def _make_conn():
conn = MagicMock()
conn.conn.gettimeout.return_value = None
# PyPy calls wait_for_read() before recv_into(), which checks fileno() == -1
# as an early-exit. Without this, sock.fileno() returns a MagicMock and the
# subsequent sock.pending() > 0 comparison raises TypeError on PyPy.
conn.conn.sock.fileno.return_value = -1
return conn


class TestReceiveMessage(UnitTest):
def _patch_receive_data(self, *chunks):
"""Make receive_data return the given byte strings on successive calls."""
mock = patch.object(network_layer, "receive_data", side_effect=list(chunks))
self.addCleanup(mock.stop)
return mock.start()

def test_request_id_mismatch_raises(self):
self._patch_receive_data(
_make_header(length=32, request_id=0, response_to=99, op_code=2013)
)
with self.assertRaises(ProtocolError):
network_layer.receive_message(_make_conn(), request_id=1)

def test_length_too_small_raises(self):
self._patch_receive_data(_make_header(length=16, request_id=0, response_to=0, op_code=2013))
with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"):
network_layer.receive_message(_make_conn(), request_id=None)

def test_length_exceeds_max_raises(self):
self._patch_receive_data(
_make_header(length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013)
)
with self.assertRaisesRegex(ProtocolError, "larger than server max"):
network_layer.receive_message(_make_conn(), request_id=None)

def test_normal_op_msg_unpacks(self):
body = b"x" * 16
self._patch_receive_data(
_make_header(length=32, request_id=0, response_to=0, op_code=2013), body
)
unpack = MagicMock(return_value="REPLY")
with patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}):
result = network_layer.receive_message(_make_conn(), request_id=None)
unpack.assert_called_once_with(body)
self.assertEqual(result, "REPLY")

def test_op_compressed_decompresses(self):
# length=35 -> body length = 35 - 25 = 10 (header 16 + compression sub-header 9).
compressed_body = b"y" * 10
self._patch_receive_data(
_make_header(length=35, request_id=0, response_to=0, op_code=2012),
_make_compression_header(op_code=2013, uncompressed_size=0, compressor_id=1),
compressed_body,
)
unpack = MagicMock(return_value="REPLY")
with (
patch.object(network_layer, "decompress", return_value=b"decompressed") as decompress,
patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}),
):
result = network_layer.receive_message(_make_conn(), request_id=None)
decompress.assert_called_once_with(compressed_body, 1)
unpack.assert_called_once_with(b"decompressed")
self.assertEqual(result, "REPLY")

def test_unknown_opcode_raises(self):
self._patch_receive_data(
_make_header(length=20, request_id=0, response_to=0, op_code=9999), b"data"
)
with patch.object(network_layer, "_UNPACK_REPLY", {2013: MagicMock()}):
with self.assertRaises(ProtocolError):
network_layer.receive_message(_make_conn(), request_id=None)


class TestReceiveData(UnitTest):
def test_raises_on_connection_closed(self):
# Covers the explicit `raise OSError("connection closed")` branch when
# recv_into returns 0.
conn = _make_conn()
conn.conn.recv_into.return_value = 0
with self.assertRaisesRegex(OSError, "connection closed"):
network_layer.receive_data(conn, 10, deadline=None)


if __name__ == "__main__":
unittest.main()
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def async_only_test(f: str) -> bool:
"test_async_loop_safety.py",
"test_async_contextvars_reset.py",
"test_async_loop_unblocked.py",
"test_async_network_layer.py",
]


Expand Down
Loading