2022-03-28 16:39:36 +02:00
|
|
|
from typing import Optional
|
|
|
|
|
2021-01-04 09:35:31 +01:00
|
|
|
from aiokit import AioThing
|
2022-03-28 16:39:36 +02:00
|
|
|
from psycopg.rows import tuple_row
|
|
|
|
from psycopg_pool import AsyncConnectionPool
|
2021-01-04 09:35:31 +01:00
|
|
|
|
|
|
|
|
|
|
|
class AioPostgresPoolHolder(AioThing):
|
2022-03-28 16:39:36 +02:00
|
|
|
def __init__(self, conninfo, timeout=30, min_size=1, max_size=4):
|
2021-01-04 09:35:31 +01:00
|
|
|
super().__init__()
|
|
|
|
self.pool = None
|
2022-03-28 16:39:36 +02:00
|
|
|
self.fn = lambda: AsyncConnectionPool(
|
|
|
|
conninfo=conninfo,
|
|
|
|
timeout=timeout,
|
|
|
|
min_size=min_size,
|
|
|
|
max_size=max_size,
|
|
|
|
)
|
2021-01-04 09:35:31 +01:00
|
|
|
|
|
|
|
async def start(self):
|
|
|
|
if not self.pool:
|
2022-03-28 16:39:36 +02:00
|
|
|
self.pool = self.fn()
|
2021-01-04 09:35:31 +01:00
|
|
|
|
|
|
|
async def stop(self):
|
|
|
|
if self.pool:
|
2022-03-28 16:39:36 +02:00
|
|
|
await self.pool.close()
|
2021-01-04 09:35:31 +01:00
|
|
|
self.pool = None
|
|
|
|
|
2022-03-28 16:39:36 +02:00
|
|
|
async def iterate(
|
|
|
|
self,
|
|
|
|
stmt: str,
|
|
|
|
values=None,
|
|
|
|
row_factory=tuple_row,
|
|
|
|
cursor_name: Optional[str] = None,
|
|
|
|
itersize: Optional[int] = None,
|
|
|
|
):
|
|
|
|
if not self.pool:
|
|
|
|
raise RuntimeError('AioPostgresPoolHolder has not been started')
|
|
|
|
async with self.pool.connection() as conn:
|
|
|
|
async with conn.cursor(name=cursor_name, row_factory=row_factory) as cur:
|
|
|
|
if itersize is not None:
|
|
|
|
cur.itersize = itersize
|
|
|
|
await cur.execute(stmt, values)
|
|
|
|
async for row in cur:
|
|
|
|
yield row
|
|
|
|
|
|
|
|
async def execute(self, stmt: str, values=None, cursor_name: Optional[str] = None, row_factory=tuple_row):
|
|
|
|
if not self.pool:
|
|
|
|
raise RuntimeError('AioPostgresPoolHolder has not been started')
|
|
|
|
async with self.pool.connection() as conn:
|
|
|
|
async with conn.cursor(name=cursor_name, row_factory=row_factory) as cur:
|
|
|
|
await cur.execute(stmt, values)
|