-
Notifications
You must be signed in to change notification settings - Fork 1.2k
PYTHON-5781 Coverage increase for network_layer.py
#2774
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aclark4life
wants to merge
12
commits into
mongodb:master
Choose a base branch
from
aclark4life:PYTHON-5781
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+285
−0
Open
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
0b861f7
PYTHON-5781 Increase coverage for network_layer.py
aclark4life a3db870
PYTHON-5781 Nit fixes
aclark4life c156316
PYTHON-5781 Async split for test_network_layer.py
aclark4life 26b46b9
Address Noah review by proxy
aclark4life 15a3581
Rename vars
aclark4life ff58841
Update OP_COMPRESSED comment
aclark4life 61326cd
Noah + Copilot review
aclark4life dd67fbf
PYTHON-5781 Drop trivial tests, add sync-only receive_message/receive…
aclark4life 5ba827c
PYTHON-5781 Fix TestReceiveData failures on PyPy
aclark4life 776cc06
Noah feedback
aclark4life 347b2fc
Noah feedback
aclark4life 1821730
Merge branch 'master' into PYTHON-5781
aclark4life File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,154 @@ | ||
| # Copyright 2026-present MongoDB, Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Async-only unit tests for network_layer.py.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import asyncio | ||
| import struct | ||
| import sys | ||
| from unittest.mock import AsyncMock, MagicMock, patch | ||
|
|
||
| sys.path[0:0] = [""] | ||
|
|
||
| from test.asynchronous import AsyncUnitTest, unittest | ||
|
|
||
| from pymongo.common import MAX_MESSAGE_SIZE | ||
| from pymongo.errors import ProtocolError | ||
| from pymongo.network_layer import PyMongoProtocol, _async_socket_receive | ||
|
|
||
|
|
||
| def _make_protocol(timeout=None): | ||
| # PyMongoProtocol.__init__ calls asyncio.get_running_loop(), so this helper | ||
| # must be called from inside an async test method. | ||
| protocol = PyMongoProtocol(timeout=timeout) | ||
| mock_transport = MagicMock() | ||
| mock_transport.is_closing.return_value = False | ||
| protocol.transport = mock_transport | ||
| return protocol | ||
|
|
||
|
|
||
| def _make_header(length, request_id, response_to, op_code): | ||
| return struct.pack("<iiii", length, request_id, response_to, op_code) | ||
|
|
||
|
|
||
| class TestPyMongoProtocol(AsyncUnitTest): | ||
| def _make_proto_with_header(self, header_bytes, max_size=MAX_MESSAGE_SIZE): | ||
| protocol = _make_protocol() | ||
| protocol._max_message_size = max_size | ||
| protocol._header = memoryview(bytearray(header_bytes)) | ||
| return protocol | ||
|
|
||
| def test_normal_op_msg(self): | ||
| header = _make_header(length=32, request_id=1, response_to=99, op_code=2013) | ||
| protocol = self._make_proto_with_header(header) | ||
| body_len, op_code, response_to, expecting_compression = protocol.process_header() | ||
| self.assertEqual(body_len, 16) | ||
| self.assertEqual(op_code, 2013) | ||
| self.assertEqual(response_to, 99) | ||
| self.assertFalse(expecting_compression) | ||
|
|
||
| def test_op_compressed(self): | ||
| # OP_COMPRESSED=2012; process_header strips the 9-byte compression sub-header | ||
| # (op code + uncompressed size + compressor id), then the 16-byte standard header. | ||
| # length=35 → after compression sub-header: 26 → body: 10 | ||
| header = _make_header(length=35, request_id=1, response_to=0, op_code=2012) | ||
| protocol = self._make_proto_with_header(header) | ||
| body_len, op_code, _response_to, expecting_compression = protocol.process_header() | ||
| self.assertEqual(body_len, 10) | ||
| self.assertEqual(op_code, 2012) | ||
| self.assertTrue(expecting_compression) | ||
|
|
||
| def test_op_compressed_length_too_small_raises(self): | ||
| header = _make_header(length=25, request_id=1, response_to=0, op_code=2012) | ||
| protocol = self._make_proto_with_header(header) | ||
| with self.assertRaises(ProtocolError): | ||
| protocol.process_header() | ||
|
|
||
| def test_non_compressed_length_too_small_raises(self): | ||
| header = _make_header(length=16, request_id=1, response_to=0, op_code=2013) | ||
| protocol = self._make_proto_with_header(header) | ||
| with self.assertRaises(ProtocolError): | ||
| protocol.process_header() | ||
|
|
||
| def test_length_exceeds_max_raises(self): | ||
| header = _make_header( | ||
| length=MAX_MESSAGE_SIZE + 1, request_id=1, response_to=0, op_code=2013 | ||
| ) | ||
| protocol = self._make_proto_with_header(header) | ||
| with self.assertRaises(ProtocolError): | ||
| protocol.process_header() | ||
|
|
||
| def test_op_reply_op_code(self): | ||
| header = _make_header(length=20, request_id=0, response_to=0, op_code=1) | ||
| protocol = self._make_proto_with_header(header) | ||
| body_len, op_code, _response_to, expecting_compression = protocol.process_header() | ||
| self.assertEqual(body_len, 4) | ||
| self.assertEqual(op_code, 1) | ||
| self.assertFalse(expecting_compression) | ||
|
|
||
| def test_compression_header_snappy_compressor_id(self): | ||
| protocol = _make_protocol() | ||
| # <iiB: little-endian, i32 op code=2013, i32 uncompressed size=0, u8 compressor id=1 (snappy) | ||
| data = struct.pack("<iiB", 2013, 0, 1) | ||
| protocol._compression_header = memoryview(bytearray(data)) | ||
| op_code, compressor_id = protocol.process_compression_header() | ||
| self.assertEqual(op_code, 2013) | ||
| self.assertEqual(compressor_id, 1) | ||
|
|
||
| def test_compression_header_zlib_compressor_id(self): | ||
| protocol = _make_protocol() | ||
| data = struct.pack("<iiB", 2013, 0, 2) | ||
| protocol._compression_header = memoryview(bytearray(data)) | ||
| _, compressor_id = protocol.process_compression_header() | ||
| self.assertEqual(compressor_id, 2) | ||
|
|
||
| def test_close_aborts_transport(self): | ||
| protocol = _make_protocol() | ||
| protocol.close() | ||
| self.assertTrue(protocol.transport.abort.called) | ||
|
|
||
| def test_connection_lost_twice_does_not_raise(self): | ||
| protocol = _make_protocol() | ||
| protocol.connection_lost(None) | ||
| protocol.connection_lost(None) | ||
|
|
||
| async def test_close_with_exception_propagates_to_pending(self): | ||
| protocol = _make_protocol() | ||
| future = asyncio.get_running_loop().create_future() | ||
| protocol._pending_messages.append(future) | ||
| exc = OSError("connection reset") | ||
| protocol.close(exc) | ||
| with self.assertRaisesRegex(OSError, "connection reset"): | ||
| await future | ||
|
|
||
|
|
||
| class TestAsyncSocketReceive(AsyncUnitTest): | ||
| async def test_raises_on_connection_closed(self): | ||
| # Covers the explicit `raise OSError("connection closed")` branch when | ||
| # sock_recv_into returns 0. | ||
| mock_socket = MagicMock() | ||
| loop = asyncio.get_running_loop() | ||
|
|
||
| async def fake_recv_into(sock, buf): | ||
| return 0 | ||
|
|
||
| with patch.object(loop, "sock_recv_into", new=AsyncMock(side_effect=fake_recv_into)): | ||
| with self.assertRaisesRegex(OSError, "connection closed"): | ||
| await _async_socket_receive(mock_socket, 10, loop) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,130 @@ | ||
| # Copyright 2026-present MongoDB, Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # https://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| """Sync-only unit tests for network_layer.py. | ||
|
|
||
| These cover ``receive_message`` and ``receive_data``, which only exist on the | ||
| synchronous receive path (the async path uses ``PyMongoProtocol`` instead). | ||
| The async-only tests live in ``test/asynchronous/test_async_network_layer.py``. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import struct | ||
| import sys | ||
| from unittest.mock import MagicMock, patch | ||
|
|
||
| sys.path[0:0] = [""] | ||
|
|
||
| from test import UnitTest, unittest | ||
|
|
||
| from pymongo import network_layer | ||
| from pymongo.common import MAX_MESSAGE_SIZE | ||
| from pymongo.errors import ProtocolError | ||
|
|
||
|
|
||
| def _make_header(length, request_id, response_to, op_code): | ||
| return struct.pack("<iiii", length, request_id, response_to, op_code) | ||
|
|
||
|
|
||
| def _make_compression_header(op_code, uncompressed_size, compressor_id): | ||
| return struct.pack("<iiB", op_code, uncompressed_size, compressor_id) | ||
|
|
||
|
|
||
| def _make_conn(): | ||
| conn = MagicMock() | ||
| conn.conn.gettimeout.return_value = None | ||
| # PyPy calls wait_for_read() before recv_into(), which checks fileno() == -1 | ||
| # as an early-exit. Without this, sock.fileno() returns a MagicMock and the | ||
| # subsequent sock.pending() > 0 comparison raises TypeError on PyPy. | ||
| conn.conn.sock.fileno.return_value = -1 | ||
| return conn | ||
|
|
||
|
|
||
| class TestReceiveMessage(UnitTest): | ||
| def _patch_receive_data(self, *chunks): | ||
| """Make receive_data return the given byte strings on successive calls.""" | ||
| mock = patch.object(network_layer, "receive_data", side_effect=list(chunks)) | ||
| self.addCleanup(mock.stop) | ||
| return mock.start() | ||
|
|
||
| def test_request_id_mismatch_raises(self): | ||
| self._patch_receive_data( | ||
| _make_header(length=32, request_id=0, response_to=99, op_code=2013) | ||
| ) | ||
| with self.assertRaises(ProtocolError): | ||
| network_layer.receive_message(_make_conn(), request_id=1) | ||
|
|
||
| def test_length_too_small_raises(self): | ||
| self._patch_receive_data(_make_header(length=16, request_id=0, response_to=0, op_code=2013)) | ||
| with self.assertRaisesRegex(ProtocolError, "not longer than standard message header"): | ||
| network_layer.receive_message(_make_conn(), request_id=None) | ||
|
|
||
| def test_length_exceeds_max_raises(self): | ||
| self._patch_receive_data( | ||
| _make_header(length=MAX_MESSAGE_SIZE + 1, request_id=0, response_to=0, op_code=2013) | ||
| ) | ||
| with self.assertRaisesRegex(ProtocolError, "larger than server max"): | ||
| network_layer.receive_message(_make_conn(), request_id=None) | ||
|
|
||
| def test_normal_op_msg_unpacks(self): | ||
| body = b"x" * 16 | ||
| self._patch_receive_data( | ||
| _make_header(length=32, request_id=0, response_to=0, op_code=2013), body | ||
| ) | ||
| unpack = MagicMock(return_value="REPLY") | ||
| with patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}): | ||
| result = network_layer.receive_message(_make_conn(), request_id=None) | ||
| unpack.assert_called_once_with(body) | ||
| self.assertEqual(result, "REPLY") | ||
|
|
||
| def test_op_compressed_decompresses(self): | ||
| # length=35 -> body length = 35 - 25 = 10 (header 16 + compression sub-header 9). | ||
| compressed_body = b"y" * 10 | ||
| self._patch_receive_data( | ||
| _make_header(length=35, request_id=0, response_to=0, op_code=2012), | ||
| _make_compression_header(op_code=2013, uncompressed_size=0, compressor_id=1), | ||
| compressed_body, | ||
| ) | ||
| unpack = MagicMock(return_value="REPLY") | ||
| with ( | ||
| patch.object(network_layer, "decompress", return_value=b"decompressed") as decompress, | ||
| patch.object(network_layer, "_UNPACK_REPLY", {2013: unpack}), | ||
| ): | ||
| result = network_layer.receive_message(_make_conn(), request_id=None) | ||
| decompress.assert_called_once_with(compressed_body, 1) | ||
| unpack.assert_called_once_with(b"decompressed") | ||
| self.assertEqual(result, "REPLY") | ||
|
|
||
| def test_unknown_opcode_raises(self): | ||
| self._patch_receive_data( | ||
| _make_header(length=20, request_id=0, response_to=0, op_code=9999), b"data" | ||
| ) | ||
| with patch.object(network_layer, "_UNPACK_REPLY", {2013: MagicMock()}): | ||
| with self.assertRaises(ProtocolError): | ||
| network_layer.receive_message(_make_conn(), request_id=None) | ||
|
|
||
|
|
||
| class TestReceiveData(UnitTest): | ||
| def test_raises_on_connection_closed(self): | ||
| # Covers the explicit `raise OSError("connection closed")` branch when | ||
| # recv_into returns 0. | ||
| conn = _make_conn() | ||
| conn.conn.recv_into.return_value = 0 | ||
| with self.assertRaisesRegex(OSError, "connection closed"): | ||
| network_layer.receive_data(conn, 10, deadline=None) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similarly, the tests in this class that do not perform asynchronous I/O don't need to be
async defeither.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not addressed.