1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
|
import json
import uuid
from typing import Any, Callable, Optional
import psycopg2.extras
from aiohttp import web
from aiopg import Pool
from aiohttp_session import AbstractStorage, Session
class PgStorage(AbstractStorage):
"""PG storage"""
def __init__( # type: ignore[no-any-unimported]
self,
pg_pool: Pool,
*,
cookie_name: str = "AIOHTTP_SESSION",
domain: Optional[str] = None,
max_age: Optional[int] = None,
path: str = "/",
secure: Optional[bool] = None,
httponly: bool = True,
key_factory: Callable[[], str] = lambda: uuid.uuid4().hex,
encoder: Callable[[object], str] = psycopg2.extras.Json,
decoder: Callable[[str], Any] = json.loads,
):
super().__init__(
cookie_name=cookie_name,
domain=domain,
max_age=max_age,
path=path,
secure=secure,
httponly=httponly,
encoder=encoder,
decoder=decoder,
)
self._pg = pg_pool
self._key_factory = key_factory
async def load_session(self, request: web.Request) -> Session:
cookie = self.load_cookie(request)
if cookie is None:
return Session(None, data={}, new=True, max_age=self.max_age)
else:
async with self._pg.acquire() as conn:
key = uuid.UUID(cookie)
async with conn.cursor(
cursor_factory=psycopg2.extras.DictCursor
) as cur:
await cur.execute(
"SELECT session, extract(epoch from created) "
+ "FROM web.sessions WHERE uuid = %s",
(key,),
)
data = await cur.fetchone()
if not data:
return Session(None, data={}, new=True, max_age=self.max_age)
return Session(key, data=data, new=False, max_age=self.max_age)
async def save_session(
self, request: web.Request, response: web.StreamResponse, session: Session
) -> None:
key = session.identity
if key is None:
key = self._key_factory()
self.save_cookie(response, key, max_age=session.max_age)
else:
if session.empty:
self.save_cookie(response, "", max_age=session.max_age)
else:
key = str(key)
self.save_cookie(response, key, max_age=session.max_age)
data = self._get_session_data(session)
if not data:
return
data_encoded = self._encoder(data["session"])
expire = data["created"] + (session.max_age or 0)
async with self._pg.acquire() as conn:
async with conn.cursor() as cur:
await cur.execute(
"INSERT INTO web.sessions (uuid,session,created,expire)"
+ " VALUES (%s, %s, to_timestamp(%s),to_timestamp(%s))"
+ " ON CONFLICT (uuid)"
+ " DO UPDATE"
+ " SET (session,expire)=(EXCLUDED.session, EXCLUDED.expire)",
[key, data_encoded, data["created"], expire],
)
|