from __future__ import annotations from typing import Any from flask import session from app.core_settings import AppSettings, get_settings from app.storage.kiosk_settings import SQLiteKioskSettingsRepository VALID_MODES = {"public", "private"} USER_MODE_PREFIX = "user:" DEFAULT_WIDGETS = ["hero", "history", "strings", "status", "production", "comparison", "importStatus"] DEFAULT_HERO_METRICS = ["ac_power", "dc_power_total", "energy_today", "energy_total"] DEFAULT_CHART_GROUPS = [{"id": "overview", "title": None, "metric_ids": ["ac_power", "dc_power_total", "inverter_temp"]}] VALID_WIDGETS = {"hero", "quickMetrics", "history", "status", "strings", "production", "comparison", "distribution", "importStatus"} VALID_REALTIME_RANGES = {"today", "yesterday", "6h", "12h", "24h", "48h", "7d"} VALID_ANALYTICS_RANGES = {"today", "yesterday", "7d", "30d", "90d", "365d", "custom"} class KioskSettingsService: def __init__(self, settings: AppSettings | None = None) -> None: self.settings = settings or get_settings() self.repository = SQLiteKioskSettingsRepository(self.settings.storage["sqlite_path"]) def get(self, mode: str) -> dict[str, Any]: normalized_mode = self._normalize_mode(mode) stored = self.repository.get(normalized_mode) if stored is None: return self._default_payload(normalized_mode) return self._sanitize_payload(normalized_mode, stored, persist_if_changed=False) def update(self, mode: str, payload: dict[str, Any], updated_by: str | None = None) -> dict[str, Any]: normalized_mode = self._normalize_mode(mode) merged = {**self.get(normalized_mode), **(payload or {})} cleaned = self._sanitize_payload(normalized_mode, merged, persist_if_changed=False) return self.repository.upsert(normalized_mode, cleaned, updated_by=updated_by) def update_from_session(self, mode: str, payload: dict[str, Any]) -> dict[str, Any]: updated_by = session.get("auth_user") return self.update(mode, payload, updated_by=updated_by) def _default_payload(self, mode: str) -> dict[str, Any]: return { "mode": mode, "widgets": list(DEFAULT_WIDGETS), "hero_metric_ids": list(DEFAULT_HERO_METRICS), "realtime_range": self._default_realtime_range(), "analytics_range": self._default_analytics_range(), "analytics_bucket": self._default_analytics_bucket(), "compare_mode": self._default_compare_mode(), "chart_groups": self._normalize_chart_groups(None), "updated_at": None, "updated_by": None, } def _sanitize_payload(self, mode: str, payload: dict[str, Any], persist_if_changed: bool = False) -> dict[str, Any]: cleaned = { "mode": mode, "widgets": self._normalize_widgets(payload.get("widgets")), "hero_metric_ids": self._normalize_metric_ids(payload.get("hero_metric_ids"), DEFAULT_HERO_METRICS), "realtime_range": self._normalize_realtime_range(payload.get("realtime_range")), "analytics_range": self._normalize_analytics_range(payload.get("analytics_range")), "analytics_bucket": self._normalize_bucket(payload.get("analytics_bucket")), "compare_mode": self._normalize_compare_mode(payload.get("compare_mode")), "chart_groups": self._normalize_chart_groups(payload.get("chart_groups")), "updated_at": payload.get("updated_at"), "updated_by": payload.get("updated_by"), } if persist_if_changed: return self.repository.upsert(mode, cleaned, updated_by=cleaned.get("updated_by")) return cleaned def _normalize_mode(self, mode: str) -> str: normalized = (mode or "").strip().lower() if normalized in VALID_MODES: return normalized if normalized.startswith(USER_MODE_PREFIX) and len(normalized) > len(USER_MODE_PREFIX): return normalized raise ValueError("Mode must be one of: public, private") def _normalize_widgets(self, widgets: Any) -> list[str]: if not isinstance(widgets, list): return list(DEFAULT_WIDGETS) normalized: list[str] = [] for item in widgets: widget = str(item or "").strip() if widget in VALID_WIDGETS and widget not in normalized: normalized.append(widget) return normalized or list(DEFAULT_WIDGETS) def _normalize_metric_ids(self, metric_ids: Any, defaults: list[str]) -> list[str]: if not isinstance(metric_ids, list): return list(defaults) normalized: list[str] = [] for item in metric_ids: value = str(item or "").strip() if value and value not in normalized: normalized.append(value) return normalized[:12] or list(defaults) def _normalize_realtime_range(self, value: Any) -> str: normalized = str(value or self._default_realtime_range()).strip() return normalized if normalized in VALID_REALTIME_RANGES else self._default_realtime_range() def _normalize_analytics_range(self, value: Any) -> str: normalized = str(value or self._default_analytics_range()).strip() return normalized if normalized in VALID_ANALYTICS_RANGES else self._default_analytics_range() def _normalize_bucket(self, value: Any) -> str: normalized = str(value or self._default_analytics_bucket()).strip() return normalized if normalized in self.settings.analytics["bucket_labels"] else self._default_analytics_bucket() def _normalize_compare_mode(self, value: Any) -> str: normalized = str(value or self._default_compare_mode()).strip() return normalized if normalized in self.settings.analytics["compare_modes"] else self._default_compare_mode() def _normalize_chart_groups(self, groups: Any) -> list[dict[str, Any]]: if not isinstance(groups, list): return [dict(item) for item in DEFAULT_CHART_GROUPS] normalized: list[dict[str, Any]] = [] for index, item in enumerate(groups): if not isinstance(item, dict): continue raw_metrics = item.get("metric_ids") if not isinstance(raw_metrics, list): continue metric_ids: list[str] = [] for metric in raw_metrics: value = str(metric or "").strip() if value and value not in metric_ids: metric_ids.append(value) if not metric_ids: continue chart_id = str(item.get("id") or f"chart_{index + 1}").strip() or f"chart_{index + 1}" title_raw = item.get("title") title = None if title_raw is None else str(title_raw).strip() or None normalized.append({"id": chart_id[:80], "title": title, "metric_ids": metric_ids[:24]}) return normalized[:8] or [dict(item) for item in DEFAULT_CHART_GROUPS] def _default_realtime_range(self) -> str: raw = str(self.settings.realtime.get("history_default_range", "12h")) return raw if raw in VALID_REALTIME_RANGES else "12h" def _default_analytics_range(self) -> str: raw = str(self.settings.analytics.get("default_range", "30d")) return raw if raw in VALID_ANALYTICS_RANGES else "30d" def _default_analytics_bucket(self) -> str: raw = str(self.settings.analytics.get("default_bucket", "day")) return raw if raw in self.settings.analytics["bucket_labels"] else "day" def _default_compare_mode(self) -> str: raw = str(self.settings.analytics.get("default_compare", "none")) return raw if raw in self.settings.analytics["compare_modes"] else "none" _kiosk_settings_service: KioskSettingsService | None = None def get_kiosk_settings_service() -> KioskSettingsService: global _kiosk_settings_service if _kiosk_settings_service is None: _kiosk_settings_service = KioskSettingsService() return _kiosk_settings_service