mirror of
https://github.com/aljazceru/transcription-api.git
synced 2025-12-17 07:14:24 +01:00
fixing single model
This commit is contained in:
@@ -16,6 +16,7 @@ from dataclasses import dataclass, asdict
|
||||
from concurrent import futures
|
||||
import threading
|
||||
from datetime import datetime
|
||||
import atexit
|
||||
|
||||
# Add current directory to path for generated protobuf imports
|
||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -69,63 +70,95 @@ class TranscriptionSession:
|
||||
transcriptions: List[dict]
|
||||
|
||||
|
||||
class TranscriptionEngine:
|
||||
"""Core transcription engine using Whisper or SimulStreaming"""
|
||||
class ModelManager:
|
||||
"""Singleton manager for Whisper model to share across all connections"""
|
||||
|
||||
def __init__(self, model_name: str = "large-v3"):
|
||||
self.model_name = model_name
|
||||
self.model = None
|
||||
self.processor = None
|
||||
self.online_processor = None
|
||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self.load_model()
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
_model = None
|
||||
_device = None
|
||||
_model_name = None
|
||||
_initialized = False
|
||||
|
||||
def load_model(self):
|
||||
"""Load the transcription model"""
|
||||
if USE_SIMULSTREAMING:
|
||||
self._load_simulstreaming()
|
||||
else:
|
||||
self._load_whisper()
|
||||
def __new__(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def _load_simulstreaming(self):
|
||||
"""Load SimulStreaming for real-time transcription"""
|
||||
def initialize(self, model_name: str = "large-v3"):
|
||||
"""Initialize the model (only once)"""
|
||||
with self._lock:
|
||||
if not self._initialized:
|
||||
self._model_name = model_name
|
||||
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
self._load_model()
|
||||
self._initialized = True
|
||||
logger.info(f"ModelManager initialized with {model_name} on {self._device}")
|
||||
|
||||
def _load_model(self):
|
||||
"""Load the Whisper model"""
|
||||
try:
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
# Add SimulStreaming arguments
|
||||
simulwhisper_args(parser)
|
||||
|
||||
args = parser.parse_args([
|
||||
'--model_path', self.model_name,
|
||||
'--lan', 'auto',
|
||||
'--task', 'transcribe',
|
||||
'--backend', 'whisper',
|
||||
'--min-chunk-size', '0.5',
|
||||
'--beams', '1',
|
||||
])
|
||||
|
||||
# Create processor
|
||||
self.processor, self.online_processor = simul_asr_factory(args)
|
||||
logger.info(f"Loaded SimulStreaming with model: {self.model_name}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SimulStreaming: {e}")
|
||||
logger.info("Falling back to standard Whisper")
|
||||
USE_SIMULSTREAMING = False
|
||||
self._load_whisper()
|
||||
|
||||
def _load_whisper(self):
|
||||
"""Load standard Whisper model"""
|
||||
try:
|
||||
# Use the shared volume for model caching
|
||||
download_root = os.environ.get('TORCH_HOME', '/app/models')
|
||||
self.model = whisper.load_model(self.model_name, device=self.device, download_root=download_root)
|
||||
logger.info(f"Loaded Whisper model: {self.model_name} on {self.device} from {download_root}")
|
||||
self._model = whisper.load_model(
|
||||
self._model_name,
|
||||
device=self._device,
|
||||
download_root=download_root
|
||||
)
|
||||
logger.info(f"Loaded shared Whisper model: {self._model_name} on {self._device}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Whisper model: {e}")
|
||||
raise
|
||||
|
||||
def get_model(self):
|
||||
"""Get the shared model instance"""
|
||||
if not self._initialized:
|
||||
raise RuntimeError("ModelManager not initialized. Call initialize() first.")
|
||||
return self._model
|
||||
|
||||
def get_device(self):
|
||||
"""Get the device being used"""
|
||||
return self._device
|
||||
|
||||
def get_model_name(self):
|
||||
"""Get the model name"""
|
||||
return self._model_name
|
||||
|
||||
def cleanup(self):
|
||||
"""Cleanup resources (call on shutdown)"""
|
||||
with self._lock:
|
||||
if self._model is not None:
|
||||
del self._model
|
||||
self._model = None
|
||||
self._initialized = False
|
||||
if self._device == "cuda":
|
||||
torch.cuda.empty_cache()
|
||||
logger.info("ModelManager cleaned up")
|
||||
|
||||
|
||||
class TranscriptionEngine:
|
||||
"""Core transcription engine using shared Whisper model"""
|
||||
|
||||
def __init__(self, model_manager: Optional[ModelManager] = None):
|
||||
"""Initialize with optional model manager (for testing)"""
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
self.model = None # Will get from manager
|
||||
self.processor = None
|
||||
self.online_processor = None
|
||||
self.device = self.model_manager.get_device()
|
||||
self.model_name = self.model_manager._model_name
|
||||
|
||||
def load_model(self):
|
||||
"""Load the transcription model"""
|
||||
# Model is already loaded in ModelManager
|
||||
# This method is kept for compatibility
|
||||
pass
|
||||
|
||||
def get_model(self):
|
||||
"""Get the shared model instance from ModelManager"""
|
||||
return self.model_manager.get_model()
|
||||
|
||||
def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool:
|
||||
"""
|
||||
Simple Voice Activity Detection
|
||||
@@ -176,13 +209,14 @@ class TranscriptionEngine:
|
||||
}
|
||||
else:
|
||||
# Use standard Whisper
|
||||
if self.model:
|
||||
model = self.get_model()
|
||||
if model:
|
||||
# Pad audio to minimum length if needed
|
||||
if len(audio) < SAMPLE_RATE:
|
||||
audio = np.pad(audio, (0, SAMPLE_RATE - len(audio)))
|
||||
|
||||
# Use more conservative settings to reduce hallucinations
|
||||
result = self.model.transcribe(
|
||||
result = model.transcribe(
|
||||
audio,
|
||||
language=None if language == "auto" else language,
|
||||
fp16=self.device == "cuda",
|
||||
@@ -240,8 +274,9 @@ class TranscriptionEngine:
|
||||
audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE)
|
||||
|
||||
# Transcribe with Whisper
|
||||
if self.model:
|
||||
result = self.model.transcribe(
|
||||
model = self.get_model()
|
||||
if model:
|
||||
result = model.transcribe(
|
||||
audio,
|
||||
language=None if config.language == "auto" else config.language,
|
||||
task=config.task or "transcribe",
|
||||
@@ -278,8 +313,9 @@ class TranscriptionEngine:
|
||||
class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer):
|
||||
"""gRPC service implementation"""
|
||||
|
||||
def __init__(self):
|
||||
self.engine = TranscriptionEngine()
|
||||
def __init__(self, model_manager: Optional[ModelManager] = None):
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
self.engine = TranscriptionEngine(self.model_manager)
|
||||
self.sessions: Dict[str, TranscriptionSession] = {}
|
||||
self.start_time = time.time()
|
||||
|
||||
@@ -419,7 +455,9 @@ async def serve_grpc(port: int = 50051):
|
||||
]
|
||||
)
|
||||
|
||||
servicer = TranscriptionServicer()
|
||||
# Use the shared model manager
|
||||
model_manager = get_global_model_manager()
|
||||
servicer = TranscriptionServicer(model_manager)
|
||||
transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server)
|
||||
|
||||
server.add_insecure_port(f'[::]:{port}')
|
||||
@@ -429,12 +467,27 @@ async def serve_grpc(port: int = 50051):
|
||||
await server.wait_for_termination()
|
||||
|
||||
|
||||
# Global model manager instance (shared across all handlers)
|
||||
_global_model_manager = None
|
||||
|
||||
def get_global_model_manager() -> ModelManager:
|
||||
"""Get or create the global model manager"""
|
||||
global _global_model_manager
|
||||
if _global_model_manager is None:
|
||||
_global_model_manager = ModelManager()
|
||||
model_name = os.environ.get('MODEL_PATH', 'large-v3')
|
||||
_global_model_manager.initialize(model_name)
|
||||
return _global_model_manager
|
||||
|
||||
|
||||
# WebSocket support for compatibility
|
||||
async def handle_websocket(websocket, path):
|
||||
"""Handle WebSocket connections for compatibility"""
|
||||
import websockets
|
||||
|
||||
engine = TranscriptionEngine()
|
||||
# Use the shared model manager instead of creating new engine
|
||||
model_manager = get_global_model_manager()
|
||||
engine = TranscriptionEngine(model_manager)
|
||||
session_id = str(time.time())
|
||||
audio_buffer = bytearray()
|
||||
|
||||
@@ -453,8 +506,8 @@ async def handle_websocket(websocket, path):
|
||||
audio_data = base64.b64decode(data['data'])
|
||||
audio_buffer.extend(audio_data)
|
||||
|
||||
# Process when we have enough audio
|
||||
min_bytes = int(SAMPLE_RATE * 0.5 * 2)
|
||||
# Process when we have enough audio (3 seconds for better accuracy)
|
||||
min_bytes = int(SAMPLE_RATE * 3.0 * 2) # 3 seconds of PCM16
|
||||
|
||||
while len(audio_buffer) >= min_bytes:
|
||||
chunk = bytes(audio_buffer[:min_bytes])
|
||||
@@ -506,12 +559,24 @@ async def main():
|
||||
ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765'))
|
||||
enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true'
|
||||
|
||||
tasks = [serve_grpc(grpc_port)]
|
||||
# Initialize the global model manager once at startup
|
||||
logger.info("Initializing shared model manager...")
|
||||
model_manager = get_global_model_manager()
|
||||
logger.info(f"Model manager initialized with model: {model_manager._model_name}")
|
||||
|
||||
if enable_websocket:
|
||||
tasks.append(serve_websocket(ws_port))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
try:
|
||||
tasks = [serve_grpc(grpc_port)]
|
||||
|
||||
if enable_websocket:
|
||||
tasks.append(serve_websocket(ws_port))
|
||||
|
||||
await asyncio.gather(*tasks)
|
||||
finally:
|
||||
# Cleanup on shutdown
|
||||
logger.info("Shutting down, cleaning up model manager...")
|
||||
if model_manager:
|
||||
model_manager.cleanup()
|
||||
logger.info("Cleanup complete")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
Reference in New Issue
Block a user