From cf48a5c2fcb4172fcd4180e0100861668a151794 Mon Sep 17 00:00:00 2001 From: Aljaz Ceru Date: Thu, 11 Sep 2025 12:08:33 +0200 Subject: [PATCH] fixing single model --- src/transcription_server.py | 191 ++++++++++++++++++++++++------------ 1 file changed, 128 insertions(+), 63 deletions(-) diff --git a/src/transcription_server.py b/src/transcription_server.py index c53f9e0..d98a28e 100644 --- a/src/transcription_server.py +++ b/src/transcription_server.py @@ -16,6 +16,7 @@ from dataclasses import dataclass, asdict from concurrent import futures import threading from datetime import datetime +import atexit # Add current directory to path for generated protobuf imports sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) @@ -69,63 +70,95 @@ class TranscriptionSession: transcriptions: List[dict] -class TranscriptionEngine: - """Core transcription engine using Whisper or SimulStreaming""" +class ModelManager: + """Singleton manager for Whisper model to share across all connections""" - def __init__(self, model_name: str = "large-v3"): - self.model_name = model_name - self.model = None - self.processor = None - self.online_processor = None - self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.load_model() + _instance = None + _lock = threading.Lock() + _model = None + _device = None + _model_name = None + _initialized = False - def load_model(self): - """Load the transcription model""" - if USE_SIMULSTREAMING: - self._load_simulstreaming() - else: - self._load_whisper() + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance - def _load_simulstreaming(self): - """Load SimulStreaming for real-time transcription""" + def initialize(self, model_name: str = "large-v3"): + """Initialize the model (only once)""" + with self._lock: + if not self._initialized: + self._model_name = model_name + self._device = "cuda" if torch.cuda.is_available() else "cpu" + self._load_model() + self._initialized = True + logger.info(f"ModelManager initialized with {model_name} on {self._device}") + + def _load_model(self): + """Load the Whisper model""" try: - import argparse - parser = argparse.ArgumentParser() - - # Add SimulStreaming arguments - simulwhisper_args(parser) - - args = parser.parse_args([ - '--model_path', self.model_name, - '--lan', 'auto', - '--task', 'transcribe', - '--backend', 'whisper', - '--min-chunk-size', '0.5', - '--beams', '1', - ]) - - # Create processor - self.processor, self.online_processor = simul_asr_factory(args) - logger.info(f"Loaded SimulStreaming with model: {self.model_name}") - - except Exception as e: - logger.error(f"Failed to load SimulStreaming: {e}") - logger.info("Falling back to standard Whisper") - USE_SIMULSTREAMING = False - self._load_whisper() - - def _load_whisper(self): - """Load standard Whisper model""" - try: - # Use the shared volume for model caching download_root = os.environ.get('TORCH_HOME', '/app/models') - self.model = whisper.load_model(self.model_name, device=self.device, download_root=download_root) - logger.info(f"Loaded Whisper model: {self.model_name} on {self.device} from {download_root}") + self._model = whisper.load_model( + self._model_name, + device=self._device, + download_root=download_root + ) + logger.info(f"Loaded shared Whisper model: {self._model_name} on {self._device}") except Exception as e: logger.error(f"Failed to load Whisper model: {e}") raise + def get_model(self): + """Get the shared model instance""" + if not self._initialized: + raise RuntimeError("ModelManager not initialized. Call initialize() first.") + return self._model + + def get_device(self): + """Get the device being used""" + return self._device + + def get_model_name(self): + """Get the model name""" + return self._model_name + + def cleanup(self): + """Cleanup resources (call on shutdown)""" + with self._lock: + if self._model is not None: + del self._model + self._model = None + self._initialized = False + if self._device == "cuda": + torch.cuda.empty_cache() + logger.info("ModelManager cleaned up") + + +class TranscriptionEngine: + """Core transcription engine using shared Whisper model""" + + def __init__(self, model_manager: Optional[ModelManager] = None): + """Initialize with optional model manager (for testing)""" + self.model_manager = model_manager or ModelManager() + self.model = None # Will get from manager + self.processor = None + self.online_processor = None + self.device = self.model_manager.get_device() + self.model_name = self.model_manager._model_name + + def load_model(self): + """Load the transcription model""" + # Model is already loaded in ModelManager + # This method is kept for compatibility + pass + + def get_model(self): + """Get the shared model instance from ModelManager""" + return self.model_manager.get_model() + def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool: """ Simple Voice Activity Detection @@ -176,13 +209,14 @@ class TranscriptionEngine: } else: # Use standard Whisper - if self.model: + model = self.get_model() + if model: # Pad audio to minimum length if needed if len(audio) < SAMPLE_RATE: audio = np.pad(audio, (0, SAMPLE_RATE - len(audio))) # Use more conservative settings to reduce hallucinations - result = self.model.transcribe( + result = model.transcribe( audio, language=None if language == "auto" else language, fp16=self.device == "cuda", @@ -240,8 +274,9 @@ class TranscriptionEngine: audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE) # Transcribe with Whisper - if self.model: - result = self.model.transcribe( + model = self.get_model() + if model: + result = model.transcribe( audio, language=None if config.language == "auto" else config.language, task=config.task or "transcribe", @@ -278,8 +313,9 @@ class TranscriptionEngine: class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer): """gRPC service implementation""" - def __init__(self): - self.engine = TranscriptionEngine() + def __init__(self, model_manager: Optional[ModelManager] = None): + self.model_manager = model_manager or ModelManager() + self.engine = TranscriptionEngine(self.model_manager) self.sessions: Dict[str, TranscriptionSession] = {} self.start_time = time.time() @@ -419,7 +455,9 @@ async def serve_grpc(port: int = 50051): ] ) - servicer = TranscriptionServicer() + # Use the shared model manager + model_manager = get_global_model_manager() + servicer = TranscriptionServicer(model_manager) transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server) server.add_insecure_port(f'[::]:{port}') @@ -429,12 +467,27 @@ async def serve_grpc(port: int = 50051): await server.wait_for_termination() +# Global model manager instance (shared across all handlers) +_global_model_manager = None + +def get_global_model_manager() -> ModelManager: + """Get or create the global model manager""" + global _global_model_manager + if _global_model_manager is None: + _global_model_manager = ModelManager() + model_name = os.environ.get('MODEL_PATH', 'large-v3') + _global_model_manager.initialize(model_name) + return _global_model_manager + + # WebSocket support for compatibility async def handle_websocket(websocket, path): """Handle WebSocket connections for compatibility""" import websockets - engine = TranscriptionEngine() + # Use the shared model manager instead of creating new engine + model_manager = get_global_model_manager() + engine = TranscriptionEngine(model_manager) session_id = str(time.time()) audio_buffer = bytearray() @@ -453,8 +506,8 @@ async def handle_websocket(websocket, path): audio_data = base64.b64decode(data['data']) audio_buffer.extend(audio_data) - # Process when we have enough audio - min_bytes = int(SAMPLE_RATE * 0.5 * 2) + # Process when we have enough audio (3 seconds for better accuracy) + min_bytes = int(SAMPLE_RATE * 3.0 * 2) # 3 seconds of PCM16 while len(audio_buffer) >= min_bytes: chunk = bytes(audio_buffer[:min_bytes]) @@ -506,12 +559,24 @@ async def main(): ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765')) enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true' - tasks = [serve_grpc(grpc_port)] + # Initialize the global model manager once at startup + logger.info("Initializing shared model manager...") + model_manager = get_global_model_manager() + logger.info(f"Model manager initialized with model: {model_manager._model_name}") - if enable_websocket: - tasks.append(serve_websocket(ws_port)) - - await asyncio.gather(*tasks) + try: + tasks = [serve_grpc(grpc_port)] + + if enable_websocket: + tasks.append(serve_websocket(ws_port)) + + await asyncio.gather(*tasks) + finally: + # Cleanup on shutdown + logger.info("Shutting down, cleaning up model manager...") + if model_manager: + model_manager.cleanup() + logger.info("Cleanup complete") if __name__ == "__main__":