forked from sass/tipibot
Refactor db error catch to one helper
This commit is contained in:
@@ -16,17 +16,13 @@ from typing import TypedDict
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from . import pb_client
|
from . import pb_client
|
||||||
|
from .pb_client import DatabaseError
|
||||||
|
|
||||||
import strings
|
import strings
|
||||||
|
|
||||||
_txn_log = logging.getLogger("tipiCOIN.txn")
|
_txn_log = logging.getLogger("tipiCOIN.txn")
|
||||||
|
|
||||||
|
|
||||||
class DatabaseError(Exception):
|
|
||||||
"""Raised when PocketBase is unreachable or returns an error."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def _txn(event: str, **fields) -> None:
|
def _txn(event: str, **fields) -> None:
|
||||||
"""Log a single economy transaction to the transactions logger."""
|
"""Log a single economy transaction to the transactions logger."""
|
||||||
body = " ".join(f"{k}={v}" for k, v in fields.items())
|
body = " ".join(f"{k}={v}" for k, v in fields.items())
|
||||||
|
|||||||
@@ -21,6 +21,11 @@ import aiohttp
|
|||||||
|
|
||||||
import config
|
import config
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseError(Exception):
|
||||||
|
"""Raised when PocketBase is unreachable or returns an error."""
|
||||||
|
pass
|
||||||
|
|
||||||
_log = logging.getLogger("tipiCOIN.pb")
|
_log = logging.getLogger("tipiCOIN.pb")
|
||||||
|
|
||||||
PB_URL = config.PB_URL
|
PB_URL = config.PB_URL
|
||||||
@@ -57,17 +62,20 @@ async def _ensure_auth() -> str:
|
|||||||
if time.monotonic() < _token_expiry:
|
if time.monotonic() < _token_expiry:
|
||||||
return _token
|
return _token
|
||||||
session = _get_session()
|
session = _get_session()
|
||||||
|
try:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
f"{PB_URL}/api/collections/_superusers/auth-with-password",
|
f"{PB_URL}/api/collections/_superusers/auth-with-password",
|
||||||
json={"identity": PB_ADMIN_EMAIL, "password": PB_ADMIN_PASSWORD},
|
json={"identity": PB_ADMIN_EMAIL, "password": PB_ADMIN_PASSWORD},
|
||||||
) as resp:
|
) as resp:
|
||||||
if resp.status != 200:
|
if resp.status != 200:
|
||||||
text = await resp.text()
|
text = await resp.text()
|
||||||
raise RuntimeError(f"PocketBase auth failed ({resp.status}): {text}")
|
raise DatabaseError(f"PocketBase auth failed ({resp.status}): {text}")
|
||||||
data = await resp.json()
|
data = await resp.json()
|
||||||
_token = data["token"]
|
_token = data["token"]
|
||||||
_token_expiry = time.monotonic() + 13 * 24 * 3600 # refresh well before expiry
|
_token_expiry = time.monotonic() + 13 * 24 * 3600 # refresh well before expiry
|
||||||
_log.debug("PocketBase admin token refreshed")
|
_log.debug("PocketBase admin token refreshed")
|
||||||
|
except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e:
|
||||||
|
raise DatabaseError(f"Database unavailable: {e}") from e
|
||||||
return _token
|
return _token
|
||||||
|
|
||||||
|
|
||||||
@@ -80,76 +88,72 @@ def _invalidate_token() -> None:
|
|||||||
_token_expiry = 0.0
|
_token_expiry = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Request helper with auth-retry and error wrapping
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _request(method: str, url: str, **kwargs: Any) -> Any:
|
||||||
|
"""Make an authenticated request, retrying once on 401/403 by re-authing.
|
||||||
|
|
||||||
|
Returns the parsed JSON body. Raises DatabaseError on connection issues or
|
||||||
|
non-2xx responses after retrying.
|
||||||
|
"""
|
||||||
|
session = _get_session()
|
||||||
|
for attempt in range(2):
|
||||||
|
kwargs["headers"] = await _hdrs()
|
||||||
|
try:
|
||||||
|
async with session.request(method, url, **kwargs) as resp:
|
||||||
|
if resp.status in (401, 403) and attempt == 0:
|
||||||
|
_invalidate_token()
|
||||||
|
continue
|
||||||
|
if not resp.ok:
|
||||||
|
text = await resp.text()
|
||||||
|
raise DatabaseError(f"Database unavailable: {resp.status}, {text}")
|
||||||
|
return await resp.json()
|
||||||
|
except (aiohttp.ClientConnectorError, asyncio.TimeoutError) as e:
|
||||||
|
raise DatabaseError(f"Database unavailable: {e}") from e
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# CRUD helpers
|
# CRUD helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
async def get_record(user_id: str) -> dict[str, Any] | None:
|
async def get_record(user_id: str) -> dict[str, Any] | None:
|
||||||
"""Fetch one economy record by Discord user_id. Returns None if not found."""
|
"""Fetch one economy record by Discord user_id. Returns None if not found."""
|
||||||
session = _get_session()
|
data = await _request(
|
||||||
for attempt in range(2):
|
"GET",
|
||||||
async with session.get(
|
|
||||||
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
||||||
params={"filter": f'user_id="{user_id}"', "perPage": 1},
|
params={"filter": f'user_id="{user_id}"', "perPage": 1},
|
||||||
headers=await _hdrs(),
|
)
|
||||||
) as resp:
|
|
||||||
if resp.status == 403 and attempt == 0:
|
|
||||||
_invalidate_token()
|
|
||||||
continue
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = await resp.json()
|
|
||||||
items = data.get("items", [])
|
items = data.get("items", [])
|
||||||
return items[0] if items else None
|
return items[0] if items else None
|
||||||
|
|
||||||
|
|
||||||
async def create_record(record: dict[str, Any]) -> dict[str, Any]:
|
async def create_record(record: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Create a new economy record. Returns the created record (includes PB id)."""
|
"""Create a new economy record. Returns the created record (includes PB id)."""
|
||||||
session = _get_session()
|
return await _request(
|
||||||
for attempt in range(2):
|
"POST",
|
||||||
async with session.post(
|
|
||||||
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
||||||
json=record,
|
json=record,
|
||||||
headers=await _hdrs(),
|
)
|
||||||
) as resp:
|
|
||||||
if resp.status == 403 and attempt == 0:
|
|
||||||
_invalidate_token()
|
|
||||||
continue
|
|
||||||
if resp.status not in (200, 201):
|
|
||||||
text = await resp.text()
|
|
||||||
raise RuntimeError(f"PocketBase create failed ({resp.status}): {text}")
|
|
||||||
return await resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
async def update_record(record_id: str, data: dict[str, Any]) -> dict[str, Any]:
|
async def update_record(record_id: str, data: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""PATCH an existing record by its PocketBase record id."""
|
"""PATCH an existing record by its PocketBase record id."""
|
||||||
session = _get_session()
|
return await _request(
|
||||||
for attempt in range(2):
|
"PATCH",
|
||||||
async with session.patch(
|
|
||||||
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records/{record_id}",
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records/{record_id}",
|
||||||
json=data,
|
json=data,
|
||||||
headers=await _hdrs(),
|
)
|
||||||
) as resp:
|
|
||||||
if resp.status == 403 and attempt == 0:
|
|
||||||
_invalidate_token()
|
|
||||||
continue
|
|
||||||
resp.raise_for_status()
|
|
||||||
return await resp.json()
|
|
||||||
|
|
||||||
|
|
||||||
async def count_records() -> int:
|
async def count_records() -> int:
|
||||||
"""Return the total number of records in the collection (single cheap request)."""
|
"""Return the total number of records in the collection (single cheap request)."""
|
||||||
session = _get_session()
|
data = await _request(
|
||||||
for attempt in range(2):
|
"GET",
|
||||||
async with session.get(
|
|
||||||
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
||||||
params={"perPage": 1, "page": 1},
|
params={"perPage": 1, "page": 1},
|
||||||
headers=await _hdrs(),
|
)
|
||||||
) as resp:
|
|
||||||
if resp.status == 403 and attempt == 0:
|
|
||||||
_invalidate_token()
|
|
||||||
continue
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = await resp.json()
|
|
||||||
return int(data.get("totalItems", 0))
|
return int(data.get("totalItems", 0))
|
||||||
|
|
||||||
|
|
||||||
@@ -157,24 +161,14 @@ async def list_all_records(page_size: int = 500) -> list[dict[str, Any]]:
|
|||||||
"""Fetch every record in the collection, handling PocketBase pagination."""
|
"""Fetch every record in the collection, handling PocketBase pagination."""
|
||||||
results: list[dict[str, Any]] = []
|
results: list[dict[str, Any]] = []
|
||||||
page = 1
|
page = 1
|
||||||
session = _get_session()
|
|
||||||
while True:
|
while True:
|
||||||
hdrs = await _hdrs()
|
data = await _request(
|
||||||
for attempt in range(2):
|
"GET",
|
||||||
async with session.get(
|
|
||||||
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
||||||
params={"perPage": page_size, "page": page},
|
params={"perPage": page_size, "page": page},
|
||||||
headers=hdrs,
|
)
|
||||||
) as resp:
|
|
||||||
if resp.status == 403 and attempt == 0:
|
|
||||||
_invalidate_token()
|
|
||||||
hdrs = await _hdrs()
|
|
||||||
continue
|
|
||||||
resp.raise_for_status()
|
|
||||||
data = await resp.json()
|
|
||||||
batch = data.get("items", [])
|
batch = data.get("items", [])
|
||||||
results.extend(batch)
|
results.extend(batch)
|
||||||
if len(batch) < page_size:
|
if len(batch) < page_size:
|
||||||
return results
|
return results
|
||||||
page += 1
|
page += 1
|
||||||
break
|
|
||||||
|
|||||||
Reference in New Issue
Block a user