fixing single model

This commit is contained in:
2025-09-11 12:08:33 +02:00
parent ab17a8ac21
commit cf48a5c2fc

View File

@@ -16,6 +16,7 @@ from dataclasses import dataclass, asdict
from concurrent import futures
import threading
from datetime import datetime
import atexit
# Add current directory to path for generated protobuf imports
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
@@ -69,63 +70,95 @@ class TranscriptionSession:
transcriptions: List[dict]
class TranscriptionEngine:
"""Core transcription engine using Whisper or SimulStreaming"""
class ModelManager:
"""Singleton manager for Whisper model to share across all connections"""
def __init__(self, model_name: str = "large-v3"):
self.model_name = model_name
self.model = None
self.processor = None
self.online_processor = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model()
_instance = None
_lock = threading.Lock()
_model = None
_device = None
_model_name = None
_initialized = False
def load_model(self):
"""Load the transcription model"""
if USE_SIMULSTREAMING:
self._load_simulstreaming()
else:
self._load_whisper()
def __new__(cls):
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def _load_simulstreaming(self):
"""Load SimulStreaming for real-time transcription"""
def initialize(self, model_name: str = "large-v3"):
"""Initialize the model (only once)"""
with self._lock:
if not self._initialized:
self._model_name = model_name
self._device = "cuda" if torch.cuda.is_available() else "cpu"
self._load_model()
self._initialized = True
logger.info(f"ModelManager initialized with {model_name} on {self._device}")
def _load_model(self):
"""Load the Whisper model"""
try:
import argparse
parser = argparse.ArgumentParser()
# Add SimulStreaming arguments
simulwhisper_args(parser)
args = parser.parse_args([
'--model_path', self.model_name,
'--lan', 'auto',
'--task', 'transcribe',
'--backend', 'whisper',
'--min-chunk-size', '0.5',
'--beams', '1',
])
# Create processor
self.processor, self.online_processor = simul_asr_factory(args)
logger.info(f"Loaded SimulStreaming with model: {self.model_name}")
except Exception as e:
logger.error(f"Failed to load SimulStreaming: {e}")
logger.info("Falling back to standard Whisper")
USE_SIMULSTREAMING = False
self._load_whisper()
def _load_whisper(self):
"""Load standard Whisper model"""
try:
# Use the shared volume for model caching
download_root = os.environ.get('TORCH_HOME', '/app/models')
self.model = whisper.load_model(self.model_name, device=self.device, download_root=download_root)
logger.info(f"Loaded Whisper model: {self.model_name} on {self.device} from {download_root}")
self._model = whisper.load_model(
self._model_name,
device=self._device,
download_root=download_root
)
logger.info(f"Loaded shared Whisper model: {self._model_name} on {self._device}")
except Exception as e:
logger.error(f"Failed to load Whisper model: {e}")
raise
def get_model(self):
"""Get the shared model instance"""
if not self._initialized:
raise RuntimeError("ModelManager not initialized. Call initialize() first.")
return self._model
def get_device(self):
"""Get the device being used"""
return self._device
def get_model_name(self):
"""Get the model name"""
return self._model_name
def cleanup(self):
"""Cleanup resources (call on shutdown)"""
with self._lock:
if self._model is not None:
del self._model
self._model = None
self._initialized = False
if self._device == "cuda":
torch.cuda.empty_cache()
logger.info("ModelManager cleaned up")
class TranscriptionEngine:
"""Core transcription engine using shared Whisper model"""
def __init__(self, model_manager: Optional[ModelManager] = None):
"""Initialize with optional model manager (for testing)"""
self.model_manager = model_manager or ModelManager()
self.model = None # Will get from manager
self.processor = None
self.online_processor = None
self.device = self.model_manager.get_device()
self.model_name = self.model_manager._model_name
def load_model(self):
"""Load the transcription model"""
# Model is already loaded in ModelManager
# This method is kept for compatibility
pass
def get_model(self):
"""Get the shared model instance from ModelManager"""
return self.model_manager.get_model()
def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool:
"""
Simple Voice Activity Detection
@@ -176,13 +209,14 @@ class TranscriptionEngine:
}
else:
# Use standard Whisper
if self.model:
model = self.get_model()
if model:
# Pad audio to minimum length if needed
if len(audio) < SAMPLE_RATE:
audio = np.pad(audio, (0, SAMPLE_RATE - len(audio)))
# Use more conservative settings to reduce hallucinations
result = self.model.transcribe(
result = model.transcribe(
audio,
language=None if language == "auto" else language,
fp16=self.device == "cuda",
@@ -240,8 +274,9 @@ class TranscriptionEngine:
audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE)
# Transcribe with Whisper
if self.model:
result = self.model.transcribe(
model = self.get_model()
if model:
result = model.transcribe(
audio,
language=None if config.language == "auto" else config.language,
task=config.task or "transcribe",
@@ -278,8 +313,9 @@ class TranscriptionEngine:
class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer):
"""gRPC service implementation"""
def __init__(self):
self.engine = TranscriptionEngine()
def __init__(self, model_manager: Optional[ModelManager] = None):
self.model_manager = model_manager or ModelManager()
self.engine = TranscriptionEngine(self.model_manager)
self.sessions: Dict[str, TranscriptionSession] = {}
self.start_time = time.time()
@@ -419,7 +455,9 @@ async def serve_grpc(port: int = 50051):
]
)
servicer = TranscriptionServicer()
# Use the shared model manager
model_manager = get_global_model_manager()
servicer = TranscriptionServicer(model_manager)
transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server)
server.add_insecure_port(f'[::]:{port}')
@@ -429,12 +467,27 @@ async def serve_grpc(port: int = 50051):
await server.wait_for_termination()
# Global model manager instance (shared across all handlers)
_global_model_manager = None
def get_global_model_manager() -> ModelManager:
"""Get or create the global model manager"""
global _global_model_manager
if _global_model_manager is None:
_global_model_manager = ModelManager()
model_name = os.environ.get('MODEL_PATH', 'large-v3')
_global_model_manager.initialize(model_name)
return _global_model_manager
# WebSocket support for compatibility
async def handle_websocket(websocket, path):
"""Handle WebSocket connections for compatibility"""
import websockets
engine = TranscriptionEngine()
# Use the shared model manager instead of creating new engine
model_manager = get_global_model_manager()
engine = TranscriptionEngine(model_manager)
session_id = str(time.time())
audio_buffer = bytearray()
@@ -453,8 +506,8 @@ async def handle_websocket(websocket, path):
audio_data = base64.b64decode(data['data'])
audio_buffer.extend(audio_data)
# Process when we have enough audio
min_bytes = int(SAMPLE_RATE * 0.5 * 2)
# Process when we have enough audio (3 seconds for better accuracy)
min_bytes = int(SAMPLE_RATE * 3.0 * 2) # 3 seconds of PCM16
while len(audio_buffer) >= min_bytes:
chunk = bytes(audio_buffer[:min_bytes])
@@ -506,12 +559,24 @@ async def main():
ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765'))
enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true'
tasks = [serve_grpc(grpc_port)]
# Initialize the global model manager once at startup
logger.info("Initializing shared model manager...")
model_manager = get_global_model_manager()
logger.info(f"Model manager initialized with model: {model_manager._model_name}")
if enable_websocket:
tasks.append(serve_websocket(ws_port))
await asyncio.gather(*tasks)
try:
tasks = [serve_grpc(grpc_port)]
if enable_websocket:
tasks.append(serve_websocket(ws_port))
await asyncio.gather(*tasks)
finally:
# Cleanup on shutdown
logger.info("Shutting down, cleaning up model manager...")
if model_manager:
model_manager.cleanup()
logger.info("Cleanup complete")
if __name__ == "__main__":