1
0
forked from sass/tipibot

Refactor db error catch to one helper

This commit is contained in:
Rene Arumetsa
2026-05-04 17:31:16 +03:00
parent 6d344a47f4
commit 192888625e
2 changed files with 79 additions and 89 deletions

View File

@@ -16,17 +16,13 @@ from typing import TypedDict
import aiohttp import aiohttp
from . import pb_client from . import pb_client
from .pb_client import DatabaseError
import strings import strings
_txn_log = logging.getLogger("tipiCOIN.txn") _txn_log = logging.getLogger("tipiCOIN.txn")
class DatabaseError(Exception):
"""Raised when PocketBase is unreachable or returns an error."""
pass
def _txn(event: str, **fields) -> None: def _txn(event: str, **fields) -> None:
"""Log a single economy transaction to the transactions logger.""" """Log a single economy transaction to the transactions logger."""
body = " ".join(f"{k}={v}" for k, v in fields.items()) body = " ".join(f"{k}={v}" for k, v in fields.items())

View File

@@ -21,6 +21,11 @@ import aiohttp
import config import config
class DatabaseError(Exception):
"""Raised when PocketBase is unreachable or returns an error."""
pass
_log = logging.getLogger("tipiCOIN.pb") _log = logging.getLogger("tipiCOIN.pb")
PB_URL = config.PB_URL PB_URL = config.PB_URL
@@ -57,17 +62,20 @@ async def _ensure_auth() -> str:
if time.monotonic() < _token_expiry: if time.monotonic() < _token_expiry:
return _token return _token
session = _get_session() session = _get_session()
async with session.post( try:
f"{PB_URL}/api/collections/_superusers/auth-with-password", async with session.post(
json={"identity": PB_ADMIN_EMAIL, "password": PB_ADMIN_PASSWORD}, f"{PB_URL}/api/collections/_superusers/auth-with-password",
) as resp: json={"identity": PB_ADMIN_EMAIL, "password": PB_ADMIN_PASSWORD},
if resp.status != 200: ) as resp:
text = await resp.text() if resp.status != 200:
raise RuntimeError(f"PocketBase auth failed ({resp.status}): {text}") text = await resp.text()
data = await resp.json() raise DatabaseError(f"PocketBase auth failed ({resp.status}): {text}")
_token = data["token"] data = await resp.json()
_token_expiry = time.monotonic() + 13 * 24 * 3600 # refresh well before expiry _token = data["token"]
_log.debug("PocketBase admin token refreshed") _token_expiry = time.monotonic() + 13 * 24 * 3600 # refresh well before expiry
_log.debug("PocketBase admin token refreshed")
except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e:
raise DatabaseError(f"Database unavailable: {e}") from e
return _token return _token
@@ -80,101 +88,87 @@ def _invalidate_token() -> None:
_token_expiry = 0.0 _token_expiry = 0.0
# ---------------------------------------------------------------------------
# Request helper with auth-retry and error wrapping
# ---------------------------------------------------------------------------
async def _request(method: str, url: str, **kwargs: Any) -> Any:
"""Make an authenticated request, retrying once on 401/403 by re-authing.
Returns the parsed JSON body. Raises DatabaseError on connection issues or
non-2xx responses after retrying.
"""
session = _get_session()
for attempt in range(2):
kwargs["headers"] = await _hdrs()
try:
async with session.request(method, url, **kwargs) as resp:
if resp.status in (401, 403) and attempt == 0:
_invalidate_token()
continue
if not resp.ok:
text = await resp.text()
raise DatabaseError(f"Database unavailable: {resp.status}, {text}")
return await resp.json()
except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e:
raise DatabaseError(f"Database unavailable: {e}") from e
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# CRUD helpers # CRUD helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def get_record(user_id: str) -> dict[str, Any] | None: async def get_record(user_id: str) -> dict[str, Any] | None:
"""Fetch one economy record by Discord user_id. Returns None if not found.""" """Fetch one economy record by Discord user_id. Returns None if not found."""
session = _get_session() data = await _request(
for attempt in range(2): "GET",
async with session.get( f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records", params={"filter": f'user_id="{user_id}"', "perPage": 1},
params={"filter": f'user_id="{user_id}"', "perPage": 1}, )
headers=await _hdrs(), items = data.get("items", [])
) as resp: return items[0] if items else None
if resp.status == 403 and attempt == 0:
_invalidate_token()
continue
resp.raise_for_status()
data = await resp.json()
items = data.get("items", [])
return items[0] if items else None
async def create_record(record: dict[str, Any]) -> dict[str, Any]: async def create_record(record: dict[str, Any]) -> dict[str, Any]:
"""Create a new economy record. Returns the created record (includes PB id).""" """Create a new economy record. Returns the created record (includes PB id)."""
session = _get_session() return await _request(
for attempt in range(2): "POST",
async with session.post( f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records", json=record,
json=record, )
headers=await _hdrs(),
) as resp:
if resp.status == 403 and attempt == 0:
_invalidate_token()
continue
if resp.status not in (200, 201):
text = await resp.text()
raise RuntimeError(f"PocketBase create failed ({resp.status}): {text}")
return await resp.json()
async def update_record(record_id: str, data: dict[str, Any]) -> dict[str, Any]: async def update_record(record_id: str, data: dict[str, Any]) -> dict[str, Any]:
"""PATCH an existing record by its PocketBase record id.""" """PATCH an existing record by its PocketBase record id."""
session = _get_session() return await _request(
for attempt in range(2): "PATCH",
async with session.patch( f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records/{record_id}",
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records/{record_id}", json=data,
json=data, )
headers=await _hdrs(),
) as resp:
if resp.status == 403 and attempt == 0:
_invalidate_token()
continue
resp.raise_for_status()
return await resp.json()
async def count_records() -> int: async def count_records() -> int:
"""Return the total number of records in the collection (single cheap request).""" """Return the total number of records in the collection (single cheap request)."""
session = _get_session() data = await _request(
for attempt in range(2): "GET",
async with session.get( f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records", params={"perPage": 1, "page": 1},
params={"perPage": 1, "page": 1}, )
headers=await _hdrs(), return int(data.get("totalItems", 0))
) as resp:
if resp.status == 403 and attempt == 0:
_invalidate_token()
continue
resp.raise_for_status()
data = await resp.json()
return int(data.get("totalItems", 0))
async def list_all_records(page_size: int = 500) -> list[dict[str, Any]]: async def list_all_records(page_size: int = 500) -> list[dict[str, Any]]:
"""Fetch every record in the collection, handling PocketBase pagination.""" """Fetch every record in the collection, handling PocketBase pagination."""
results: list[dict[str, Any]] = [] results: list[dict[str, Any]] = []
page = 1 page = 1
session = _get_session()
while True: while True:
hdrs = await _hdrs() data = await _request(
for attempt in range(2): "GET",
async with session.get( f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records", params={"perPage": page_size, "page": page},
params={"perPage": page_size, "page": page}, )
headers=hdrs, batch = data.get("items", [])
) as resp: results.extend(batch)
if resp.status == 403 and attempt == 0: if len(batch) < page_size:
_invalidate_token() return results
hdrs = await _hdrs() page += 1
continue
resp.raise_for_status()
data = await resp.json()
batch = data.get("items", [])
results.extend(batch)
if len(batch) < page_size:
return results
page += 1
break