Fix parallel async client

This commit is contained in:
Andrea Cavalli 2024-03-30 20:09:07 +01:00
parent f7f0fd99d2
commit e4a6ce26c1
1 changed files with 47 additions and 3 deletions

View File

@ -4,6 +4,8 @@ __version__ = '0.2'
import asyncio
import pathlib
import time
import urllib
import warnings
from enum import Enum
from itertools import chain
from string import digits
@ -11,7 +13,8 @@ from typing import Any, List, Tuple, Dict, Optional
from pyrogram import raw, utils
from pyrogram.storage import Storage
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory
from thriftpy2.contrib.aio.socket import TAsyncSocket
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
from thriftpy2.transport.framed import TFramedTransportFactory
@ -112,9 +115,49 @@ class RockServerStorage(Storage):
super().__init__(name=self._session_id)
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
proto_factory=TAsyncBinaryProtocolFactory(),
trans_factory=TAsyncBufferedTransportFactory(),
timeout=3000, connect_timeout=None,
cafile=None, ssl_context=None,
certfile=None, keyfile=None,
validate=True, url='',
socket_timeout=None):
if socket_timeout is not None:
warnings.warn(
"The 'socket_timeout' argument is deprecated. "
"Please use 'timeout' instead.",
DeprecationWarning,
)
timeout = socket_timeout
if url:
parsed_url = urllib.parse.urlparse(url)
host = parsed_url.hostname or host
port = parsed_url.port or port
if unix_socket:
socket = TAsyncSocket(unix_socket=unix_socket,
connect_timeout=connect_timeout,
socket_timeout=timeout)
if certfile:
warnings.warn("SSL only works with host:port, not unix_socket.")
elif host and port:
socket = TAsyncSocket(
host, port,
socket_timeout=timeout, connect_timeout=connect_timeout,
cafile=cafile, ssl_context=ssl_context,
certfile=certfile, keyfile=keyfile, validate=validate)
else:
raise ValueError("Either host/port or unix_socket"
" or url must be provided.")
transport = trans_factory.get_transport(socket)
protocol = proto_factory.get_protocol(transport)
await transport.open()
return TParallelAsyncClient(service, protocol)
async def open(self):
""" Initialize pyrogram session"""
self._client = TParallelAsyncClient(await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000))
self._client = await self.make_parallel_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)
# Column('dc_id', BIGINT, primary_key=True),
# Column('api_id', BIGINT),
@ -134,7 +177,8 @@ class RockServerStorage(Storage):
# Column('last_update_on', BIGINT),
self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True))
self._session_data = await fetchone(self._client, self._session_col, SESSION_KEY)
fetched_session_data = await fetchone(self._client, self._session_col, SESSION_KEY)
self._session_data = fetched_session_data if fetched_session_data is not None else self._session_data
async def save(self):
""" On save we update the date """