"""Fluent Bit log ingestion, buffering, and alert-context persistence.""" from __future__ import annotations import asyncio import json import sqlite3 from collections import deque from datetime import UTC, datetime from hashlib import sha1 from pathlib import Path from typing import Any from app.config import ( ALL_NFS, LOG_ALERT_CONTEXT_AFTER, LOG_ALERT_CONTEXT_BEFORE, LOG_ALLOWED_NFS, LOG_ALERT_CONTEXT_DB_MAX_ROWS, LOG_ALERT_CONTEXT_DB_PATH, LOG_BUFFER_LINES, LOG_AUTO_CONFIGURE, LOG_FLUENTBIT_MATCH, LOG_INGEST_ENABLED, LOG_RECEIVER_BIND_HOST, LOG_RECEIVER_FORMAT, LOG_RECEIVER_HOST, LOG_RECEIVER_PORT, LOG_TRACE_BUFFER_LINES, ) from app.services import pls _server: asyncio.base_events.Server | None = None _events: deque[dict[str, Any]] = deque(maxlen=max(LOG_BUFFER_LINES, 1)) _trace_events: deque[dict[str, Any]] = deque(maxlen=max(LOG_TRACE_BUFFER_LINES, LOG_BUFFER_LINES, 1)) _ingested_total = 0 _parse_errors = 0 _last_event_at: str | None = None _db_initialized = False _allowed_nfs = {nf.upper() for nf in LOG_ALLOWED_NFS} def _db_path() -> Path: return Path(LOG_ALERT_CONTEXT_DB_PATH) def _ensure_db() -> None: global _db_initialized if _db_initialized: return path = _db_path() path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(path) try: conn.execute( """ CREATE TABLE IF NOT EXISTS alert_context ( id TEXT PRIMARY KEY, fingerprint TEXT UNIQUE, created_at TEXT NOT NULL, event_ts TEXT NOT NULL, category TEXT NOT NULL, nf TEXT, node TEXT, severity TEXT, description TEXT, remediation TEXT, source TEXT, match_message TEXT, before_context TEXT, after_context TEXT ) """ ) conn.commit() finally: conn.close() _db_initialized = True def _trim_db(conn: sqlite3.Connection) -> None: conn.execute( """ DELETE FROM alert_context WHERE id NOT IN ( SELECT id FROM alert_context ORDER BY event_ts DESC, created_at DESC LIMIT ? ) """, (max(LOG_ALERT_CONTEXT_DB_MAX_ROWS, 1),), ) def _parse_timestamp(value: Any) -> tuple[float, str]: if value is None: now = datetime.now(UTC) return now.timestamp(), now.isoformat() if isinstance(value, (int, float)): raw = float(value) if raw > 1_000_000_000_000: raw = raw / 1_000_000.0 elif raw > 10_000_000_000: raw = raw / 1000.0 dt = datetime.fromtimestamp(raw, UTC) return raw, dt.isoformat() text = str(value).strip() if text.isdigit(): return _parse_timestamp(int(text)) normalized = text.replace("Z", "+00:00") for candidate in (normalized, normalized.replace(" ", "T")): try: dt = datetime.fromisoformat(candidate) if dt.tzinfo is None: dt = dt.replace(tzinfo=UTC) else: dt = dt.astimezone(UTC) return dt.timestamp(), dt.isoformat() except ValueError: continue now = datetime.now(UTC) return now.timestamp(), now.isoformat() def _candidate_fields(payload: dict[str, Any]) -> list[str]: candidates = [] for key in ( "message", "MESSAGE", "log", "msg", "systemd_unit", "_SYSTEMD_UNIT", "syslog_identifier", "SYSLOG_IDENTIFIER", "_COMM", "comm", "_EXE", "container_name", "tag", ): value = payload.get(key) if value not in (None, ""): candidates.append(str(value)) return candidates def _infer_nf(payload: dict[str, Any], message: str) -> str: haystack = " ".join(_candidate_fields(payload) + [message]).lower() aliases = { "upf": "UPF", "amf": "AMF", "smf": "SMF", "udm": "UDM", "udr": "UDR", "nrf": "NRF", "ausf": "AUSF", "pcf": "PCF", "mme": "MME", "sgwc": "SGWC", "dra": "DRA", "dsm": "DSM", "aaa": "AAA", "bmsc": "BMSC", "chf": "CHF", "smsf": "SMSF", "eir": "EIR", "licensed": "LICENSED", "prometheus": "PROMETHEUS", "alertmanager": "ALERTMANAGER", "fluent-bit": "FLUENT-BIT", } for needle, label in aliases.items(): if needle in haystack: return label return "SYSTEM" def _normalize_event(payload: dict[str, Any], remote_host: str) -> dict[str, Any]: ts_value = ( payload.get("timestamp") or payload.get("@timestamp") or payload.get("time") or payload.get("date") or payload.get("_SOURCE_REALTIME_TIMESTAMP") ) epoch, ts_iso = _parse_timestamp(ts_value) node = ( payload.get("hostname") or payload.get("host") or payload.get("_HOSTNAME") or payload.get("syslog_hostname") or remote_host ) source = ( payload.get("systemd_unit") or payload.get("_SYSTEMD_UNIT") or payload.get("syslog_identifier") or payload.get("SYSLOG_IDENTIFIER") or payload.get("_COMM") or payload.get("tag") or "unknown" ) message = ( payload.get("message") or payload.get("MESSAGE") or payload.get("log") or payload.get("msg") or "" ) message = str(message).strip() tag = str(payload.get("tag", "")) nf = _infer_nf(payload, message) fingerprint = sha1(f"{ts_iso}|{node}|{nf}|{source}|{message}".encode("utf-8")).hexdigest() return { "id": fingerprint, "timestamp": ts_iso, "epoch": epoch, "node": str(node), "nf": nf, "source": str(source), "tag": tag, "message": message, "raw": payload, } async def _ingest_payload(payload: dict[str, Any], remote_host: str) -> None: global _ingested_total, _last_event_at event = _normalize_event(payload, remote_host) if event.get("nf", "").upper() not in _allowed_nfs: return _events.append(event) _trace_events.append(event) _ingested_total += 1 _last_event_at = event["timestamp"] async def _handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: global _parse_errors peer = writer.get_extra_info("peername") remote_host = peer[0] if isinstance(peer, tuple) and peer else "unknown" try: while not reader.at_eof(): line = await reader.readline() if not line: break text = line.decode("utf-8", errors="replace").strip() if not text: continue try: payload = json.loads(text) if isinstance(payload, dict): await _ingest_payload(payload, remote_host) elif isinstance(payload, list): for item in payload: if isinstance(item, dict): await _ingest_payload(item, remote_host) except Exception: _parse_errors += 1 finally: writer.close() await writer.wait_closed() async def startup() -> None: global _server _ensure_db() if not LOG_INGEST_ENABLED or _server is not None: return _server = await asyncio.start_server(_handle_client, LOG_RECEIVER_BIND_HOST, LOG_RECEIVER_PORT) if LOG_AUTO_CONFIGURE: try: await configure_site_output() except Exception: pass async def shutdown() -> None: global _server if _server is None: return _server.close() await _server.wait_closed() _server = None def receiver_status() -> dict[str, Any]: return { "enabled": LOG_INGEST_ENABLED, "bind_host": LOG_RECEIVER_BIND_HOST, "receiver_host": LOG_RECEIVER_HOST, "port": LOG_RECEIVER_PORT, "format": LOG_RECEIVER_FORMAT, "allowed_nfs": sorted(_allowed_nfs), "buffer_lines": LOG_BUFFER_LINES, "trace_buffer_lines": LOG_TRACE_BUFFER_LINES, "context_before": LOG_ALERT_CONTEXT_BEFORE, "context_after": LOG_ALERT_CONTEXT_AFTER, "db_path": str(_db_path()), "ingested_total": _ingested_total, "parse_errors": _parse_errors, "last_event_at": _last_event_at, "current_buffer_size": len(_events), } def current_output_config(receiver_host: str) -> dict[str, Any]: return { "name": "tcp", "match": LOG_FLUENTBIT_MATCH, "host": receiver_host, "port": LOG_RECEIVER_PORT, "format": LOG_RECEIVER_FORMAT, } def default_input_config() -> dict[str, Any]: return { "name": "systemd", "path": "/var/log/journal", "tag": "marvis.systemd", "read_from_tail": "on", "strip_underscores": "off", } async def _resolve_receiver_host() -> str: if LOG_RECEIVER_HOST: return LOG_RECEIVER_HOST cluster = await pls.get_cluster_status() if isinstance(cluster, dict): current_node = cluster.get("current_node") if isinstance(current_node, str) and current_node: return pls.node_host(current_node) system = await pls.get_system_info() if isinstance(system, dict) and system.get("hostname"): return str(system["hostname"]) return "127.0.0.1" def _merged_fluentbit_config(config: dict[str, Any], receiver_host: str) -> dict[str, Any]: merged = dict(config or {}) pipeline = dict(merged.get("pipeline") or {}) inputs = list(pipeline.get("inputs") or []) outputs = list(pipeline.get("outputs") or []) desired = current_output_config(receiver_host) if not inputs: inputs = [default_input_config()] filtered = [] for output in outputs: if not isinstance(output, dict): continue is_existing_marvis = ( output.get("name") == "tcp" and output.get("port") == LOG_RECEIVER_PORT and output.get("format") == LOG_RECEIVER_FORMAT ) if not is_existing_marvis: filtered.append(output) filtered.append(desired) pipeline["inputs"] = inputs pipeline["outputs"] = filtered merged["pipeline"] = pipeline if "parsers" not in merged: merged["parsers"] = list(config.get("parsers") or []) if isinstance(config, dict) else [] return merged async def configure_site_output() -> dict[str, Any]: current = await pls.get_fluentbit_config() if not isinstance(current, dict): raise RuntimeError("Could not read current Fluent Bit config from PLS") receiver_host = await _resolve_receiver_host() desired = _merged_fluentbit_config(current, receiver_host) updated = await pls.put_fluentbit_config(desired) if not isinstance(updated, dict): raise RuntimeError("PLS rejected Fluent Bit config update") return { "receiver_host": receiver_host, "receiver_port": LOG_RECEIVER_PORT, "match": LOG_FLUENTBIT_MATCH, "config": updated, } def get_events(limit: int | None = None, node: str | None = None, nf: str | None = None, imsi: str | None = None) -> list[dict[str, Any]]: events = list(_trace_events if imsi else _events) if node: node_l = node.lower() events = [event for event in events if event.get("node", "").lower() == node_l] if nf: nf_u = nf.upper() events = [event for event in events if event.get("nf", "").upper() == nf_u] if imsi: needle = imsi.strip() events = [event for event in events if needle and needle in event.get("message", "")] events.sort(key=lambda event: event.get("epoch", 0.0)) if limit is not None: return events[-limit:] return events def record_alert_context( *, category: str, nf: str, node: str, severity: str, description: str, remediation: str, source: str, event: dict[str, Any], before_context: list[dict[str, Any]], after_context: list[dict[str, Any]], ) -> str: _ensure_db() fingerprint = sha1( "|".join( [ category, nf, node, severity, description, remediation, event.get("timestamp", ""), event.get("message", ""), ] ).encode("utf-8") ).hexdigest() alert_id = sha1(f"{fingerprint}|{source}".encode("utf-8")).hexdigest() conn = sqlite3.connect(_db_path()) try: conn.execute( """ INSERT OR REPLACE INTO alert_context ( id, fingerprint, created_at, event_ts, category, nf, node, severity, description, remediation, source, match_message, before_context, after_context ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( alert_id, fingerprint, datetime.now(UTC).isoformat(), event.get("timestamp", ""), category, nf, node, severity, description, remediation, source, event.get("message", ""), json.dumps(before_context), json.dumps(after_context), ), ) _trim_db(conn) conn.commit() finally: conn.close() return alert_id def recent_alert_context(limit: int = 20) -> list[dict[str, Any]]: _ensure_db() conn = sqlite3.connect(_db_path()) conn.row_factory = sqlite3.Row try: rows = conn.execute( """ SELECT id, created_at, event_ts, category, nf, node, severity, description, remediation, source, match_message, before_context, after_context FROM alert_context ORDER BY event_ts DESC, created_at DESC LIMIT ? """, (limit,), ).fetchall() return [dict(row) for row in rows] finally: conn.close() def known_nfs() -> list[str]: return list(ALL_NFS)