115 lines
4.0 KiB
Python
115 lines
4.0 KiB
Python
import asyncio
|
|
import logging
|
|
from typing import Optional
|
|
|
|
import psycopg
|
|
from aiokit import AioThing
|
|
from izihawa_utils.exceptions import BaseError
|
|
from psycopg.rows import tuple_row
|
|
from psycopg_pool import AsyncConnectionPool
|
|
|
|
|
|
class OperationalError(BaseError):
|
|
level = logging.WARNING
|
|
code = 'operational_error'
|
|
|
|
|
|
class AioPostgresPoolHolder(AioThing):
|
|
def __init__(self, conninfo, timeout=30, min_size=1, max_size=1, is_recycling=True):
|
|
super().__init__()
|
|
self.pool = None
|
|
self.fn = lambda: AsyncConnectionPool(
|
|
conninfo=conninfo,
|
|
timeout=timeout,
|
|
min_size=min_size,
|
|
max_size=max_size + int(is_recycling),
|
|
)
|
|
self.is_recycling = is_recycling
|
|
self.recycling_task = None
|
|
self.timeout = timeout
|
|
|
|
async def _get_connection(self):
|
|
ev = asyncio.Event()
|
|
conn = await self.pool.getconn()
|
|
asyncio.get_running_loop().add_reader(conn.fileno(), ev.set)
|
|
return ev, conn
|
|
|
|
async def recycling(self):
|
|
logging.getLogger('debug').debug({
|
|
'action': 'start_recycling',
|
|
'mode': 'pool',
|
|
'stats': self.pool.get_stats(),
|
|
})
|
|
ev, conn = await self._get_connection()
|
|
try:
|
|
while True:
|
|
try:
|
|
await asyncio.wait_for(ev.wait(), self.timeout)
|
|
except asyncio.TimeoutError:
|
|
continue
|
|
try:
|
|
await conn.execute("SELECT 1")
|
|
except psycopg.OperationalError:
|
|
asyncio.get_running_loop().remove_reader(conn.fileno())
|
|
await self.pool.putconn(conn)
|
|
await self.pool.check()
|
|
ev, conn = await self._get_connection()
|
|
except asyncio.CancelledError:
|
|
pass
|
|
finally:
|
|
await self.pool.putconn(conn)
|
|
logging.getLogger('debug').debug({
|
|
'action': 'stopped_recycling',
|
|
'mode': 'pool',
|
|
'stats': self.pool.get_stats(),
|
|
})
|
|
|
|
async def start(self):
|
|
if not self.pool:
|
|
self.pool = self.fn()
|
|
await self.pool.wait()
|
|
if self.is_recycling:
|
|
self.recycling_task = asyncio.create_task(self.recycling())
|
|
|
|
async def stop(self):
|
|
if self.pool:
|
|
if self.recycling_task:
|
|
self.recycling_task.cancel()
|
|
await self.recycling_task
|
|
self.recycling_task = None
|
|
logging.getLogger('debug').debug({
|
|
'action': 'close',
|
|
'mode': 'pool',
|
|
'stats': self.pool.get_stats(),
|
|
})
|
|
await self.pool.close()
|
|
self.pool = None
|
|
|
|
async def iterate(
|
|
self,
|
|
stmt: str,
|
|
values=None,
|
|
row_factory=tuple_row,
|
|
cursor_name: Optional[str] = None,
|
|
itersize: Optional[int] = None,
|
|
statement_timeout: Optional[int] = None,
|
|
):
|
|
if not self.pool:
|
|
raise RuntimeError('AioPostgresPoolHolder has not been started')
|
|
async with self.pool.connection() as conn:
|
|
async with conn.cursor(name=cursor_name, row_factory=row_factory) as cur:
|
|
if itersize is not None:
|
|
cur.itersize = itersize
|
|
await cur.execute(stmt + ';' if statement_timeout else '', values)
|
|
if statement_timeout:
|
|
await cur.execute(f'SET statement_timeout = {statement_timeout};')
|
|
async for row in cur:
|
|
yield row
|
|
|
|
async def execute(self, stmt: str, values=None, cursor_name: Optional[str] = None, row_factory=tuple_row):
|
|
if not self.pool:
|
|
raise RuntimeError('AioPostgresPoolHolder has not been started')
|
|
async with self.pool.connection() as conn:
|
|
async with conn.cursor(name=cursor_name, row_factory=row_factory) as cur:
|
|
await cur.execute(stmt, values)
|