fixing single model

This commit is contained in:
2025-09-11 12:08:33 +02:00
parent ab17a8ac21
commit cf48a5c2fc

View File

@@ -16,6 +16,7 @@ from dataclasses import dataclass, asdict
from concurrent import futures from concurrent import futures
import threading import threading
from datetime import datetime from datetime import datetime
import atexit
# Add current directory to path for generated protobuf imports # Add current directory to path for generated protobuf imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -69,63 +70,95 @@ class TranscriptionSession:
transcriptions: List[dict] transcriptions: List[dict]
class TranscriptionEngine: class ModelManager:
"""Core transcription engine using Whisper or SimulStreaming""" """Singleton manager for Whisper model to share across all connections"""
def __init__(self, model_name: str = "large-v3"): _instance = None
self.model_name = model_name _lock = threading.Lock()
self.model = None _model = None
self.processor = None _device = None
self.online_processor = None _model_name = None
self.device = "cuda" if torch.cuda.is_available() else "cpu" _initialized = False
self.load_model()
def load_model(self): def __new__(cls):
"""Load the transcription model""" if cls._instance is None:
if USE_SIMULSTREAMING: with cls._lock:
self._load_simulstreaming() if cls._instance is None:
else: cls._instance = super().__new__(cls)
self._load_whisper() return cls._instance
def _load_simulstreaming(self): def initialize(self, model_name: str = "large-v3"):
"""Load SimulStreaming for real-time transcription""" """Initialize the model (only once)"""
with self._lock:
if not self._initialized:
self._model_name = model_name
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._load_model()
self._initialized = True
logger.info(f"ModelManager initialized with {model_name} on {self._device}")
def _load_model(self):
"""Load the Whisper model"""
try: try:
import argparse
parser = argparse.ArgumentParser()
# Add SimulStreaming arguments
simulwhisper_args(parser)
args = parser.parse_args([
'--model_path', self.model_name,
'--lan', 'auto',
'--task', 'transcribe',
'--backend', 'whisper',
'--min-chunk-size', '0.5',
'--beams', '1',
])
# Create processor
self.processor, self.online_processor = simul_asr_factory(args)
logger.info(f"Loaded SimulStreaming with model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to load SimulStreaming: {e}")
logger.info("Falling back to standard Whisper")
USE_SIMULSTREAMING = False
self._load_whisper()
def _load_whisper(self):
"""Load standard Whisper model"""
try:
# Use the shared volume for model caching
download_root = os.environ.get('TORCH_HOME', '/app/models') download_root = os.environ.get('TORCH_HOME', '/app/models')
self.model = whisper.load_model(self.model_name, device=self.device, download_root=download_root) self._model = whisper.load_model(
logger.info(f"Loaded Whisper model: {self.model_name} on {self.device} from {download_root}") self._model_name,
device=self._device,
download_root=download_root
)
logger.info(f"Loaded shared Whisper model: {self._model_name} on {self._device}")
except Exception as e: except Exception as e:
logger.error(f"Failed to load Whisper model: {e}") logger.error(f"Failed to load Whisper model: {e}")
raise raise
def get_model(self):
"""Get the shared model instance"""
if not self._initialized:
raise RuntimeError("ModelManager not initialized. Call initialize() first.")
return self._model
def get_device(self):
"""Get the device being used"""
return self._device
def get_model_name(self):
"""Get the model name"""
return self._model_name
def cleanup(self):
"""Cleanup resources (call on shutdown)"""
with self._lock:
if self._model is not None:
del self._model
self._model = None
self._initialized = False
if self._device == "cuda":
torch.cuda.empty_cache()
logger.info("ModelManager cleaned up")
class TranscriptionEngine:
"""Core transcription engine using shared Whisper model"""
def __init__(self, model_manager: Optional[ModelManager] = None):
"""Initialize with optional model manager (for testing)"""
self.model_manager = model_manager or ModelManager()
self.model = None # Will get from manager
self.processor = None
self.online_processor = None
self.device = self.model_manager.get_device()
self.model_name = self.model_manager._model_name
def load_model(self):
"""Load the transcription model"""
# Model is already loaded in ModelManager
# This method is kept for compatibility
pass
def get_model(self):
"""Get the shared model instance from ModelManager"""
return self.model_manager.get_model()
def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool: def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool:
""" """
Simple Voice Activity Detection Simple Voice Activity Detection
@@ -176,13 +209,14 @@ class TranscriptionEngine:
} }
else: else:
# Use standard Whisper # Use standard Whisper
if self.model: model = self.get_model()
if model:
# Pad audio to minimum length if needed # Pad audio to minimum length if needed
if len(audio) < SAMPLE_RATE: if len(audio) < SAMPLE_RATE:
audio = np.pad(audio, (0, SAMPLE_RATE - len(audio))) audio = np.pad(audio, (0, SAMPLE_RATE - len(audio)))
# Use more conservative settings to reduce hallucinations # Use more conservative settings to reduce hallucinations
result = self.model.transcribe( result = model.transcribe(
audio, audio,
language=None if language == "auto" else language, language=None if language == "auto" else language,
fp16=self.device == "cuda", fp16=self.device == "cuda",
@@ -240,8 +274,9 @@ class TranscriptionEngine:
audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE) audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE)
# Transcribe with Whisper # Transcribe with Whisper
if self.model: model = self.get_model()
result = self.model.transcribe( if model:
result = model.transcribe(
audio, audio,
language=None if config.language == "auto" else config.language, language=None if config.language == "auto" else config.language,
task=config.task or "transcribe", task=config.task or "transcribe",
@@ -278,8 +313,9 @@ class TranscriptionEngine:
class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer): class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer):
"""gRPC service implementation""" """gRPC service implementation"""
def __init__(self): def __init__(self, model_manager: Optional[ModelManager] = None):
self.engine = TranscriptionEngine() self.model_manager = model_manager or ModelManager()
self.engine = TranscriptionEngine(self.model_manager)
self.sessions: Dict[str, TranscriptionSession] = {} self.sessions: Dict[str, TranscriptionSession] = {}
self.start_time = time.time() self.start_time = time.time()
@@ -419,7 +455,9 @@ async def serve_grpc(port: int = 50051):
] ]
) )
servicer = TranscriptionServicer() # Use the shared model manager
model_manager = get_global_model_manager()
servicer = TranscriptionServicer(model_manager)
transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server) transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server)
server.add_insecure_port(f'[::]:{port}') server.add_insecure_port(f'[::]:{port}')
@@ -429,12 +467,27 @@ async def serve_grpc(port: int = 50051):
await server.wait_for_termination() await server.wait_for_termination()
# Global model manager instance (shared across all handlers)
_global_model_manager = None
def get_global_model_manager() -> ModelManager:
"""Get or create the global model manager"""
global _global_model_manager
if _global_model_manager is None:
_global_model_manager = ModelManager()
model_name = os.environ.get('MODEL_PATH', 'large-v3')
_global_model_manager.initialize(model_name)
return _global_model_manager
# WebSocket support for compatibility # WebSocket support for compatibility
async def handle_websocket(websocket, path): async def handle_websocket(websocket, path):
"""Handle WebSocket connections for compatibility""" """Handle WebSocket connections for compatibility"""
import websockets import websockets
engine = TranscriptionEngine() # Use the shared model manager instead of creating new engine
model_manager = get_global_model_manager()
engine = TranscriptionEngine(model_manager)
session_id = str(time.time()) session_id = str(time.time())
audio_buffer = bytearray() audio_buffer = bytearray()
@@ -453,8 +506,8 @@ async def handle_websocket(websocket, path):
audio_data = base64.b64decode(data['data']) audio_data = base64.b64decode(data['data'])
audio_buffer.extend(audio_data) audio_buffer.extend(audio_data)
# Process when we have enough audio # Process when we have enough audio (3 seconds for better accuracy)
min_bytes = int(SAMPLE_RATE * 0.5 * 2) min_bytes = int(SAMPLE_RATE * 3.0 * 2) # 3 seconds of PCM16
while len(audio_buffer) >= min_bytes: while len(audio_buffer) >= min_bytes:
chunk = bytes(audio_buffer[:min_bytes]) chunk = bytes(audio_buffer[:min_bytes])
@@ -506,12 +559,24 @@ async def main():
ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765')) ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765'))
enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true' enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true'
tasks = [serve_grpc(grpc_port)] # Initialize the global model manager once at startup
logger.info("Initializing shared model manager...")
model_manager = get_global_model_manager()
logger.info(f"Model manager initialized with model: {model_manager._model_name}")
if enable_websocket: try:
tasks.append(serve_websocket(ws_port)) tasks = [serve_grpc(grpc_port)]
await asyncio.gather(*tasks) if enable_websocket:
tasks.append(serve_websocket(ws_port))
await asyncio.gather(*tasks)
finally:
# Cleanup on shutdown
logger.info("Shutting down, cleaning up model manager...")
if model_manager:
model_manager.cleanup()
logger.info("Cleanup complete")
if __name__ == "__main__": if __name__ == "__main__":