python/qmp: allow sockets to be passed to connect()

Allow existing sockets to be passed to connect(). The changes are pretty
minimal, and this allows for far greater flexibility in setting up
communications with an endpoint.

Signed-off-by: John Snow <jsnow@redhat.com>
Message-id: 20230517163406.2593480-2-jsnow@redhat.com
Signed-off-by: John Snow <jsnow@redhat.com>
master
John Snow 2023-05-17 12:34:02 -04:00
parent ab72522797
commit 9341b2a6b9
1 changed files with 15 additions and 6 deletions

View File

@ -370,7 +370,7 @@ class AsyncProtocol(Generic[T]):
@upper_half
@require(Runstate.IDLE)
async def connect(self, address: SocketAddrT,
async def connect(self, address: Union[SocketAddrT, socket.socket],
ssl: Optional[SSLContext] = None) -> None:
"""
Connect to the server and begin processing message queues.
@ -615,7 +615,7 @@ class AsyncProtocol(Generic[T]):
self.logger.debug("Connection accepted.")
@upper_half
async def _do_connect(self, address: SocketAddrT,
async def _do_connect(self, address: Union[SocketAddrT, socket.socket],
ssl: Optional[SSLContext] = None) -> None:
"""
Acting as the transport client, initiate a connection to a server.
@ -634,9 +634,17 @@ class AsyncProtocol(Generic[T]):
# otherwise yield.
await asyncio.sleep(0)
self.logger.debug("Connecting to %s ...", address)
if isinstance(address, tuple):
if isinstance(address, socket.socket):
self.logger.debug("Connecting with existing socket: "
"fd=%d, family=%r, type=%r",
address.fileno(), address.family, address.type)
connect = asyncio.open_connection(
limit=self._limit,
ssl=ssl,
sock=address,
)
elif isinstance(address, tuple):
self.logger.debug("Connecting to %s ...", address)
connect = asyncio.open_connection(
address[0],
address[1],
@ -644,13 +652,14 @@ class AsyncProtocol(Generic[T]):
limit=self._limit,
)
else:
self.logger.debug("Connecting to file://%s ...", address)
connect = asyncio.open_unix_connection(
path=address,
ssl=ssl,
limit=self._limit,
)
self._reader, self._writer = await connect
self._reader, self._writer = await connect
self.logger.debug("Connected.")
@upper_half