"""Fluent Bit log ingestion, buffering, and alert-context persistence.""" from __future__ import annotations import asyncio import json import re import sqlite3 from collections import deque from datetime import UTC, datetime from hashlib import sha1 from pathlib import Path from typing import Any from app.config import ( ALL_NFS, LOG_ALERT_CONTEXT_AFTER, LOG_ALERT_CONTEXT_BEFORE, LOG_ALLOWED_NFS, LOG_ALERT_CONTEXT_DB_MAX_ROWS, LOG_ALERT_CONTEXT_DB_PATH, LOG_BUFFER_LINES, LOG_AUTO_CONFIGURE, LOG_FLUENTBIT_MATCH, LOG_INGEST_ENABLED, LOG_PROCESS_BUFFER_LINES, LOG_RECEIVER_BIND_HOST, LOG_RECEIVER_FORMAT, LOG_RECEIVER_HOST, LOG_RECEIVER_PORT, LOG_SUBSCRIBER_BUFFER_LINES, LOG_TRACE_DEBUG_LEVEL, LOG_TRACE_BUFFER_LINES, LOG_TRACE_TARGET_SERVICES, ) from app.services import pls _server: asyncio.base_events.Server | None = None _allowed_nfs = {nf.upper() for nf in LOG_ALLOWED_NFS} _events: deque[dict[str, Any]] = deque(maxlen=max(LOG_BUFFER_LINES, 1)) _trace_events: deque[dict[str, Any]] = deque(maxlen=max(LOG_TRACE_BUFFER_LINES, LOG_BUFFER_LINES, 1)) _process_events: dict[str, deque[dict[str, Any]]] = { nf.upper(): deque(maxlen=max(LOG_PROCESS_BUFFER_LINES, 1)) for nf in _allowed_nfs if nf != "SYSTEM" } _subscriber_events: dict[str, deque[dict[str, Any]]] = {} _ingested_total = 0 _parse_errors = 0 _last_event_at: str | None = None _db_initialized = False _supi_pattern = re.compile(r"(imsi-\d{6,20}|\b\d{6,20}\b)", re.IGNORECASE) _trace_state: dict[str, Any] = { "active": False, "filter": "", "normalized": "", "started_at": None, "matched_events": 0, "nodes": [], "services": list(LOG_TRACE_TARGET_SERVICES), "level": LOG_TRACE_DEBUG_LEVEL, "original_levels": {}, } def _db_path() -> Path: return Path(LOG_ALERT_CONTEXT_DB_PATH) def _ensure_db() -> None: global _db_initialized if _db_initialized: return path = _db_path() path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(path) try: conn.execute( """ CREATE TABLE IF NOT EXISTS alert_context ( id TEXT PRIMARY KEY, fingerprint TEXT UNIQUE, created_at TEXT NOT NULL, event_ts TEXT NOT NULL, category TEXT NOT NULL, nf TEXT, node TEXT, severity TEXT, description TEXT, remediation TEXT, source TEXT, match_message TEXT, before_context TEXT, after_context TEXT ) """ ) conn.commit() finally: conn.close() _db_initialized = True def _trim_db(conn: sqlite3.Connection) -> None: conn.execute( """ DELETE FROM alert_context WHERE id NOT IN ( SELECT id FROM alert_context ORDER BY event_ts DESC, created_at DESC LIMIT ? ) """, (max(LOG_ALERT_CONTEXT_DB_MAX_ROWS, 1),), ) def _parse_timestamp(value: Any) -> tuple[float, str]: if value is None: now = datetime.now(UTC) return now.timestamp(), now.isoformat() if isinstance(value, (int, float)): raw = float(value) if raw > 1_000_000_000_000: raw = raw / 1_000_000.0 elif raw > 10_000_000_000: raw = raw / 1000.0 dt = datetime.fromtimestamp(raw, UTC) return raw, dt.isoformat() text = str(value).strip() if text.isdigit(): return _parse_timestamp(int(text)) normalized = text.replace("Z", "+00:00") for candidate in (normalized, normalized.replace(" ", "T")): try: dt = datetime.fromisoformat(candidate) if dt.tzinfo is None: dt = dt.replace(tzinfo=UTC) else: dt = dt.astimezone(UTC) return dt.timestamp(), dt.isoformat() except ValueError: continue now = datetime.now(UTC) return now.timestamp(), now.isoformat() def _candidate_fields(payload: dict[str, Any]) -> list[str]: candidates = [] for key in ( "message", "MESSAGE", "log", "msg", "systemd_unit", "_SYSTEMD_UNIT", "syslog_identifier", "SYSLOG_IDENTIFIER", "_COMM", "comm", "_EXE", "container_name", "tag", ): value = payload.get(key) if value not in (None, ""): candidates.append(str(value)) return candidates def _infer_nf(payload: dict[str, Any], message: str) -> str: haystack = " ".join(_candidate_fields(payload) + [message]).lower() aliases = { "upf": "UPF", "amf": "AMF", "smf": "SMF", "udm": "UDM", "udr": "UDR", "nrf": "NRF", "ausf": "AUSF", "pcf": "PCF", "mme": "MME", "sgwc": "SGWC", "dra": "DRA", "dsm": "DSM", "aaa": "AAA", "bmsc": "BMSC", "chf": "CHF", "smsf": "SMSF", "eir": "EIR", "licensed": "LICENSED", "prometheus": "PROMETHEUS", "alertmanager": "ALERTMANAGER", "fluent-bit": "FLUENT-BIT", } for needle, label in aliases.items(): if needle in haystack: return label return "SYSTEM" def _normalize_supi(value: str | None) -> str: if not value: return "" text = str(value).strip().lower() if not text: return "" if text.startswith("imsi-"): digits = "".join(ch for ch in text[5:] if ch.isdigit()) return f"imsi-{digits}" if digits else text digits = "".join(ch for ch in text if ch.isdigit()) if digits: return f"imsi-{digits}" return text def _extract_supis(message: str) -> list[str]: matches = [] for raw in _supi_pattern.findall(message or ""): normalized = _normalize_supi(raw) if normalized and normalized not in matches: matches.append(normalized) return matches def _matches_trace(event: dict[str, Any]) -> bool: if not _trace_state.get("active"): return False normalized = _trace_state.get("normalized", "") if not normalized: return False message = str(event.get("message", "")).lower() if normalized in message: return True digits = normalized.removeprefix("imsi-") return bool(digits and digits in message) def _normalize_event(payload: dict[str, Any], remote_host: str) -> dict[str, Any]: ts_value = ( payload.get("timestamp") or payload.get("@timestamp") or payload.get("time") or payload.get("date") or payload.get("_SOURCE_REALTIME_TIMESTAMP") ) epoch, ts_iso = _parse_timestamp(ts_value) node = ( payload.get("hostname") or payload.get("host") or payload.get("_HOSTNAME") or payload.get("syslog_hostname") or remote_host ) source = ( payload.get("systemd_unit") or payload.get("_SYSTEMD_UNIT") or payload.get("syslog_identifier") or payload.get("SYSLOG_IDENTIFIER") or payload.get("_COMM") or payload.get("tag") or "unknown" ) message = ( payload.get("message") or payload.get("MESSAGE") or payload.get("log") or payload.get("msg") or "" ) message = str(message).strip() supis = _extract_supis(message) tag = str(payload.get("tag", "")) nf = _infer_nf(payload, message) fingerprint = sha1(f"{ts_iso}|{node}|{nf}|{source}|{message}".encode("utf-8")).hexdigest() return { "id": fingerprint, "timestamp": ts_iso, "epoch": epoch, "node": str(node), "nf": nf, "source": str(source), "tag": tag, "message": message, "supis": supis, "raw": payload, } async def _ingest_payload(payload: dict[str, Any], remote_host: str) -> None: global _ingested_total, _last_event_at event = _normalize_event(payload, remote_host) if event.get("nf", "").upper() not in _allowed_nfs: return _events.append(event) _trace_events.append(event) nf_key = event.get("nf", "").upper() if nf_key: _process_events.setdefault(nf_key, deque(maxlen=max(LOG_PROCESS_BUFFER_LINES, 1))).append(event) for supi in event.get("supis", []): _subscriber_events.setdefault( supi, deque(maxlen=max(LOG_SUBSCRIBER_BUFFER_LINES, 1)), ).append(event) if _matches_trace(event): _trace_state["matched_events"] = int(_trace_state.get("matched_events", 0)) + 1 _ingested_total += 1 _last_event_at = event["timestamp"] async def _handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: global _parse_errors peer = writer.get_extra_info("peername") remote_host = peer[0] if isinstance(peer, tuple) and peer else "unknown" try: while not reader.at_eof(): line = await reader.readline() if not line: break text = line.decode("utf-8", errors="replace").strip() if not text: continue try: payload = json.loads(text) if isinstance(payload, dict): await _ingest_payload(payload, remote_host) elif isinstance(payload, list): for item in payload: if isinstance(item, dict): await _ingest_payload(item, remote_host) except Exception: _parse_errors += 1 finally: writer.close() await writer.wait_closed() async def startup() -> None: global _server _ensure_db() if not LOG_INGEST_ENABLED or _server is not None: return _server = await asyncio.start_server(_handle_client, LOG_RECEIVER_BIND_HOST, LOG_RECEIVER_PORT) if LOG_AUTO_CONFIGURE: try: await configure_site_output() except Exception: pass async def shutdown() -> None: global _server if _server is None: return _server.close() await _server.wait_closed() _server = None def receiver_status() -> dict[str, Any]: return { "enabled": LOG_INGEST_ENABLED, "bind_host": LOG_RECEIVER_BIND_HOST, "receiver_host": LOG_RECEIVER_HOST, "port": LOG_RECEIVER_PORT, "format": LOG_RECEIVER_FORMAT, "allowed_nfs": sorted(_allowed_nfs), "buffer_lines": LOG_BUFFER_LINES, "process_buffer_lines": LOG_PROCESS_BUFFER_LINES, "subscriber_buffer_lines": LOG_SUBSCRIBER_BUFFER_LINES, "trace_buffer_lines": LOG_TRACE_BUFFER_LINES, "context_before": LOG_ALERT_CONTEXT_BEFORE, "context_after": LOG_ALERT_CONTEXT_AFTER, "db_path": str(_db_path()), "ingested_total": _ingested_total, "parse_errors": _parse_errors, "last_event_at": _last_event_at, "current_buffer_size": len(_events), "process_buffers": sorted(_process_events.keys()), "subscriber_buffers": len(_subscriber_events), "trace": { "active": bool(_trace_state.get("active")), "filter": _trace_state.get("filter", ""), "started_at": _trace_state.get("started_at"), "matched_events": _trace_state.get("matched_events", 0), "nodes": list(_trace_state.get("nodes", [])), "services": list(_trace_state.get("services", [])), "level": _trace_state.get("level", LOG_TRACE_DEBUG_LEVEL), }, } def current_output_config(receiver_host: str) -> dict[str, Any]: return { "name": "tcp", "match": LOG_FLUENTBIT_MATCH, "host": receiver_host, "port": LOG_RECEIVER_PORT, "format": LOG_RECEIVER_FORMAT, } def default_input_config() -> dict[str, Any]: return { "name": "systemd", "path": "/var/log/journal", "tag": "marvis.systemd", "read_from_tail": "on", "strip_underscores": "off", } async def _resolve_receiver_host() -> str: if LOG_RECEIVER_HOST: return LOG_RECEIVER_HOST cluster = await pls.get_cluster_status() if isinstance(cluster, dict): current_node = cluster.get("current_node") if isinstance(current_node, str) and current_node: return pls.node_host(current_node) system = await pls.get_system_info() if isinstance(system, dict) and system.get("hostname"): return str(system["hostname"]) return "127.0.0.1" def _merged_fluentbit_config(config: dict[str, Any], receiver_host: str) -> dict[str, Any]: merged = dict(config or {}) pipeline = dict(merged.get("pipeline") or {}) inputs = list(pipeline.get("inputs") or []) outputs = list(pipeline.get("outputs") or []) desired = current_output_config(receiver_host) if not inputs: inputs = [default_input_config()] filtered = [] for output in outputs: if not isinstance(output, dict): continue is_existing_marvis = ( output.get("name") == "tcp" and output.get("port") == LOG_RECEIVER_PORT and output.get("format") == LOG_RECEIVER_FORMAT ) if not is_existing_marvis: filtered.append(output) filtered.append(desired) pipeline["inputs"] = inputs pipeline["outputs"] = filtered merged["pipeline"] = pipeline if "parsers" not in merged: merged["parsers"] = list(config.get("parsers") or []) if isinstance(config, dict) else [] return merged async def configure_site_output() -> dict[str, Any]: current = await pls.get_fluentbit_config() if not isinstance(current, dict): raise RuntimeError("Could not read current Fluent Bit config from PLS") receiver_host = await _resolve_receiver_host() desired = _merged_fluentbit_config(current, receiver_host) updated = await pls.put_fluentbit_config(desired) if not isinstance(updated, dict): raise RuntimeError("PLS rejected Fluent Bit config update") return { "receiver_host": receiver_host, "receiver_port": LOG_RECEIVER_PORT, "match": LOG_FLUENTBIT_MATCH, "config": updated, } def _sort_and_limit(events: list[dict[str, Any]], limit: int | None = None) -> list[dict[str, Any]]: deduped: dict[str, dict[str, Any]] = {} for event in events: deduped[event.get("id", str(id(event)))] = event ordered = sorted(deduped.values(), key=lambda event: event.get("epoch", 0.0)) if limit is not None: return ordered[-limit:] return ordered def get_process_events(nf: str, limit: int | None = None) -> list[dict[str, Any]]: nf_key = str(nf or "").upper() events = list(_process_events.get(nf_key, [])) return _sort_and_limit(events, limit) def get_subscriber_events(supi_or_fragment: str, limit: int | None = None) -> list[dict[str, Any]]: normalized = _normalize_supi(supi_or_fragment) fragment = str(supi_or_fragment or "").strip().lower() if not normalized and not fragment: return [] matches: list[dict[str, Any]] = [] for supi, events in _subscriber_events.items(): digits = supi.removeprefix("imsi-") if normalized and (supi == normalized or normalized in supi or normalized.removeprefix("imsi-") in digits): matches.extend(events) continue if fragment and (fragment in supi.lower() or fragment in digits): matches.extend(events) return _sort_and_limit(matches, limit) async def _trace_target_nodes() -> list[dict[str, Any]]: cluster = await pls.get_cluster_status() nodes = [] if isinstance(cluster, dict): for node in cluster.get("nodes", []): host = pls.node_host(node.get("name", "")) if host: nodes.append({"name": node.get("name", ""), "host": host}) if not nodes: system = await pls.get_system_info() host = str(system.get("hostname", "") if isinstance(system, dict) else "") or "127.0.0.1" nodes.append({"name": host, "host": host}) deduped = {} for node in nodes: deduped[node["host"]] = node return list(deduped.values()) async def start_subscriber_trace(supi_or_fragment: str) -> dict[str, Any]: normalized = _normalize_supi(supi_or_fragment) fragment = str(supi_or_fragment or "").strip() if not normalized and not fragment: raise RuntimeError("A SUPI or SUPI fragment is required to start a trace") if _trace_state.get("active"): await stop_subscriber_trace() target_nodes = await _trace_target_nodes() original_levels: dict[str, dict[str, Any]] = {} applied_nodes: list[str] = [] for node in target_nodes: host = node["host"] current = await pls.get_log_config(host=host) if not isinstance(current, dict): continue original_levels[host] = current updated = dict(current) updated["level"] = LOG_TRACE_DEBUG_LEVEL await pls.put_log_config(updated, host=host) applied_nodes.append(host) _trace_state.update( { "active": True, "filter": fragment, "normalized": normalized or fragment.lower(), "started_at": datetime.now(UTC).isoformat(), "matched_events": 0, "nodes": applied_nodes, "services": list(LOG_TRACE_TARGET_SERVICES), "level": LOG_TRACE_DEBUG_LEVEL, "original_levels": original_levels, } ) return receiver_status()["trace"] async def stop_subscriber_trace() -> dict[str, Any]: original_levels = dict(_trace_state.get("original_levels", {})) restored_nodes: list[str] = [] for host, config in original_levels.items(): try: if isinstance(config, dict): await pls.put_log_config(config, host=host) restored_nodes.append(host) except Exception: continue summary = { "filter": _trace_state.get("filter", ""), "started_at": _trace_state.get("started_at"), "matched_events": _trace_state.get("matched_events", 0), "restored_nodes": restored_nodes, } _trace_state.update( { "active": False, "filter": "", "normalized": "", "started_at": None, "matched_events": 0, "nodes": [], "services": list(LOG_TRACE_TARGET_SERVICES), "level": LOG_TRACE_DEBUG_LEVEL, "original_levels": {}, } ) return summary def get_events(limit: int | None = None, node: str | None = None, nf: str | None = None, imsi: str | None = None) -> list[dict[str, Any]]: if imsi: events = get_subscriber_events(imsi, limit=None) elif nf: events = get_process_events(nf, limit=None) else: events = list(_events) if node: node_l = node.lower() events = [event for event in events if event.get("node", "").lower() == node_l] if nf: nf_u = nf.upper() events = [event for event in events if event.get("nf", "").upper() == nf_u] if imsi: needle = str(imsi).strip().lower() normalized = _normalize_supi(imsi) digits = normalized.removeprefix("imsi-") if normalized else "" events = [ event for event in events if needle and ( needle in event.get("message", "").lower() or any( needle in supi.lower() or (digits and digits in supi.removeprefix("imsi-")) for supi in event.get("supis", []) ) ) ] return _sort_and_limit(events, limit) def record_alert_context( *, category: str, nf: str, node: str, severity: str, description: str, remediation: str, source: str, event: dict[str, Any], before_context: list[dict[str, Any]], after_context: list[dict[str, Any]], ) -> str: _ensure_db() fingerprint = sha1( "|".join( [ category, nf, node, severity, description, remediation, event.get("timestamp", ""), event.get("message", ""), ] ).encode("utf-8") ).hexdigest() alert_id = sha1(f"{fingerprint}|{source}".encode("utf-8")).hexdigest() conn = sqlite3.connect(_db_path()) try: conn.execute( """ INSERT OR REPLACE INTO alert_context ( id, fingerprint, created_at, event_ts, category, nf, node, severity, description, remediation, source, match_message, before_context, after_context ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( alert_id, fingerprint, datetime.now(UTC).isoformat(), event.get("timestamp", ""), category, nf, node, severity, description, remediation, source, event.get("message", ""), json.dumps(before_context), json.dumps(after_context), ), ) _trim_db(conn) conn.commit() finally: conn.close() return alert_id def recent_alert_context(limit: int = 20) -> list[dict[str, Any]]: _ensure_db() conn = sqlite3.connect(_db_path()) conn.row_factory = sqlite3.Row try: rows = conn.execute( """ SELECT id, created_at, event_ts, category, nf, node, severity, description, remediation, source, match_message, before_context, after_context FROM alert_context ORDER BY event_ts DESC, created_at DESC LIMIT ? """, (limit,), ).fetchall() return [dict(row) for row in rows] finally: conn.close() def known_nfs() -> list[str]: return list(ALL_NFS)