from __future__ import annotations import sqlite3 from contextlib import contextmanager from datetime import datetime, date from pathlib import Path from typing import Iterator from app.models import DailyEnergyRecord, HistoricalCoverage class SQLiteEnergyRepository: def __init__(self, db_path: str) -> None: self.db_path = Path(db_path) self.db_path.parent.mkdir(parents=True, exist_ok=True) self.ensure_schema() @contextmanager def connect(self) -> Iterator[sqlite3.Connection]: conn = sqlite3.connect(self.db_path) conn.row_factory = sqlite3.Row try: conn.execute("PRAGMA journal_mode=WAL") conn.execute("PRAGMA synchronous=NORMAL") yield conn conn.commit() finally: conn.close() def ensure_schema(self) -> None: with self.connect() as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS daily_energy ( day TEXT PRIMARY KEY, energy_kwh REAL NOT NULL, source TEXT NOT NULL, samples_count INTEGER NOT NULL DEFAULT 0, imported_at TEXT NOT NULL ) """ ) conn.execute( "CREATE INDEX IF NOT EXISTS idx_daily_energy_imported_at ON daily_energy(imported_at)" ) def has_day(self, day: date) -> bool: with self.connect() as conn: row = conn.execute("SELECT 1 FROM daily_energy WHERE day = ? LIMIT 1", (day.isoformat(),)).fetchone() return row is not None def upsert_daily_energy(self, record: DailyEnergyRecord) -> None: imported_at = record.imported_at or datetime.utcnow() with self.connect() as conn: conn.execute( """ INSERT INTO daily_energy (day, energy_kwh, source, samples_count, imported_at) VALUES (?, ?, ?, ?, ?) ON CONFLICT(day) DO UPDATE SET energy_kwh = excluded.energy_kwh, source = excluded.source, samples_count = excluded.samples_count, imported_at = excluded.imported_at """, ( record.day.isoformat(), float(record.energy_kwh), record.source, int(record.samples_count), imported_at.isoformat(), ), ) def fetch_daily_energy(self, start_day: date, end_day: date) -> dict[date, DailyEnergyRecord]: with self.connect() as conn: rows = conn.execute( """ SELECT day, energy_kwh, source, samples_count, imported_at FROM daily_energy WHERE day >= ? AND day <= ? ORDER BY day ASC """, (start_day.isoformat(), end_day.isoformat()), ).fetchall() payload: dict[date, DailyEnergyRecord] = {} for row in rows: payload[date.fromisoformat(row["day"])] = DailyEnergyRecord( day=date.fromisoformat(row["day"]), energy_kwh=float(row["energy_kwh"]), source=row["source"], samples_count=int(row["samples_count"]), imported_at=datetime.fromisoformat(row["imported_at"]), ) return payload def coverage(self) -> HistoricalCoverage: with self.connect() as conn: row = conn.execute( """ SELECT COUNT(*) AS imported_days, MIN(day) AS first_day, MAX(day) AS last_day, COALESCE(SUM(energy_kwh), 0) AS total_energy_kwh FROM daily_energy """ ).fetchone() if row is None: return HistoricalCoverage() return HistoricalCoverage( imported_days=int(row["imported_days"] or 0), first_day=date.fromisoformat(row["first_day"]) if row["first_day"] else None, last_day=date.fromisoformat(row["last_day"]) if row["last_day"] else None, total_energy_kwh=round(float(row["total_energy_kwh"] or 0.0), 2), ) def latest_day(self) -> date | None: return self.coverage().last_day def count_missing_days(self, start_day: date, end_day: date) -> int: existing = self.fetch_daily_energy(start_day, end_day) current = start_day missing = 0 while current <= end_day: if current not in existing: missing += 1 current = current.fromordinal(current.toordinal() + 1) return missing