mirror of
https://github.com/aljazceru/signallama.git
synced 2025-12-18 07:04:20 +01:00
479 lines
18 KiB
Python
479 lines
18 KiB
Python
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
import signal as py_signal
|
|
import sqlite3
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
from settings import SIGNAL_URL, SIGNAL_NUMBER, LLM_MODEL, LLM_API_BASE, LLM_API_KEY, LLM_PROVIDER, WHISPER_URL
|
|
|
|
import aiohttp
|
|
import litellm
|
|
import tempfile
|
|
import os
|
|
|
|
# Required packages:
|
|
# pip install litellm aiohttp
|
|
|
|
# Configure logging
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
datefmt="%Y-%m-%d %H:%M:%S",
|
|
)
|
|
logger = logging.getLogger("SignalLLMBridge")
|
|
|
|
# Configure LiteLLM
|
|
litellm.set_verbose = False # Set to True for debugging
|
|
|
|
DB_FILE = Path(__file__).parent / "history.db"
|
|
MAX_HISTORY = 10 # number of turns to keep per user
|
|
|
|
|
|
def filter_think_tags(text: str) -> str:
|
|
"""Remove content between <think></think> tags from text"""
|
|
# Remove <think>...</think> blocks (case insensitive, multiline)
|
|
filtered = re.sub(r'<think>.*?</think>', '', text, flags=re.IGNORECASE | re.DOTALL)
|
|
# Clean up extra whitespace
|
|
filtered = re.sub(r'\n\s*\n\s*\n', '\n\n', filtered) # Multiple newlines to double
|
|
return filtered.strip()
|
|
|
|
|
|
def init_db(db_path: Path) -> None:
|
|
conn = sqlite3.connect(str(db_path))
|
|
c = conn.cursor()
|
|
c.execute(
|
|
'''
|
|
CREATE TABLE IF NOT EXISTS history (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
user TEXT NOT NULL,
|
|
role TEXT NOT NULL,
|
|
content TEXT NOT NULL,
|
|
timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
'''
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
|
|
class ContextManager:
|
|
def __init__(self, db_path: Path, max_history: int = MAX_HISTORY) -> None:
|
|
self.db_path = db_path
|
|
self.max_history = max_history
|
|
|
|
def add_message(self, user: str, role: str, content: str) -> None:
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
c = conn.cursor()
|
|
c.execute(
|
|
'INSERT INTO history (user, role, content) VALUES (?, ?, ?)',
|
|
(user, role, content)
|
|
)
|
|
conn.commit()
|
|
# prune old
|
|
c.execute(
|
|
'''
|
|
DELETE FROM history
|
|
WHERE id IN (
|
|
SELECT id FROM history
|
|
WHERE user = ?
|
|
ORDER BY timestamp DESC
|
|
LIMIT -1 OFFSET ?
|
|
)
|
|
''', (user, self.max_history * 2)
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def get_history(self, user: str) -> List[Dict[str, str]]:
|
|
conn = sqlite3.connect(str(self.db_path))
|
|
c = conn.cursor()
|
|
c.execute(
|
|
'SELECT role, content FROM history WHERE user = ? ORDER BY timestamp ASC',
|
|
(user,)
|
|
)
|
|
rows = c.fetchall()
|
|
conn.close()
|
|
return [{'role': row[0], 'content': row[1]} for row in rows]
|
|
|
|
|
|
@dataclass
|
|
class SignalConfig:
|
|
api_url: str
|
|
number: str
|
|
receive_timeout: int = 10 # seconds
|
|
poll_interval: float = 1.0 # seconds between polls
|
|
|
|
|
|
@dataclass
|
|
class LLMConfig:
|
|
model: str
|
|
api_base: Optional[str] = None
|
|
api_key: Optional[str] = None
|
|
provider: Optional[str] = None
|
|
|
|
|
|
@dataclass
|
|
class WhisperConfig:
|
|
api_url: str
|
|
enabled: bool = True
|
|
|
|
|
|
class SignalLLMBridge:
|
|
def __init__(
|
|
self,
|
|
signal_cfg: SignalConfig,
|
|
llm_cfg: LLMConfig,
|
|
whisper_cfg: WhisperConfig,
|
|
context_mgr: ContextManager
|
|
) -> None:
|
|
self.signal_cfg = signal_cfg
|
|
self.llm_cfg = llm_cfg
|
|
self.whisper_cfg = whisper_cfg
|
|
self.context = context_mgr
|
|
self.session: Optional[aiohttp.ClientSession] = None
|
|
self.running = True
|
|
|
|
# Configure LiteLLM settings
|
|
if llm_cfg.api_base:
|
|
litellm.api_base = llm_cfg.api_base
|
|
if llm_cfg.api_key:
|
|
litellm.api_key = llm_cfg.api_key
|
|
|
|
async def start(self) -> None:
|
|
init_db(self.context.db_path)
|
|
self.session = aiohttp.ClientSession()
|
|
loop = asyncio.get_running_loop()
|
|
loop.add_signal_handler(py_signal.SIGINT, self._stop)
|
|
loop.add_signal_handler(py_signal.SIGTERM, self._stop)
|
|
logger.info("Bridge started using REST polling mode with model: %s (Whisper: %s)",
|
|
self.llm_cfg.model,
|
|
"enabled" if self.whisper_cfg.enabled else "disabled")
|
|
await self._poll_loop()
|
|
|
|
async def _poll_loop(self) -> None:
|
|
# Don't URL encode the number - signal-cli-rest-api expects raw number
|
|
receive_url = f"{self.signal_cfg.api_url}/v1/receive/{self.signal_cfg.number}"
|
|
send_url = f"{self.signal_cfg.api_url}/v2/send"
|
|
|
|
while self.running:
|
|
try:
|
|
params = {
|
|
"timeout": self.signal_cfg.receive_timeout,
|
|
"ignore_attachments": "false",
|
|
"ignore_stories": "true"
|
|
}
|
|
async with self.session.get(receive_url, params=params) as resp:
|
|
resp.raise_for_status()
|
|
text = await resp.text()
|
|
|
|
# Parse response - it's a JSON array
|
|
if not text.strip():
|
|
await asyncio.sleep(self.signal_cfg.poll_interval)
|
|
continue
|
|
|
|
logger.debug("Raw response from signal API: %s", text[:500])
|
|
|
|
try:
|
|
# Response is a JSON array of messages
|
|
messages = json.loads(text)
|
|
if not isinstance(messages, list):
|
|
messages = [messages] # Handle single message case
|
|
logger.debug("Parsed %d messages from API", len(messages))
|
|
except json.JSONDecodeError as e:
|
|
logger.error("Failed to parse JSON response: %s", e)
|
|
logger.debug("Problematic response: %s", text[:200])
|
|
await asyncio.sleep(self.signal_cfg.poll_interval)
|
|
continue
|
|
|
|
except Exception as e:
|
|
logger.error("Error receiving messages: %s", e)
|
|
await asyncio.sleep(self.signal_cfg.poll_interval)
|
|
continue
|
|
|
|
for msg in messages:
|
|
try:
|
|
# Skip non-dict messages
|
|
if not isinstance(msg, dict):
|
|
logger.debug("Skipping non-dict message: %s", type(msg).__name__)
|
|
continue
|
|
|
|
# Extract envelope
|
|
envelope = msg.get('envelope', {})
|
|
if not envelope:
|
|
logger.debug("No envelope found, skipping message")
|
|
continue
|
|
|
|
# Only process messages with actual content (dataMessage)
|
|
data_message = envelope.get('dataMessage')
|
|
if not data_message:
|
|
# Skip typing indicators and other non-content messages
|
|
msg_type = 'typing' if envelope.get('typingMessage') else 'other'
|
|
logger.debug("Skipping %s message (no dataMessage)", msg_type)
|
|
continue
|
|
|
|
# Get sender info
|
|
author = (envelope.get('source') or
|
|
envelope.get('sourceNumber') or
|
|
envelope.get('sourceName'))
|
|
|
|
# Get message content
|
|
body = (data_message.get('message') or '').strip()
|
|
|
|
if not author:
|
|
logger.debug("Missing author, skipping message")
|
|
continue
|
|
|
|
# Check if this is a voice message
|
|
if self._is_voice_message(data_message):
|
|
try:
|
|
handled = await self._process_voice_message(data_message, author)
|
|
if handled:
|
|
continue
|
|
except Exception as e:
|
|
logger.error("Error processing voice message from %s: %s", author, e)
|
|
await self._send_reply(author, "Sorry, I encountered an error processing your voice message.")
|
|
continue
|
|
|
|
# Handle regular text messages
|
|
if not body:
|
|
logger.debug("No text content in message, skipping")
|
|
continue
|
|
|
|
logger.info("Received from %s: %s", author, body)
|
|
reply = await self._get_ai_response(body, author)
|
|
await self._send_reply(author, reply)
|
|
|
|
except Exception as e:
|
|
logger.error("Error processing message: %s", e)
|
|
logger.debug("Problematic message: %s", str(msg)[:200] if 'msg' in locals() else 'N/A')
|
|
|
|
await asyncio.sleep(self.signal_cfg.poll_interval)
|
|
|
|
if self.session:
|
|
await self.session.close()
|
|
|
|
def _is_voice_message(self, data_message: Dict[str, Any]) -> bool:
|
|
"""Check if message contains voice attachments"""
|
|
attachments = data_message.get('attachments', [])
|
|
if not attachments:
|
|
return False
|
|
|
|
voice_mime_types = [
|
|
'audio/aac',
|
|
'audio/mp4',
|
|
'audio/mpeg',
|
|
'audio/ogg',
|
|
'audio/wav',
|
|
'audio/webm',
|
|
'audio/3gpp',
|
|
'audio/amr'
|
|
]
|
|
|
|
for attachment in attachments:
|
|
content_type = attachment.get('contentType', '').lower()
|
|
if content_type in voice_mime_types:
|
|
return True
|
|
return False
|
|
|
|
async def _download_attachment(self, attachment_id: str) -> Optional[bytes]:
|
|
"""Download attachment from Signal API"""
|
|
try:
|
|
download_url = f"{self.signal_cfg.api_url}/v1/attachments/{attachment_id}"
|
|
async with self.session.get(download_url) as resp:
|
|
if resp.status == 200:
|
|
return await resp.read()
|
|
else:
|
|
logger.error("Failed to download attachment %s: HTTP %d", attachment_id, resp.status)
|
|
return None
|
|
except Exception as e:
|
|
logger.error("Error downloading attachment %s: %s", attachment_id, e)
|
|
return None
|
|
|
|
async def _transcribe_audio(self, audio_data: bytes, filename: str = "audio.ogg") -> Optional[str]:
|
|
"""Transcribe audio using the new /asr endpoint"""
|
|
if not self.whisper_cfg.enabled:
|
|
return None
|
|
|
|
try:
|
|
# Create temporary file for audio data
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(filename)[1]) as temp_file:
|
|
temp_file.write(audio_data)
|
|
temp_file_path = temp_file.name
|
|
|
|
try:
|
|
# Prepare multipart form data for /asr endpoint
|
|
with open(temp_file_path, 'rb') as audio_file:
|
|
form_data = aiohttp.FormData()
|
|
form_data.add_field('audio_file', audio_file, filename=filename)
|
|
# You can add more fields here if needed (e.g., task, language, etc.)
|
|
asr_url = f"{self.whisper_cfg.api_url}/asr?output=json"
|
|
try:
|
|
async with self.session.post(asr_url, data=form_data) as resp:
|
|
response_text = await resp.text()
|
|
if resp.status == 200:
|
|
try:
|
|
result = json.loads(response_text)
|
|
logger.info("ASR API 200 response JSON: %s", result)
|
|
return result.get('text', '').strip()
|
|
except Exception as parse_exc:
|
|
logger.error("Failed to parse ASR API JSON response: %s", parse_exc, exc_info=True)
|
|
logger.error("Raw response text: %s", response_text)
|
|
return None
|
|
else:
|
|
logger.error("ASR transcription failed: HTTP %d - %s", resp.status, response_text)
|
|
return None
|
|
except aiohttp.ClientConnectorError:
|
|
logger.error("Cannot connect to ASR service at %s - is it running?", self.whisper_cfg.api_url, exc_info=True)
|
|
return None
|
|
except asyncio.TimeoutError:
|
|
logger.error("ASR transcription timed out", exc_info=True)
|
|
return None
|
|
except aiohttp.ClientError as e:
|
|
logger.error("Aiohttp client error: %s", e, exc_info=True)
|
|
return None
|
|
except Exception as e:
|
|
logger.error("Unexpected error during ASR transcription: %s", e, exc_info=True)
|
|
return None
|
|
finally:
|
|
# Clean up temporary file
|
|
try:
|
|
os.unlink(temp_file_path)
|
|
except OSError:
|
|
pass
|
|
except Exception as e:
|
|
logger.error("Error transcribing audio: %s", e)
|
|
return None
|
|
|
|
async def _process_voice_message(self, data_message: Dict[str, Any], author: str) -> bool:
|
|
"""Process voice message and send transcription. Returns True if handled."""
|
|
attachments = data_message.get('attachments', [])
|
|
|
|
for attachment in attachments:
|
|
content_type = attachment.get('contentType', '').lower()
|
|
attachment_id = attachment.get('id')
|
|
|
|
if not attachment_id:
|
|
continue
|
|
|
|
voice_mime_types = [
|
|
'audio/aac', 'audio/mp4', 'audio/mpeg', 'audio/ogg',
|
|
'audio/wav', 'audio/webm', 'audio/3gpp', 'audio/amr'
|
|
]
|
|
|
|
if content_type in voice_mime_types:
|
|
logger.info("Processing voice message from %s (type: %s)", author, content_type)
|
|
|
|
# Download the audio file
|
|
audio_data = await self._download_attachment(attachment_id)
|
|
if not audio_data:
|
|
await self._send_reply(author, "Sorry, I couldn't download your voice message.")
|
|
return True
|
|
|
|
# Transcribe the audio
|
|
filename = f"voice_message.{content_type.split('/')[-1]}"
|
|
transcription = await self._transcribe_audio(audio_data, filename)
|
|
|
|
if transcription:
|
|
reply = f"Voice message transcription:\n\n{transcription}"
|
|
logger.info("Transcribed voice message from %s: %s", author, transcription)
|
|
else:
|
|
reply = "Sorry, I couldn't transcribe your voice message."
|
|
|
|
await self._send_reply(author, reply)
|
|
return True
|
|
|
|
return False
|
|
|
|
async def _send_reply(self, recipient: str, message: str) -> None:
|
|
"""Send a reply message"""
|
|
send_url = f"{self.signal_cfg.api_url}/v2/send"
|
|
payload = {
|
|
'number': self.signal_cfg.number,
|
|
'recipients': [recipient],
|
|
'message': message,
|
|
'text_mode': 'normal'
|
|
}
|
|
|
|
try:
|
|
async with self.session.post(send_url, json=payload) as resp_send:
|
|
resp_send.raise_for_status()
|
|
response_data = await resp_send.json()
|
|
logger.info("Sent to %s (timestamp: %s)", recipient, response_data.get('timestamp', 'unknown'))
|
|
except Exception as e:
|
|
logger.error("Error sending to %s: %s", recipient, e)
|
|
|
|
async def _get_ai_response(self, prompt: str, user: str) -> str:
|
|
history = self.context.get_history(user)
|
|
|
|
# Build messages in OpenAI format for LiteLLM
|
|
messages = []
|
|
|
|
# Add conversation history
|
|
for msg in history:
|
|
messages.append({
|
|
'role': msg['role'],
|
|
'content': msg['content']
|
|
})
|
|
|
|
# Add current user message
|
|
messages.append({
|
|
'role': 'user',
|
|
'content': prompt
|
|
})
|
|
|
|
try:
|
|
# Use LiteLLM async completion
|
|
response = await litellm.acompletion(
|
|
model=self.llm_cfg.model,
|
|
messages=messages,
|
|
api_base=self.llm_cfg.api_base,
|
|
api_key=self.llm_cfg.api_key
|
|
)
|
|
|
|
# Extract reply from LiteLLM response
|
|
reply = response.choices[0].message.content.strip()
|
|
|
|
# Filter out <think></think> content before saving and sending
|
|
filtered_reply = filter_think_tags(reply)
|
|
|
|
self.context.add_message(user, 'user', prompt)
|
|
self.context.add_message(user, 'assistant', filtered_reply)
|
|
|
|
logger.debug("Original reply length: %d, filtered length: %d", len(reply), len(filtered_reply))
|
|
return filtered_reply
|
|
|
|
except Exception as e:
|
|
logger.error("LLM API error: %s", e)
|
|
return "Sorry, I encountered an error processing your request."
|
|
|
|
def _stop(self) -> None:
|
|
logger.info("Shutdown signal received.")
|
|
self.running = False
|
|
|
|
|
|
async def main() -> None:
|
|
signal_cfg = SignalConfig(
|
|
api_url=SIGNAL_URL,
|
|
number=SIGNAL_NUMBER
|
|
)
|
|
llm_cfg = LLMConfig(
|
|
model=LLM_MODEL,
|
|
api_base=LLM_API_BASE,
|
|
api_key=LLM_API_KEY,
|
|
provider=LLM_PROVIDER
|
|
)
|
|
whisper_cfg = WhisperConfig(
|
|
api_url=WHISPER_URL,
|
|
enabled=bool(WHISPER_URL) # Enable if URL is provided
|
|
)
|
|
context_mgr = ContextManager(DB_FILE)
|
|
bridge = SignalLLMBridge(signal_cfg, llm_cfg, whisper_cfg, context_mgr)
|
|
await bridge.start()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
asyncio.run(main())
|