Files
geoip_block_generator/scheduler.py
Mateusz Gruszczyński a02ceaad64 heartbeat every 15 mins
2026-02-25 10:04:29 +01:00

426 lines
16 KiB
Python

#!/usr/bin/env python3
"""
GeoIP Country Scanner Daemon - Incremental Update Mode
"""
import schedule
import time
import sys
import signal
import os
import sqlite3
import concurrent.futures
from datetime import datetime
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
import threading
import traceback
sys.path.insert(0, str(Path(__file__).parent))
from geoip_handler import GeoIPHandler
import config
running = True
log_lock = threading.Lock()
write_lock = threading.Lock()
active_scans = {}
active_scans_lock = threading.Lock()
def heartbeat():
log_safe(f"[{datetime.now()}] HEARTBEAT running=True next_run={schedule.next_run()} jobs={len(schedule.jobs)}")
def compute_maxmind_workers():
with active_scans_lock:
active = max(1, len(active_scans))
cpu = cpu_count()
total_budget = max(32, cpu * 6) # 16*6 = 96
per_country = max(4, total_budget // active)
min_w = int(os.getenv('MAXMIND_WORKERS_MIN', '6'))
max_w = int(os.getenv('MAXMIND_WORKERS_MAX', '48'))
return max(min_w, min(max_w, per_country))
def signal_handler(signum, frame):
global running
print(f"\n[{datetime.now()}] Received signal {signum}, shutting down...", flush=True)
sys.stdout.flush()
running = False
def log_safe(message):
with log_lock:
print(message, flush=True)
sys.stdout.flush()
def update_scan_progress(country_code, progress_msg):
with active_scans_lock:
if country_code in active_scans:
active_scans[country_code]['progress'] = progress_msg
active_scans[country_code]['last_update'] = datetime.now()
def progress_callback_factory(country_code):
def callback(msg):
update_scan_progress(country_code, msg)
return callback
def print_active_scans():
with active_scans_lock:
if not active_scans:
return
print("\n" + "=" * 70, flush=True)
print("ACTIVE SCANS STATUS:", flush=True)
print("=" * 70, flush=True)
for country, info in sorted(active_scans.items()):
elapsed = (datetime.now() - info['start_time']).total_seconds()
progress = info.get('progress', 'Unknown')
is_update = info.get('is_update', False)
mode = "UPDATE" if is_update else "SCAN"
print(f" {country} [{mode}]: {progress} | {elapsed:.0f}s", flush=True)
print("=" * 70 + "\n", flush=True)
sys.stdout.flush()
def scan_single_country(country_code, is_update=False):
try:
with active_scans_lock:
active_scans[country_code] = {
'start_time': datetime.now(),
'progress': 'Starting...',
'last_update': datetime.now(),
'is_update': is_update
}
start_time = time.time()
mode = "INCREMENTAL UPDATE" if is_update else "FULL SCAN"
print(f"[START] {country_code} - {mode}...", flush=True)
sys.stdout.flush()
progress_cb = progress_callback_factory(country_code)
handler = GeoIPHandler()
print(f"[{country_code}] Scanning MaxMind + GitHub...", flush=True)
maxmind_workers = compute_maxmind_workers()
print(f"[{country_code}] MaxMind workers: {maxmind_workers} (active scans: {len(active_scans)})", flush=True)
maxmind_networks = handler._scan_maxmind_for_country(
country_code,
progress_callback=progress_cb,
workers=maxmind_workers
)
if maxmind_networks:
print(f"[{country_code}] MaxMind: {len(maxmind_networks):,} networks, checking GitHub...", flush=True)
github_networks = handler._fetch_from_github(country_code)
if github_networks:
maxmind_set = set(maxmind_networks)
github_set = set(github_networks)
missing = github_set - maxmind_set
if missing:
maxmind_networks.extend(missing)
print(f"[{country_code}] GitHub added {len(missing):,} new networks", flush=True)
else:
print(f"[{country_code}] GitHub: {len(github_networks):,} networks (no new)", flush=True)
source = 'maxmind+github'
else:
print(f"[{country_code}] GitHub: no data", flush=True)
source = 'maxmind'
networks = maxmind_networks
else:
print(f"[{country_code}] MaxMind found nothing, trying GitHub...", flush=True)
networks = handler._fetch_from_github(country_code)
source = 'github' if networks else None
if networks:
with write_lock:
print(f"[{country_code}] Acquired write lock, saving to database...", flush=True)
if is_update:
saved = handler._update_cache_incremental(country_code, networks, source)
else:
saved = handler._save_to_cache(country_code, networks, source)
print(f"[{country_code}] Released write lock", flush=True)
elapsed = time.time() - start_time
with active_scans_lock:
active_scans.pop(country_code, None)
if saved:
print(f"[DONE] {country_code}: {len(networks)} networks in {elapsed:.1f}s ({mode})", flush=True)
sys.stdout.flush()
return {'country': country_code, 'success': True, 'networks': len(networks), 'error': None, 'mode': mode}
else:
print(f"[ERROR] {country_code}: Failed to save to cache", flush=True)
sys.stdout.flush()
return {'country': country_code, 'success': False, 'networks': 0, 'error': 'Failed to save', 'mode': mode}
else:
with active_scans_lock:
active_scans.pop(country_code, None)
print(f"[ERROR] {country_code}: No data found", flush=True)
sys.stdout.flush()
return {'country': country_code, 'success': False, 'networks': 0, 'error': 'No data found', 'mode': mode}
except Exception as e:
with active_scans_lock:
active_scans.pop(country_code, None)
print(f"[ERROR] {country_code}: {e}", flush=True)
sys.stdout.flush()
import traceback
traceback.print_exc()
return {'country': country_code, 'success': False, 'networks': 0, 'error': str(e), 'mode': 'UNKNOWN'}
def scan_all_countries_incremental(parallel_workers=None, max_age_hours=168):
log_safe(f"[{datetime.now()}] Starting INCREMENTAL country scan...")
try:
handler = GeoIPHandler()
if handler.needs_update():
log_safe("Updating MaxMind database...")
result = handler.download_database()
if not result.get('success'):
log_safe(f"Warning: Database update failed - {result.get('error')}")
log_safe("\nChecking cache status...")
missing, stale = handler.get_countries_needing_scan(max_age_hours)
log_safe(f"Missing countries (never scanned): {len(missing)}")
log_safe(f"Stale countries (needs update): {len(stale)}")
if missing:
log_safe(f"Missing: {', '.join(sorted(missing))}")
if stale:
log_safe(f"Stale: {', '.join(sorted(stale))}")
total = len(missing) + len(stale)
if total == 0:
log_safe("\n✓ All countries are up to date!")
return True
if parallel_workers is None:
parallel_workers = min(cpu_count(), 16)
log_safe(f"\nProcessing {total} countries using {parallel_workers} parallel workers...")
log_safe(f" - {len(missing)} new countries (full scan)")
log_safe(f" - {len(stale)} stale countries (incremental update)")
log_safe(f"Note: Database writes are serialized with write lock")
log_safe(f"Estimated time: {total / parallel_workers * 3:.1f} minutes\n")
start_time = datetime.now()
completed = 0
success_count = 0
failed_countries = []
results_list = []
last_progress_time = time.time()
last_status_print = time.time()
def print_progress(force=False):
nonlocal last_progress_time
current_time = time.time()
if not force and (current_time - last_progress_time) < 30:
return
last_progress_time = current_time
elapsed = (datetime.now() - start_time).total_seconds()
avg_time = elapsed / completed if completed > 0 else 0
remaining = (total - completed) * avg_time if completed > 0 else 0
progress_bar = "" * int(completed / total * 40)
progress_bar += "" * (40 - int(completed / total * 40))
msg = (f"[{progress_bar}] {completed}/{total} ({100*completed/total:.1f}%) | "
f"Elapsed: {elapsed:.0f}s | ETA: {remaining:.0f}s")
print(msg, flush=True)
sys.stdout.flush()
log_safe("Starting parallel execution...")
sys.stdout.flush()
tasks = [(country, False) for country in missing] + [(country, True) for country in stale]
with ThreadPoolExecutor(max_workers=parallel_workers) as executor:
future_to_country = {
executor.submit(scan_single_country, country, is_update): country
for country, is_update in tasks
}
log_safe(f"Submitted {len(future_to_country)} tasks\n")
sys.stdout.flush()
pending = set(future_to_country.keys())
while pending:
done, pending = concurrent.futures.wait(
pending,
timeout=10,
return_when=concurrent.futures.FIRST_COMPLETED
)
for future in done:
result = future.result()
results_list.append(result)
completed += 1
if result['success']:
success_count += 1
else:
failed_countries.append(result['country'])
print_progress(force=bool(done))
current_time = time.time()
if current_time - last_status_print >= 30:
print_active_scans()
last_status_print = current_time
print("\n", flush=True)
sys.stdout.flush()
elapsed = (datetime.now() - start_time).total_seconds()
log_safe("=" * 70)
log_safe("SCAN RESULTS (sorted by country):")
log_safe("=" * 70)
for result in sorted(results_list, key=lambda x: x['country']):
mode_str = f"[{result.get('mode', 'UNKNOWN')}]"
if result['success']:
log_safe(f" {result['country']}: ✓ {result['networks']:,} networks {mode_str}")
else:
log_safe(f" {result['country']}: ✗ {result['error']} {mode_str}")
log_safe("=" * 70)
log_safe(f"\n[{datetime.now()}] Incremental scan complete!")
log_safe(f"✓ Success: {success_count}/{total} countries")
log_safe(f" - New countries: {len([r for r in results_list if r.get('mode') == 'FULL SCAN' and r['success']])}")
log_safe(f" - Updated countries: {len([r for r in results_list if r.get('mode') == 'INCREMENTAL UPDATE' and r['success']])}")
log_safe(f" Time: {elapsed:.1f}s ({elapsed/60:.1f} minutes)")
log_safe(f" Average: {elapsed/total:.1f}s per country\n")
if failed_countries:
log_safe(f"✗ Failed: {', '.join(failed_countries)}\n")
return True
except Exception as e:
log_safe(f"[{datetime.now()}] ERROR: {e}")
import traceback
traceback.print_exc()
sys.stdout.flush()
return False
if __name__ == '__main__':
print("=" * 70, flush=True)
print("GeoIP Country Scanner Daemon", flush=True)
print("=" * 70, flush=True)
print(f"Started: {datetime.now()}", flush=True)
print(f"Data dir: {config.GEOIP_DB_DIR}", flush=True)
print(f"CPU cores: {cpu_count()}", flush=True)
sys.stdout.flush()
scheduler_enabled = os.getenv('SCHEDULER_ENABLED', 'true').lower() == 'true'
if not scheduler_enabled:
print("\n[DISABLED] SCHEDULER_ENABLED=false - exiting", flush=True)
print("=" * 70, flush=True)
sys.stdout.flush()
sys.exit(0)
print("=" * 70, flush=True)
sys.stdout.flush()
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
scan_time = os.getenv('SCAN_TIME', '02:00')
scan_interval = os.getenv('SCAN_INTERVAL', '7d')
scan_on_startup = os.getenv('SCAN_ON_STARTUP', 'true').lower() == 'true'
cache_max_age_hours = int(os.getenv('CACHE_MAX_AGE_HOURS', '168'))
parallel_workers = int(os.getenv('PARALLEL_WORKERS', '16'))
if parallel_workers == 0:
parallel_workers = min(cpu_count(), 16)
print(f"\n[CONFIG] Scheduler: enabled", flush=True)
print(f"[CONFIG] Parallel: {parallel_workers} workers", flush=True)
print(f"[CONFIG] Interval: {scan_interval}", flush=True)
print(f"[CONFIG] Time: {scan_time}", flush=True)
print(f"[CONFIG] Startup scan: {scan_on_startup}", flush=True)
print(f"[CONFIG] Cache max age: {cache_max_age_hours}h ({cache_max_age_hours/24:.1f} days)", flush=True)
sys.stdout.flush()
scan_function = lambda: scan_all_countries_incremental(parallel_workers, cache_max_age_hours)
if scan_on_startup:
print("\n[STARTUP] Running incremental scan...\n", flush=True)
sys.stdout.flush()
scan_function()
else:
print("\n[STARTUP] Skipping (SCAN_ON_STARTUP=false)", flush=True)
sys.stdout.flush()
if scan_interval == 'daily':
schedule.every().day.at(scan_time).do(scan_function)
print(f"\n[SCHEDULER] Daily at {scan_time}", flush=True)
elif scan_interval == 'weekly':
schedule.every().monday.at(scan_time).do(scan_function)
print(f"\n[SCHEDULER] Weekly (Monday at {scan_time})", flush=True)
elif scan_interval == 'monthly':
schedule.every(30).days.do(scan_function)
print(f"\n[SCHEDULER] Monthly (every 30 days)", flush=True)
elif scan_interval.endswith('h'):
hours = int(scan_interval[:-1])
schedule.every(hours).hours.do(scan_function)
print(f"\n[SCHEDULER] Every {hours} hours", flush=True)
elif scan_interval.endswith('d'):
days = int(scan_interval[:-1])
schedule.every(days).days.do(scan_function)
print(f"\n[SCHEDULER] Every {days} days", flush=True)
else:
print(f"\n[ERROR] Invalid SCAN_INTERVAL: {scan_interval}", flush=True)
sys.stdout.flush()
sys.exit(1)
next_run = schedule.next_run()
if next_run:
print(f"[SCHEDULER] Next run: {next_run}", flush=True)
print("\nScheduler running. Press Ctrl+C to stop.\n", flush=True)
sys.stdout.flush()
# heartbeat
schedule.every(15).minutes.do(heartbeat)
while running:
try:
schedule.run_pending()
except Exception as e:
log_safe(f"[{datetime.now()}] ERROR in run_pending: {e}")
traceback.print_exc()
sys.stdout.flush()
time.sleep(60)
print("\n[SHUTDOWN] Stopped gracefully.", flush=True)
sys.stdout.flush()
sys.exit(0)