Files
p5g-marvis/app/services/ai.py
2026-04-27 13:42:49 -04:00

500 lines
21 KiB
Python
Raw Blame History

This file contains invisible Unicode characters
This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
AI engine for P5G Marvis.
Phase 1: rule-based with real network data.
Phase 2: swap MARVIS_AI_MODE=openai or MARVIS_AI_MODE=ollama to route through LLM.
"""
from datetime import datetime
import re
from app.config import (
AI_MODE,
ALL_NFS,
CONTAINER_RUNTIME,
OPENAI_API_KEY,
OPENAI_MODEL,
OPENAI_BASE_URL,
OLLAMA_MODEL,
OLLAMA_URL,
)
async def answer(query: str, network_state: dict, alerts: list, logs: list[dict] | None = None) -> str:
special = await _handle_log_queries(query, network_state, alerts, logs or [])
if special:
return special
if AI_MODE == "openai":
return await _call_openai(query, network_state, alerts, logs or [])
if AI_MODE == "ollama":
return await _call_ollama(query, network_state, alerts, logs or [])
return _rule_based(query, network_state, alerts, logs or [])
# ── Rule-based engine ──────────────────────────────────────────────────────
def _rule_based(query: str, network_state: dict, alerts: list, logs: list[dict]) -> str:
q = query.lower()
nfs = network_state.get("nfs", [])
cluster = network_state.get("cluster", {})
up = [n for n in nfs if n["state"] == "up"]
down = [n for n in nfs if n["state"] == "down"]
log_hits = _find_log_hits(q, logs)
if any(w in q for w in ["hello", "hi ", "hey", "howdy"]):
return ("Hello! I'm **P5G Marvis**, your AI network assistant for HPE Private 5G.\n"
"Ask me about network health, specific functions, alerts, or performance.")
if any(w in q for w in ["help", "what can", "capabilities", "commands", "features"]):
return (
"Here's what I can help with:\n\n"
"• **Network health** — overall P5G status overview\n"
"• **Network functions** — ask about AMF, SMF, UPF, UDM, NRF, etc.\n"
"• **Alerts** — active alarms and their severity\n"
"• **Subscribers** — UE registration and session analysis\n"
"• **Sessions** — PDU session and data plane health\n\n"
"_Tip: Connect an LLM by setting `MARVIS_AI_MODE=openai` or `=ollama`._"
)
# Specific NF query
for nf_name in ALL_NFS:
if nf_name.lower() in q:
return _nf_detail(nf_name, nfs, alerts, log_hits)
if any(w in q for w in ["alert", "alarm", "warning", "critical", "incident", "problem", "issue"]):
return _alerts_summary(alerts)
if any(w in q for w in ["log", "trace", "journal", "message", "error"]):
return _log_summary(log_hits, logs)
if any(w in q for w in ["subscriber", "ue ", "device", "phone", "handset", "registration", "attach"]):
return _subscriber_analysis(nfs, alerts, cluster, log_hits)
if any(w in q for w in ["session", "pdu", "bearer", "user plane", "traffic", "throughput"]):
return _session_analysis(nfs, alerts, cluster, log_hits)
# Default → health summary
return _health_summary(up, down, alerts, cluster, log_hits)
async def _handle_log_queries(query: str, network_state: dict, alerts: list, logs: list[dict]) -> str | None:
from app.services import log_analyzer, log_ingest
q = query.strip()
lowered = q.lower()
if "trace" in lowered and any(word in lowered for word in ["stop", "end", "disable", "finish"]):
summary = await log_ingest.stop_subscriber_trace()
if not summary.get("started_at"):
return " No subscriber trace is currently active."
return (
f"🛑 **Subscriber trace stopped** for `{summary.get('filter')}`\n\n"
f"Started: {summary.get('started_at')}\n"
f"Matched events: **{summary.get('matched_events', 0)}**\n"
f"Restored nodes: {', '.join(summary.get('restored_nodes', [])) or 'none'}"
)
trace_target = _extract_trace_target(q)
if trace_target:
state = await log_ingest.start_subscriber_trace(trace_target)
events = log_ingest.get_subscriber_events(trace_target, limit=20)
findings = log_analyzer.summarize_event_slice(events)
return _format_trace_response(trace_target, state, events, findings)
supi_query = _extract_supi_query(q)
asks_logs = any(
phrase in lowered
for phrase in ["show me the logs", "show logs", "logs for", "what do the logs show", "trace output", "recent logs"]
)
nf_query = _extract_nf_query(q)
if supi_query and (_is_bare_supi(q) or "subscriber" in lowered or "supi" in lowered or "imsi" in lowered or asks_logs):
events = log_ingest.get_subscriber_events(supi_query, limit=500)
findings = log_analyzer.summarize_event_slice(events)
return _format_log_slice(
title=f"Subscriber logs for `{supi_query}`",
events=events,
findings=findings,
empty_message=f" No recent logs matched subscriber `{supi_query}`.",
)
if nf_query and ("process" in lowered or asks_logs or "show me" in lowered):
events = log_ingest.get_process_events(nf_query, limit=500)
findings = log_analyzer.summarize_event_slice(events)
return _format_log_slice(
title=f"Process logs for `{nf_query}`",
events=events,
findings=findings,
empty_message=f" No recent logs are buffered for process `{nf_query}`.",
)
return None
def _health_summary(up: list, down: list, alerts: list, cluster: dict, log_hits: list[dict]) -> str:
ts = datetime.now().strftime("%H:%M:%S")
crit = [a for a in alerts if a.get("severity") == "critical"]
warn = [a for a in alerts if a.get("severity") != "critical"]
lines = [f"**P5G Network Health — {ts}**\n"]
nodes = cluster.get("nodes", [])
if up:
lines.append(f"✅ **{len(up)} UP**: {', '.join(_nf_label(n) for n in up)}")
if down:
lines.append(f"🔴 **{len(down)} DOWN**: {', '.join(_nf_label(n) for n in down)}")
lines.append(" ⚡ Action: inspect the node shown for each affected NF before pulling logs.")
if nodes:
lines.append(f"\n**Cluster nodes ({len(nodes)})**")
for node in nodes:
running = [nf["name"] for nf in node.get("nfs", []) if nf.get("state") == "up"]
down_nfs = [nf["name"] for nf in node.get("nfs", []) if nf.get("state") == "down"]
role = node.get("role", "AP")
lines.append(
f"• **{node['hostname']}** ({role}{', local' if node.get('current') else ''})"
f" — running: {', '.join(running) or 'none'}"
)
if down_nfs:
lines.append(f" down here: {', '.join(down_nfs)}")
if alerts:
lines.append(f"\n⚠️ **{len(alerts)} alert(s)** — {len(crit)} critical, {len(warn)} warning")
for a in alerts[:4]:
icon = "🔴" if a.get("severity") == "critical" else "🟡"
lines.append(f" {icon} {a['name']}: {a.get('summary', a.get('instance', ''))}")
else:
lines.append("\n✅ **No active alerts**")
if not down and not alerts:
lines.append("\n🟢 All systems nominal.")
if log_hits:
lines.append(f"\n🧾 **Relevant log hits ({len(log_hits)})**")
for hit in log_hits[:4]:
lines.append(
f"{hit.get('timestamp','')}{hit.get('node','unknown')} {hit.get('nf','SYSTEM')}: "
f"{_trim_message(hit.get('message',''))}"
)
return "\n".join(lines)
def _nf_detail(nf_name: str, nfs: list, alerts: list, log_hits: list[dict]) -> str:
nf = next((n for n in nfs if n["name"] == nf_name), None)
nf_alerts = [a for a in alerts
if nf_name in a.get("name", "") or nf_name.lower() in a.get("instance", "").lower()]
nf_logs = [hit for hit in log_hits if hit.get("nf") == nf_name]
if not nf or nf["state"] == "unknown":
return (f" No Prometheus data found for **{nf_name}**.\n"
f"Check: `{CONTAINER_RUNTIME} ps | grep {nf_name.lower()}`")
icon = "" if nf["state"] == "up" else "🔴"
placements = nf.get("nodes", [])
lines = [f"{icon} **{nf_name}** is **{nf['state'].upper()}**"]
if placements:
node_text = ", ".join(
f"{node['hostname']} ({'/'.join(node.get('roles', []))})"
for node in placements
)
lines.append(f"Nodes: {node_text}")
lines.append(f"Instance: `{nf.get('instance', 'n/a')}`")
if nf_alerts:
lines.append(f"\n⚠️ {len(nf_alerts)} alert(s) for {nf_name}:")
for a in nf_alerts:
lines.append(f"{a['name']}: {a.get('summary', '')}")
else:
lines.append("No active alerts for this function.")
if nf_logs:
lines.append(f"\n🧾 Recent {nf_name} log evidence:")
for hit in nf_logs[:4]:
lines.append(
f"{hit.get('timestamp','')} on {hit.get('node','unknown')}: "
f"{_trim_message(hit.get('message',''))}"
)
return "\n".join(lines)
def _alerts_summary(alerts: list) -> str:
if not alerts:
return "✅ **No active alerts.** Network is running cleanly."
crit = [a for a in alerts if a.get("severity") == "critical"]
warn = [a for a in alerts if a.get("severity") != "critical"]
lines = [f"⚠️ **{len(alerts)} active alert(s)** — {len(crit)} critical, {len(warn)} warning\n"]
for a in alerts:
icon = "🔴" if a.get("severity") == "critical" else "🟡"
lines.append(f"{icon} **{a['name']}**")
if a.get("summary"):
lines.append(f" {a['summary']}")
if a.get("instance"):
lines.append(f" `{a['instance']}`")
return "\n".join(lines)
def _subscriber_analysis(nfs: list, alerts: list, cluster: dict, log_hits: list[dict]) -> str:
amf = next((n for n in nfs if n["name"] == "AMF"), None)
smf = next((n for n in nfs if n["name"] == "SMF"), None)
lines = ["**Subscriber & Registration Analysis**\n"]
lines.append(f"AMF (registration/mobility): {_nf_sentence(amf, 'subscribers cannot register')}")
lines.append(f"SMF (session management): {_nf_sentence(smf, 'no new data sessions')}")
sub_alerts = [a for a in alerts if any(k in a.get("name", "").lower()
for k in ["ue", "subscriber", "session", "attach", "registration"])]
if sub_alerts:
lines.append(f"\n⚠️ {len(sub_alerts)} subscriber-related alert(s) active.")
else:
lines.append("\nNo subscriber-related alerts detected.")
sub_logs = [hit for hit in log_hits if any(key in hit.get("message", "").lower() for key in ["imsi", "supi", "registration", "attach", "subscriber"])]
if sub_logs:
lines.append("\nRecent subscriber-related log evidence:")
for hit in sub_logs[:4]:
lines.append(
f"{hit.get('nf','SYSTEM')} on {hit.get('node','unknown')}: {_trim_message(hit.get('message',''))}"
)
lines.append(_cluster_scope(cluster))
return "\n".join(lines)
def _session_analysis(nfs: list, alerts: list, cluster: dict, log_hits: list[dict]) -> str:
smf = next((n for n in nfs if n["name"] == "SMF"), None)
upf = next((n for n in nfs if n["name"] == "UPF"), None)
lines = ["**PDU Session & Data Plane Analysis**\n"]
lines.append(f"SMF: {_nf_sentence(smf, 'session setup is blocked')}")
lines.append(f"UPF: {_nf_sentence(upf, 'user-plane forwarding is blocked')}")
if (not smf or smf["state"] != "up") or (not upf or upf["state"] != "up"):
lines.append("\n⚡ **Impact**: PDU sessions will fail until both SMF and UPF are operational.")
else:
lines.append("\nBoth SMF and UPF operational — sessions should be establishing normally.")
session_logs = [hit for hit in log_hits if hit.get("nf") in {"SMF", "UPF"}]
if session_logs:
lines.append("\nRecent session/data-plane log evidence:")
for hit in session_logs[:4]:
lines.append(
f"{hit.get('nf','SYSTEM')} on {hit.get('node','unknown')}: {_trim_message(hit.get('message',''))}"
)
lines.append(_cluster_scope(cluster))
return "\n".join(lines)
def _log_summary(log_hits: list[dict], logs: list[dict]) -> str:
if not logs:
return " No ingested logs are currently available."
if not log_hits:
latest = max(logs, key=lambda event: event.get("epoch", 0.0), default=None)
if latest:
return (
" I do not see direct log matches for that question.\n\n"
f"Latest ingested log: {latest.get('timestamp','')} on {latest.get('node','unknown')} "
f"{latest.get('nf','SYSTEM')}{_trim_message(latest.get('message',''))}"
)
return " No relevant log matches were found."
lines = [f"🧾 **Relevant log matches ({len(log_hits)})**\n"]
for hit in log_hits[:8]:
lines.append(
f"{hit.get('timestamp','')}{hit.get('node','unknown')} {hit.get('nf','SYSTEM')}: "
f"{_trim_message(hit.get('message',''))}"
)
return "\n".join(lines)
def _extract_supi_query(query: str) -> str:
lowered = query.lower()
match = re.search(r"(imsi-\d{6,20}|\b\d{6,20}\b)", lowered)
if not match:
return ""
token = match.group(1)
if token.startswith("imsi-"):
return token
return f"imsi-{token}"
def _is_bare_supi(query: str) -> bool:
cleaned = query.strip().lower()
return bool(re.fullmatch(r"(imsi-\d{6,20}|\d{6,20})", cleaned))
def _extract_nf_query(query: str) -> str:
text = query.upper()
for nf_name in ALL_NFS:
if nf_name in text:
return nf_name
return ""
def _extract_trace_target(query: str) -> str:
lowered = query.lower()
if "trace" not in lowered:
return ""
if not any(word in lowered for word in ["start", "run", "begin", "trace"]):
return ""
return _extract_supi_query(query)
def _format_log_slice(*, title: str, events: list[dict], findings: list[dict], empty_message: str) -> str:
if not events:
return empty_message
lines = [f"🧾 **{title}**", f"Buffered lines: **{len(events)}**\n"]
if findings:
lines.append("Rule hits:")
for finding in findings[:6]:
lines.append(
f"• **{finding['severity'].upper()}** {finding['nf']} on {finding.get('node','unknown')}: "
f"{finding['description']}"
)
lines.append(f" Fix: {finding['remediation']}")
lines.append("")
lines.append("Recent log lines:")
for event in events[-12:]:
lines.append(
f"{event.get('timestamp','')}{event.get('node','unknown')} {event.get('nf','SYSTEM')}: "
f"{_trim_message(event.get('message',''), 220)}"
)
return "\n".join(lines)
def _format_trace_response(target: str, state: dict, events: list[dict], findings: list[dict]) -> str:
lines = [
f"🔎 **Subscriber trace active** for `{target}`",
f"Level override: **{state.get('level', 'debug')}**",
f"Nodes updated: {', '.join(state.get('nodes', [])) or 'none'}",
f"Matched events so far: **{state.get('matched_events', 0)}**\n",
]
if findings:
lines.append("Current rule-based diagnosis:")
for finding in findings[:5]:
lines.append(
f"• **{finding['severity'].upper()}** {finding['nf']} on {finding.get('node','unknown')}: "
f"{finding['description']}"
)
lines.append(f" Fix: {finding['remediation']}")
lines.append("")
if events:
lines.append("Current trace lines:")
for event in events[-10:]:
lines.append(
f"{event.get('timestamp','')}{event.get('node','unknown')} {event.get('nf','SYSTEM')}: "
f"{_trim_message(event.get('message',''), 220)}"
)
else:
lines.append("No matching subscriber logs have arrived yet.")
lines.append("\nUse `stop trace` when the attach/session test is complete.")
return "\n".join(lines)
def _nf_label(nf: dict) -> str:
placements = nf.get("nodes", [])
if not placements:
return nf["name"]
return f"{nf['name']} on {', '.join(node['hostname'] for node in placements)}"
def _nf_sentence(nf: dict | None, impact: str) -> str:
if not nf:
return "○ N/A"
if nf.get("state") == "up":
nodes = ", ".join(node["hostname"] for node in nf.get("nodes", [])) or nf.get("instance", "unknown host")
return f"✅ UP on {nodes}"
return f"🔴 DOWN — {impact}"
def _cluster_scope(cluster: dict) -> str:
nodes = cluster.get("nodes", [])
if not nodes:
return "\nCluster discovery is not available."
details = ", ".join(f"{node['hostname']} ({node.get('role', 'AP')})" for node in nodes)
return f"\nCluster scope checked: {details}"
# ── LLM backends ──────────────────────────────────────────────────────────
def _build_context(network_state: dict, alerts: list, logs: list[dict]) -> str:
nfs = network_state.get("nfs", [])
up = [n["name"] for n in nfs if n["state"] == "up"]
down = [n["name"] for n in nfs if n["state"] == "down"]
nodes = network_state.get("cluster", {}).get("nodes", [])
node_summary = ", ".join(f"{node['hostname']} ({node.get('role', 'AP')})" for node in nodes) or "none"
recent_logs = logs[-10:] if logs else []
log_summary = "; ".join(
f"{entry.get('timestamp','')} {entry.get('node','unknown')} {entry.get('nf','SYSTEM')}: {_trim_message(entry.get('message',''), 120)}"
for entry in recent_logs
) or "none"
return (
f"NFs UP: {', '.join(up) or 'none'}\n"
f"NFs DOWN: {', '.join(down) or 'none'}\n"
f"Cluster nodes: {node_summary}\n"
f"Active alerts: {', '.join(a.get('name','') for a in alerts[:5]) or 'none'}\n"
f"Recent logs: {log_summary}"
)
async def _call_openai(query: str, network_state: dict, alerts: list, logs: list[dict]) -> str:
try:
import httpx
ctx = _build_context(network_state, alerts, logs)
messages = [
{"role": "system", "content":
f"You are P5G Marvis, an AI network assistant for HPE Private 5G.\n"
f"Current network state:\n{ctx}\n\nRespond concisely, use markdown."},
{"role": "user", "content": query},
]
base = OPENAI_BASE_URL.rstrip("/")
headers = {"Content-Type": "application/json"}
if OPENAI_API_KEY:
headers["Authorization"] = f"Bearer {OPENAI_API_KEY}"
# disable cert verification for self-signed local LLM servers
verify = base.startswith("https://api.openai.com")
async with httpx.AsyncClient(timeout=120, verify=verify) as client:
resp = await client.post(
f"{base}/v1/chat/completions",
headers=headers,
json={"model": OPENAI_MODEL, "messages": messages, "max_tokens": 1024},
)
msg = resp.json()["choices"][0]["message"]
# some reasoning models put the answer in content, others in reasoning_content
return msg.get("content") or msg.get("reasoning_content") or "(empty response)"
except Exception as e:
return f"LLM error: {e}\n\n" + _rule_based(query, network_state, alerts, logs)
async def _call_ollama(query: str, network_state: dict, alerts: list, logs: list[dict]) -> str:
try:
import httpx
ctx = _build_context(network_state, alerts, logs)
prompt = (f"You are P5G Marvis, an AI network assistant.\n"
f"Network state:\n{ctx}\n\nUser: {query}\nAssistant:")
async with httpx.AsyncClient(timeout=60) as client:
resp = await client.post(
f"{OLLAMA_URL}/api/generate",
json={"model": OLLAMA_MODEL, "prompt": prompt, "stream": False},
)
return resp.json().get("response", "No response.")
except Exception as e:
return f"Ollama error: {e}\n\n" + _rule_based(query, network_state, alerts, logs)
def _find_log_hits(query: str, logs: list[dict]) -> list[dict]:
terms = [term for term in re.findall(r"[a-z0-9_-]+", query.lower()) if len(term) >= 3]
if not logs or not terms:
return []
hits = []
for event in logs:
haystack = " ".join(
[
str(event.get("nf", "")).lower(),
str(event.get("node", "")).lower(),
str(event.get("source", "")).lower(),
str(event.get("message", "")).lower(),
]
)
score = sum(1 for term in terms if term in haystack)
if score:
event_copy = dict(event)
event_copy["_score"] = score
hits.append(event_copy)
hits.sort(key=lambda event: (event.get("_score", 0), event.get("epoch", 0.0)), reverse=True)
return hits
def _trim_message(message: str, limit: int = 160) -> str:
message = " ".join(str(message).split())
if len(message) <= limit:
return message
return message[: limit - 3] + "..."