Files
p5g-marvis/app/services/log_ingest.py
2026-04-27 13:42:49 -04:00

719 lines
22 KiB
Python

"""Fluent Bit log ingestion, buffering, and alert-context persistence."""
from __future__ import annotations
import asyncio
import json
import re
import sqlite3
from collections import deque
from datetime import UTC, datetime
from hashlib import sha1
from pathlib import Path
from typing import Any
from app.config import (
ALL_NFS,
LOG_ALERT_CONTEXT_AFTER,
LOG_ALERT_CONTEXT_BEFORE,
LOG_ALLOWED_NFS,
LOG_ALERT_CONTEXT_DB_MAX_ROWS,
LOG_ALERT_CONTEXT_DB_PATH,
LOG_BUFFER_LINES,
LOG_AUTO_CONFIGURE,
LOG_FLUENTBIT_MATCH,
LOG_INGEST_ENABLED,
LOG_PROCESS_BUFFER_LINES,
LOG_RECEIVER_BIND_HOST,
LOG_RECEIVER_FORMAT,
LOG_RECEIVER_HOST,
LOG_RECEIVER_PORT,
LOG_SUBSCRIBER_BUFFER_LINES,
LOG_TRACE_DEBUG_LEVEL,
LOG_TRACE_BUFFER_LINES,
LOG_TRACE_TARGET_SERVICES,
)
from app.services import pls
_server: asyncio.base_events.Server | None = None
_allowed_nfs = {nf.upper() for nf in LOG_ALLOWED_NFS}
_events: deque[dict[str, Any]] = deque(maxlen=max(LOG_BUFFER_LINES, 1))
_trace_events: deque[dict[str, Any]] = deque(maxlen=max(LOG_TRACE_BUFFER_LINES, LOG_BUFFER_LINES, 1))
_process_events: dict[str, deque[dict[str, Any]]] = {
nf.upper(): deque(maxlen=max(LOG_PROCESS_BUFFER_LINES, 1))
for nf in _allowed_nfs if nf != "SYSTEM"
}
_subscriber_events: dict[str, deque[dict[str, Any]]] = {}
_ingested_total = 0
_parse_errors = 0
_last_event_at: str | None = None
_db_initialized = False
_supi_pattern = re.compile(r"(imsi-\d{6,20}|\b\d{6,20}\b)", re.IGNORECASE)
_trace_state: dict[str, Any] = {
"active": False,
"filter": "",
"normalized": "",
"started_at": None,
"matched_events": 0,
"nodes": [],
"services": list(LOG_TRACE_TARGET_SERVICES),
"level": LOG_TRACE_DEBUG_LEVEL,
"original_levels": {},
}
def _db_path() -> Path:
return Path(LOG_ALERT_CONTEXT_DB_PATH)
def _ensure_db() -> None:
global _db_initialized
if _db_initialized:
return
path = _db_path()
path.parent.mkdir(parents=True, exist_ok=True)
conn = sqlite3.connect(path)
try:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS alert_context (
id TEXT PRIMARY KEY,
fingerprint TEXT UNIQUE,
created_at TEXT NOT NULL,
event_ts TEXT NOT NULL,
category TEXT NOT NULL,
nf TEXT,
node TEXT,
severity TEXT,
description TEXT,
remediation TEXT,
source TEXT,
match_message TEXT,
before_context TEXT,
after_context TEXT
)
"""
)
conn.commit()
finally:
conn.close()
_db_initialized = True
def _trim_db(conn: sqlite3.Connection) -> None:
conn.execute(
"""
DELETE FROM alert_context
WHERE id NOT IN (
SELECT id
FROM alert_context
ORDER BY event_ts DESC, created_at DESC
LIMIT ?
)
""",
(max(LOG_ALERT_CONTEXT_DB_MAX_ROWS, 1),),
)
def _parse_timestamp(value: Any) -> tuple[float, str]:
if value is None:
now = datetime.now(UTC)
return now.timestamp(), now.isoformat()
if isinstance(value, (int, float)):
raw = float(value)
if raw > 1_000_000_000_000:
raw = raw / 1_000_000.0
elif raw > 10_000_000_000:
raw = raw / 1000.0
dt = datetime.fromtimestamp(raw, UTC)
return raw, dt.isoformat()
text = str(value).strip()
if text.isdigit():
return _parse_timestamp(int(text))
normalized = text.replace("Z", "+00:00")
for candidate in (normalized, normalized.replace(" ", "T")):
try:
dt = datetime.fromisoformat(candidate)
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
else:
dt = dt.astimezone(UTC)
return dt.timestamp(), dt.isoformat()
except ValueError:
continue
now = datetime.now(UTC)
return now.timestamp(), now.isoformat()
def _candidate_fields(payload: dict[str, Any]) -> list[str]:
candidates = []
for key in (
"message",
"MESSAGE",
"log",
"msg",
"systemd_unit",
"_SYSTEMD_UNIT",
"syslog_identifier",
"SYSLOG_IDENTIFIER",
"_COMM",
"comm",
"_EXE",
"container_name",
"tag",
):
value = payload.get(key)
if value not in (None, ""):
candidates.append(str(value))
return candidates
def _infer_nf(payload: dict[str, Any], message: str) -> str:
haystack = " ".join(_candidate_fields(payload) + [message]).lower()
aliases = {
"upf": "UPF",
"amf": "AMF",
"smf": "SMF",
"udm": "UDM",
"udr": "UDR",
"nrf": "NRF",
"ausf": "AUSF",
"pcf": "PCF",
"mme": "MME",
"sgwc": "SGWC",
"dra": "DRA",
"dsm": "DSM",
"aaa": "AAA",
"bmsc": "BMSC",
"chf": "CHF",
"smsf": "SMSF",
"eir": "EIR",
"licensed": "LICENSED",
"prometheus": "PROMETHEUS",
"alertmanager": "ALERTMANAGER",
"fluent-bit": "FLUENT-BIT",
}
for needle, label in aliases.items():
if needle in haystack:
return label
return "SYSTEM"
def _normalize_supi(value: str | None) -> str:
if not value:
return ""
text = str(value).strip().lower()
if not text:
return ""
if text.startswith("imsi-"):
digits = "".join(ch for ch in text[5:] if ch.isdigit())
return f"imsi-{digits}" if digits else text
digits = "".join(ch for ch in text if ch.isdigit())
if digits:
return f"imsi-{digits}"
return text
def _extract_supis(message: str) -> list[str]:
matches = []
for raw in _supi_pattern.findall(message or ""):
normalized = _normalize_supi(raw)
if normalized and normalized not in matches:
matches.append(normalized)
return matches
def _matches_trace(event: dict[str, Any]) -> bool:
if not _trace_state.get("active"):
return False
normalized = _trace_state.get("normalized", "")
if not normalized:
return False
message = str(event.get("message", "")).lower()
if normalized in message:
return True
digits = normalized.removeprefix("imsi-")
return bool(digits and digits in message)
def _normalize_event(payload: dict[str, Any], remote_host: str) -> dict[str, Any]:
ts_value = (
payload.get("timestamp")
or payload.get("@timestamp")
or payload.get("time")
or payload.get("date")
or payload.get("_SOURCE_REALTIME_TIMESTAMP")
)
epoch, ts_iso = _parse_timestamp(ts_value)
node = (
payload.get("hostname")
or payload.get("host")
or payload.get("_HOSTNAME")
or payload.get("syslog_hostname")
or remote_host
)
source = (
payload.get("systemd_unit")
or payload.get("_SYSTEMD_UNIT")
or payload.get("syslog_identifier")
or payload.get("SYSLOG_IDENTIFIER")
or payload.get("_COMM")
or payload.get("tag")
or "unknown"
)
message = (
payload.get("message")
or payload.get("MESSAGE")
or payload.get("log")
or payload.get("msg")
or ""
)
message = str(message).strip()
supis = _extract_supis(message)
tag = str(payload.get("tag", ""))
nf = _infer_nf(payload, message)
fingerprint = sha1(f"{ts_iso}|{node}|{nf}|{source}|{message}".encode("utf-8")).hexdigest()
return {
"id": fingerprint,
"timestamp": ts_iso,
"epoch": epoch,
"node": str(node),
"nf": nf,
"source": str(source),
"tag": tag,
"message": message,
"supis": supis,
"raw": payload,
}
async def _ingest_payload(payload: dict[str, Any], remote_host: str) -> None:
global _ingested_total, _last_event_at
event = _normalize_event(payload, remote_host)
if event.get("nf", "").upper() not in _allowed_nfs:
return
_events.append(event)
_trace_events.append(event)
nf_key = event.get("nf", "").upper()
if nf_key:
_process_events.setdefault(nf_key, deque(maxlen=max(LOG_PROCESS_BUFFER_LINES, 1))).append(event)
for supi in event.get("supis", []):
_subscriber_events.setdefault(
supi,
deque(maxlen=max(LOG_SUBSCRIBER_BUFFER_LINES, 1)),
).append(event)
if _matches_trace(event):
_trace_state["matched_events"] = int(_trace_state.get("matched_events", 0)) + 1
_ingested_total += 1
_last_event_at = event["timestamp"]
async def _handle_client(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
global _parse_errors
peer = writer.get_extra_info("peername")
remote_host = peer[0] if isinstance(peer, tuple) and peer else "unknown"
try:
while not reader.at_eof():
line = await reader.readline()
if not line:
break
text = line.decode("utf-8", errors="replace").strip()
if not text:
continue
try:
payload = json.loads(text)
if isinstance(payload, dict):
await _ingest_payload(payload, remote_host)
elif isinstance(payload, list):
for item in payload:
if isinstance(item, dict):
await _ingest_payload(item, remote_host)
except Exception:
_parse_errors += 1
finally:
writer.close()
await writer.wait_closed()
async def startup() -> None:
global _server
_ensure_db()
if not LOG_INGEST_ENABLED or _server is not None:
return
_server = await asyncio.start_server(_handle_client, LOG_RECEIVER_BIND_HOST, LOG_RECEIVER_PORT)
if LOG_AUTO_CONFIGURE:
try:
await configure_site_output()
except Exception:
pass
async def shutdown() -> None:
global _server
if _server is None:
return
_server.close()
await _server.wait_closed()
_server = None
def receiver_status() -> dict[str, Any]:
return {
"enabled": LOG_INGEST_ENABLED,
"bind_host": LOG_RECEIVER_BIND_HOST,
"receiver_host": LOG_RECEIVER_HOST,
"port": LOG_RECEIVER_PORT,
"format": LOG_RECEIVER_FORMAT,
"allowed_nfs": sorted(_allowed_nfs),
"buffer_lines": LOG_BUFFER_LINES,
"process_buffer_lines": LOG_PROCESS_BUFFER_LINES,
"subscriber_buffer_lines": LOG_SUBSCRIBER_BUFFER_LINES,
"trace_buffer_lines": LOG_TRACE_BUFFER_LINES,
"context_before": LOG_ALERT_CONTEXT_BEFORE,
"context_after": LOG_ALERT_CONTEXT_AFTER,
"db_path": str(_db_path()),
"ingested_total": _ingested_total,
"parse_errors": _parse_errors,
"last_event_at": _last_event_at,
"current_buffer_size": len(_events),
"process_buffers": sorted(_process_events.keys()),
"subscriber_buffers": len(_subscriber_events),
"trace": {
"active": bool(_trace_state.get("active")),
"filter": _trace_state.get("filter", ""),
"started_at": _trace_state.get("started_at"),
"matched_events": _trace_state.get("matched_events", 0),
"nodes": list(_trace_state.get("nodes", [])),
"services": list(_trace_state.get("services", [])),
"level": _trace_state.get("level", LOG_TRACE_DEBUG_LEVEL),
},
}
def current_output_config(receiver_host: str) -> dict[str, Any]:
return {
"name": "tcp",
"match": LOG_FLUENTBIT_MATCH,
"host": receiver_host,
"port": LOG_RECEIVER_PORT,
"format": LOG_RECEIVER_FORMAT,
}
def default_input_config() -> dict[str, Any]:
return {
"name": "systemd",
"path": "/var/log/journal",
"tag": "marvis.systemd",
"read_from_tail": "on",
"strip_underscores": "off",
}
async def _resolve_receiver_host() -> str:
if LOG_RECEIVER_HOST:
return LOG_RECEIVER_HOST
cluster = await pls.get_cluster_status()
if isinstance(cluster, dict):
current_node = cluster.get("current_node")
if isinstance(current_node, str) and current_node:
return pls.node_host(current_node)
system = await pls.get_system_info()
if isinstance(system, dict) and system.get("hostname"):
return str(system["hostname"])
return "127.0.0.1"
def _merged_fluentbit_config(config: dict[str, Any], receiver_host: str) -> dict[str, Any]:
merged = dict(config or {})
pipeline = dict(merged.get("pipeline") or {})
inputs = list(pipeline.get("inputs") or [])
outputs = list(pipeline.get("outputs") or [])
desired = current_output_config(receiver_host)
if not inputs:
inputs = [default_input_config()]
filtered = []
for output in outputs:
if not isinstance(output, dict):
continue
is_existing_marvis = (
output.get("name") == "tcp"
and output.get("port") == LOG_RECEIVER_PORT
and output.get("format") == LOG_RECEIVER_FORMAT
)
if not is_existing_marvis:
filtered.append(output)
filtered.append(desired)
pipeline["inputs"] = inputs
pipeline["outputs"] = filtered
merged["pipeline"] = pipeline
if "parsers" not in merged:
merged["parsers"] = list(config.get("parsers") or []) if isinstance(config, dict) else []
return merged
async def configure_site_output() -> dict[str, Any]:
current = await pls.get_fluentbit_config()
if not isinstance(current, dict):
raise RuntimeError("Could not read current Fluent Bit config from PLS")
receiver_host = await _resolve_receiver_host()
desired = _merged_fluentbit_config(current, receiver_host)
updated = await pls.put_fluentbit_config(desired)
if not isinstance(updated, dict):
raise RuntimeError("PLS rejected Fluent Bit config update")
return {
"receiver_host": receiver_host,
"receiver_port": LOG_RECEIVER_PORT,
"match": LOG_FLUENTBIT_MATCH,
"config": updated,
}
def _sort_and_limit(events: list[dict[str, Any]], limit: int | None = None) -> list[dict[str, Any]]:
deduped: dict[str, dict[str, Any]] = {}
for event in events:
deduped[event.get("id", str(id(event)))] = event
ordered = sorted(deduped.values(), key=lambda event: event.get("epoch", 0.0))
if limit is not None:
return ordered[-limit:]
return ordered
def get_process_events(nf: str, limit: int | None = None) -> list[dict[str, Any]]:
nf_key = str(nf or "").upper()
events = list(_process_events.get(nf_key, []))
return _sort_and_limit(events, limit)
def get_subscriber_events(supi_or_fragment: str, limit: int | None = None) -> list[dict[str, Any]]:
normalized = _normalize_supi(supi_or_fragment)
fragment = str(supi_or_fragment or "").strip().lower()
if not normalized and not fragment:
return []
matches: list[dict[str, Any]] = []
for supi, events in _subscriber_events.items():
digits = supi.removeprefix("imsi-")
if normalized and (supi == normalized or normalized in supi or normalized.removeprefix("imsi-") in digits):
matches.extend(events)
continue
if fragment and (fragment in supi.lower() or fragment in digits):
matches.extend(events)
return _sort_and_limit(matches, limit)
async def _trace_target_nodes() -> list[dict[str, Any]]:
cluster = await pls.get_cluster_status()
nodes = []
if isinstance(cluster, dict):
for node in cluster.get("nodes", []):
host = pls.node_host(node.get("name", ""))
if host:
nodes.append({"name": node.get("name", ""), "host": host})
if not nodes:
system = await pls.get_system_info()
host = str(system.get("hostname", "") if isinstance(system, dict) else "") or "127.0.0.1"
nodes.append({"name": host, "host": host})
deduped = {}
for node in nodes:
deduped[node["host"]] = node
return list(deduped.values())
async def start_subscriber_trace(supi_or_fragment: str) -> dict[str, Any]:
normalized = _normalize_supi(supi_or_fragment)
fragment = str(supi_or_fragment or "").strip()
if not normalized and not fragment:
raise RuntimeError("A SUPI or SUPI fragment is required to start a trace")
if _trace_state.get("active"):
await stop_subscriber_trace()
target_nodes = await _trace_target_nodes()
original_levels: dict[str, dict[str, Any]] = {}
applied_nodes: list[str] = []
for node in target_nodes:
host = node["host"]
current = await pls.get_log_config(host=host)
if not isinstance(current, dict):
continue
original_levels[host] = current
updated = dict(current)
updated["level"] = LOG_TRACE_DEBUG_LEVEL
await pls.put_log_config(updated, host=host)
applied_nodes.append(host)
_trace_state.update(
{
"active": True,
"filter": fragment,
"normalized": normalized or fragment.lower(),
"started_at": datetime.now(UTC).isoformat(),
"matched_events": 0,
"nodes": applied_nodes,
"services": list(LOG_TRACE_TARGET_SERVICES),
"level": LOG_TRACE_DEBUG_LEVEL,
"original_levels": original_levels,
}
)
return receiver_status()["trace"]
async def stop_subscriber_trace() -> dict[str, Any]:
original_levels = dict(_trace_state.get("original_levels", {}))
restored_nodes: list[str] = []
for host, config in original_levels.items():
try:
if isinstance(config, dict):
await pls.put_log_config(config, host=host)
restored_nodes.append(host)
except Exception:
continue
summary = {
"filter": _trace_state.get("filter", ""),
"started_at": _trace_state.get("started_at"),
"matched_events": _trace_state.get("matched_events", 0),
"restored_nodes": restored_nodes,
}
_trace_state.update(
{
"active": False,
"filter": "",
"normalized": "",
"started_at": None,
"matched_events": 0,
"nodes": [],
"services": list(LOG_TRACE_TARGET_SERVICES),
"level": LOG_TRACE_DEBUG_LEVEL,
"original_levels": {},
}
)
return summary
def get_events(limit: int | None = None, node: str | None = None, nf: str | None = None, imsi: str | None = None) -> list[dict[str, Any]]:
if imsi:
events = get_subscriber_events(imsi, limit=None)
elif nf:
events = get_process_events(nf, limit=None)
else:
events = list(_events)
if node:
node_l = node.lower()
events = [event for event in events if event.get("node", "").lower() == node_l]
if nf:
nf_u = nf.upper()
events = [event for event in events if event.get("nf", "").upper() == nf_u]
if imsi:
needle = str(imsi).strip().lower()
normalized = _normalize_supi(imsi)
digits = normalized.removeprefix("imsi-") if normalized else ""
events = [
event for event in events
if needle and (
needle in event.get("message", "").lower()
or any(
needle in supi.lower() or (digits and digits in supi.removeprefix("imsi-"))
for supi in event.get("supis", [])
)
)
]
return _sort_and_limit(events, limit)
def record_alert_context(
*,
category: str,
nf: str,
node: str,
severity: str,
description: str,
remediation: str,
source: str,
event: dict[str, Any],
before_context: list[dict[str, Any]],
after_context: list[dict[str, Any]],
) -> str:
_ensure_db()
fingerprint = sha1(
"|".join(
[
category,
nf,
node,
severity,
description,
remediation,
event.get("timestamp", ""),
event.get("message", ""),
]
).encode("utf-8")
).hexdigest()
alert_id = sha1(f"{fingerprint}|{source}".encode("utf-8")).hexdigest()
conn = sqlite3.connect(_db_path())
try:
conn.execute(
"""
INSERT OR REPLACE INTO alert_context (
id, fingerprint, created_at, event_ts, category, nf, node, severity,
description, remediation, source, match_message, before_context, after_context
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
alert_id,
fingerprint,
datetime.now(UTC).isoformat(),
event.get("timestamp", ""),
category,
nf,
node,
severity,
description,
remediation,
source,
event.get("message", ""),
json.dumps(before_context),
json.dumps(after_context),
),
)
_trim_db(conn)
conn.commit()
finally:
conn.close()
return alert_id
def recent_alert_context(limit: int = 20) -> list[dict[str, Any]]:
_ensure_db()
conn = sqlite3.connect(_db_path())
conn.row_factory = sqlite3.Row
try:
rows = conn.execute(
"""
SELECT id, created_at, event_ts, category, nf, node, severity, description,
remediation, source, match_message, before_context, after_context
FROM alert_context
ORDER BY event_ts DESC, created_at DESC
LIMIT ?
""",
(limit,),
).fetchall()
return [dict(row) for row in rows]
finally:
conn.close()
def known_nfs() -> list[str]:
return list(ALL_NFS)