From e4a6ce26c10129c4877a147bedebf2985b582d13 Mon Sep 17 00:00:00 2001 From: Andrea Cavalli Date: Sat, 30 Mar 2024 20:09:07 +0100 Subject: [PATCH] Fix parallel async client --- pyrogram_rockserver_storage/__init__.py | 50 +++++++++++++++++++++++-- 1 file changed, 47 insertions(+), 3 deletions(-) diff --git a/pyrogram_rockserver_storage/__init__.py b/pyrogram_rockserver_storage/__init__.py index 549aeca..f39bb00 100644 --- a/pyrogram_rockserver_storage/__init__.py +++ b/pyrogram_rockserver_storage/__init__.py @@ -4,6 +4,8 @@ __version__ = '0.2' import asyncio import pathlib import time +import urllib +import warnings from enum import Enum from itertools import chain from string import digits @@ -11,7 +13,8 @@ from typing import Any, List, Tuple, Dict, Optional from pyrogram import raw, utils from pyrogram.storage import Storage -from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory +from thriftpy2.contrib.aio.socket import TAsyncSocket +from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory from thriftpy2.transport.framed import TFramedTransportFactory @@ -112,9 +115,49 @@ class RockServerStorage(Storage): super().__init__(name=self._session_id) + async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None, + proto_factory=TAsyncBinaryProtocolFactory(), + trans_factory=TAsyncBufferedTransportFactory(), + timeout=3000, connect_timeout=None, + cafile=None, ssl_context=None, + certfile=None, keyfile=None, + validate=True, url='', + socket_timeout=None): + if socket_timeout is not None: + warnings.warn( + "The 'socket_timeout' argument is deprecated. " + "Please use 'timeout' instead.", + DeprecationWarning, + ) + timeout = socket_timeout + if url: + parsed_url = urllib.parse.urlparse(url) + host = parsed_url.hostname or host + port = parsed_url.port or port + if unix_socket: + socket = TAsyncSocket(unix_socket=unix_socket, + connect_timeout=connect_timeout, + socket_timeout=timeout) + if certfile: + warnings.warn("SSL only works with host:port, not unix_socket.") + elif host and port: + socket = TAsyncSocket( + host, port, + socket_timeout=timeout, connect_timeout=connect_timeout, + cafile=cafile, ssl_context=ssl_context, + certfile=certfile, keyfile=keyfile, validate=validate) + else: + raise ValueError("Either host/port or unix_socket" + " or url must be provided.") + + transport = trans_factory.get_transport(socket) + protocol = proto_factory.get_protocol(transport) + await transport.open() + return TParallelAsyncClient(service, protocol) + async def open(self): """ Initialize pyrogram session""" - self._client = TParallelAsyncClient(await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)) + self._client = await self.make_parallel_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000) # Column('dc_id', BIGINT, primary_key=True), # Column('api_id', BIGINT), @@ -134,7 +177,8 @@ class RockServerStorage(Storage): # Column('last_update_on', BIGINT), self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True)) - self._session_data = await fetchone(self._client, self._session_col, SESSION_KEY) + fetched_session_data = await fetchone(self._client, self._session_col, SESSION_KEY) + self._session_data = fetched_session_data if fetched_session_data is not None else self._session_data async def save(self): """ On save we update the date """