Fix parallel async client
This commit is contained in:
parent
f7f0fd99d2
commit
e4a6ce26c1
@ -4,6 +4,8 @@ __version__ = '0.2'
|
|||||||
import asyncio
|
import asyncio
|
||||||
import pathlib
|
import pathlib
|
||||||
import time
|
import time
|
||||||
|
import urllib
|
||||||
|
import warnings
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from string import digits
|
from string import digits
|
||||||
@ -11,7 +13,8 @@ from typing import Any, List, Tuple, Dict, Optional
|
|||||||
|
|
||||||
from pyrogram import raw, utils
|
from pyrogram import raw, utils
|
||||||
from pyrogram.storage import Storage
|
from pyrogram.storage import Storage
|
||||||
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory
|
from thriftpy2.contrib.aio.socket import TAsyncSocket
|
||||||
|
from thriftpy2.contrib.aio.transport import TAsyncFramedTransportFactory, TAsyncBufferedTransportFactory
|
||||||
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
|
from thriftpy2.contrib.aio.protocol import TAsyncBinaryProtocolFactory
|
||||||
|
|
||||||
from thriftpy2.transport.framed import TFramedTransportFactory
|
from thriftpy2.transport.framed import TFramedTransportFactory
|
||||||
@ -112,9 +115,49 @@ class RockServerStorage(Storage):
|
|||||||
|
|
||||||
super().__init__(name=self._session_id)
|
super().__init__(name=self._session_id)
|
||||||
|
|
||||||
|
async def make_parallel_aio_client(self, service, host='localhost', port=9090, unix_socket=None,
|
||||||
|
proto_factory=TAsyncBinaryProtocolFactory(),
|
||||||
|
trans_factory=TAsyncBufferedTransportFactory(),
|
||||||
|
timeout=3000, connect_timeout=None,
|
||||||
|
cafile=None, ssl_context=None,
|
||||||
|
certfile=None, keyfile=None,
|
||||||
|
validate=True, url='',
|
||||||
|
socket_timeout=None):
|
||||||
|
if socket_timeout is not None:
|
||||||
|
warnings.warn(
|
||||||
|
"The 'socket_timeout' argument is deprecated. "
|
||||||
|
"Please use 'timeout' instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
|
timeout = socket_timeout
|
||||||
|
if url:
|
||||||
|
parsed_url = urllib.parse.urlparse(url)
|
||||||
|
host = parsed_url.hostname or host
|
||||||
|
port = parsed_url.port or port
|
||||||
|
if unix_socket:
|
||||||
|
socket = TAsyncSocket(unix_socket=unix_socket,
|
||||||
|
connect_timeout=connect_timeout,
|
||||||
|
socket_timeout=timeout)
|
||||||
|
if certfile:
|
||||||
|
warnings.warn("SSL only works with host:port, not unix_socket.")
|
||||||
|
elif host and port:
|
||||||
|
socket = TAsyncSocket(
|
||||||
|
host, port,
|
||||||
|
socket_timeout=timeout, connect_timeout=connect_timeout,
|
||||||
|
cafile=cafile, ssl_context=ssl_context,
|
||||||
|
certfile=certfile, keyfile=keyfile, validate=validate)
|
||||||
|
else:
|
||||||
|
raise ValueError("Either host/port or unix_socket"
|
||||||
|
" or url must be provided.")
|
||||||
|
|
||||||
|
transport = trans_factory.get_transport(socket)
|
||||||
|
protocol = proto_factory.get_protocol(transport)
|
||||||
|
await transport.open()
|
||||||
|
return TParallelAsyncClient(service, protocol)
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
""" Initialize pyrogram session"""
|
""" Initialize pyrogram session"""
|
||||||
self._client = TParallelAsyncClient(await make_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000))
|
self._client = await self.make_parallel_aio_client(rocksdb_thrift.RocksDB, host=self._hostname, port=self._port, trans_factory=TRANSPORT_FACTORY, proto_factory=PROTOCOL_FACTORY, connect_timeout=8000)
|
||||||
|
|
||||||
# Column('dc_id', BIGINT, primary_key=True),
|
# Column('dc_id', BIGINT, primary_key=True),
|
||||||
# Column('api_id', BIGINT),
|
# Column('api_id', BIGINT),
|
||||||
@ -134,7 +177,8 @@ class RockServerStorage(Storage):
|
|||||||
# Column('last_update_on', BIGINT),
|
# Column('last_update_on', BIGINT),
|
||||||
self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True))
|
self._peer_col = await self._client.createColumn(name=f'peers_{self._session_id}', schema=rocksdb_thrift.ColumnSchema(fixedKeys=[8], variableTailKeys=[], hasValue=True))
|
||||||
|
|
||||||
self._session_data = await fetchone(self._client, self._session_col, SESSION_KEY)
|
fetched_session_data = await fetchone(self._client, self._session_col, SESSION_KEY)
|
||||||
|
self._session_data = fetched_session_data if fetched_session_data is not None else self._session_data
|
||||||
|
|
||||||
async def save(self):
|
async def save(self):
|
||||||
""" On save we update the date """
|
""" On save we update the date """
|
||||||
|
Loading…
Reference in New Issue
Block a user