153 lines
5.1 KiB
Python
153 lines
5.1 KiB
Python
"""Async PocketBase REST client for TipiLAN Bot.
|
|
|
|
Handles admin authentication (auto-refreshed), and CRUD operations on the
|
|
economy_users collection. Uses aiohttp, which discord.py already depends on.
|
|
|
|
Environment variables (set in .env):
|
|
PB_URL Base URL of PocketBase (default: http://127.0.0.1:8090)
|
|
PB_ADMIN_EMAIL PocketBase admin e-mail
|
|
PB_ADMIN_PASSWORD PocketBase admin password
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import logging
|
|
import os
|
|
import time
|
|
from typing import Any
|
|
|
|
import aiohttp
|
|
|
|
_log = logging.getLogger("tipiCOIN.pb")
|
|
|
|
PB_URL = os.getenv("PB_URL", "http://127.0.0.1:8090")
|
|
PB_ADMIN_EMAIL = os.getenv("PB_ADMIN_EMAIL", "")
|
|
PB_ADMIN_PASSWORD = os.getenv("PB_ADMIN_PASSWORD", "")
|
|
ECONOMY_COLLECTION = "economy_users"
|
|
|
|
_TIMEOUT = aiohttp.ClientTimeout(total=10)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Persistent session (created once, reused for the lifetime of the process)
|
|
# ---------------------------------------------------------------------------
|
|
_session: aiohttp.ClientSession | None = None
|
|
|
|
|
|
def _get_session() -> aiohttp.ClientSession:
|
|
global _session
|
|
if _session is None or _session.closed:
|
|
_session = aiohttp.ClientSession(timeout=_TIMEOUT)
|
|
return _session
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Auth token cache
|
|
# ---------------------------------------------------------------------------
|
|
_token: str = ""
|
|
_token_expiry: float = 0.0
|
|
_auth_lock = asyncio.Lock()
|
|
|
|
|
|
async def _ensure_auth() -> str:
|
|
global _token, _token_expiry
|
|
async with _auth_lock:
|
|
if time.monotonic() < _token_expiry:
|
|
return _token
|
|
session = _get_session()
|
|
async with session.post(
|
|
f"{PB_URL}/api/collections/_superusers/auth-with-password",
|
|
json={"identity": PB_ADMIN_EMAIL, "password": PB_ADMIN_PASSWORD},
|
|
) as resp:
|
|
if resp.status != 200:
|
|
text = await resp.text()
|
|
raise RuntimeError(f"PocketBase auth failed ({resp.status}): {text}")
|
|
data = await resp.json()
|
|
_token = data["token"]
|
|
_token_expiry = time.monotonic() + 13 * 24 * 3600 # refresh well before expiry
|
|
_log.debug("PocketBase admin token refreshed")
|
|
return _token
|
|
|
|
|
|
async def _hdrs() -> dict[str, str]:
|
|
return {"Authorization": await _ensure_auth()}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CRUD helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
async def get_record(user_id: str) -> dict[str, Any] | None:
|
|
"""Fetch one economy record by Discord user_id. Returns None if not found."""
|
|
session = _get_session()
|
|
async with session.get(
|
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
|
params={"filter": f'user_id="{user_id}"', "perPage": 1},
|
|
headers=await _hdrs(),
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
data = await resp.json()
|
|
items = data.get("items", [])
|
|
return items[0] if items else None
|
|
|
|
|
|
async def create_record(record: dict[str, Any]) -> dict[str, Any]:
|
|
"""Create a new economy record. Returns the created record (includes PB id)."""
|
|
session = _get_session()
|
|
async with session.post(
|
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
|
json=record,
|
|
headers=await _hdrs(),
|
|
) as resp:
|
|
if resp.status not in (200, 201):
|
|
text = await resp.text()
|
|
raise RuntimeError(f"PocketBase create failed ({resp.status}): {text}")
|
|
return await resp.json()
|
|
|
|
|
|
async def update_record(record_id: str, data: dict[str, Any]) -> dict[str, Any]:
|
|
"""PATCH an existing record by its PocketBase record id."""
|
|
session = _get_session()
|
|
async with session.patch(
|
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records/{record_id}",
|
|
json=data,
|
|
headers=await _hdrs(),
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
return await resp.json()
|
|
|
|
|
|
async def count_records() -> int:
|
|
"""Return the total number of records in the collection (single cheap request)."""
|
|
session = _get_session()
|
|
async with session.get(
|
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
|
params={"perPage": 1, "page": 1},
|
|
headers=await _hdrs(),
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
data = await resp.json()
|
|
return int(data.get("totalItems", 0))
|
|
|
|
|
|
async def list_all_records(page_size: int = 500) -> list[dict[str, Any]]:
|
|
"""Fetch every record in the collection, handling PocketBase pagination."""
|
|
results: list[dict[str, Any]] = []
|
|
page = 1
|
|
session = _get_session()
|
|
hdrs = await _hdrs()
|
|
while True:
|
|
async with session.get(
|
|
f"{PB_URL}/api/collections/{ECONOMY_COLLECTION}/records",
|
|
params={"perPage": page_size, "page": page},
|
|
headers=hdrs,
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
data = await resp.json()
|
|
batch = data.get("items", [])
|
|
results.extend(batch)
|
|
if len(batch) < page_size:
|
|
break
|
|
page += 1
|
|
return results
|