mirror of
https://github.com/aljazceru/transcription-api.git
synced 2025-12-17 07:14:24 +01:00
fixing single model
This commit is contained in:
@@ -16,6 +16,7 @@ from dataclasses import dataclass, asdict
|
|||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
import atexit
|
||||||
|
|
||||||
# Add current directory to path for generated protobuf imports
|
# Add current directory to path for generated protobuf imports
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
@@ -69,63 +70,95 @@ class TranscriptionSession:
|
|||||||
transcriptions: List[dict]
|
transcriptions: List[dict]
|
||||||
|
|
||||||
|
|
||||||
class TranscriptionEngine:
|
class ModelManager:
|
||||||
"""Core transcription engine using Whisper or SimulStreaming"""
|
"""Singleton manager for Whisper model to share across all connections"""
|
||||||
|
|
||||||
def __init__(self, model_name: str = "large-v3"):
|
_instance = None
|
||||||
self.model_name = model_name
|
_lock = threading.Lock()
|
||||||
self.model = None
|
_model = None
|
||||||
self.processor = None
|
_device = None
|
||||||
self.online_processor = None
|
_model_name = None
|
||||||
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
_initialized = False
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
def load_model(self):
|
def __new__(cls):
|
||||||
"""Load the transcription model"""
|
if cls._instance is None:
|
||||||
if USE_SIMULSTREAMING:
|
with cls._lock:
|
||||||
self._load_simulstreaming()
|
if cls._instance is None:
|
||||||
else:
|
cls._instance = super().__new__(cls)
|
||||||
self._load_whisper()
|
return cls._instance
|
||||||
|
|
||||||
def _load_simulstreaming(self):
|
def initialize(self, model_name: str = "large-v3"):
|
||||||
"""Load SimulStreaming for real-time transcription"""
|
"""Initialize the model (only once)"""
|
||||||
|
with self._lock:
|
||||||
|
if not self._initialized:
|
||||||
|
self._model_name = model_name
|
||||||
|
self._device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
self._load_model()
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(f"ModelManager initialized with {model_name} on {self._device}")
|
||||||
|
|
||||||
|
def _load_model(self):
|
||||||
|
"""Load the Whisper model"""
|
||||||
try:
|
try:
|
||||||
import argparse
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
# Add SimulStreaming arguments
|
|
||||||
simulwhisper_args(parser)
|
|
||||||
|
|
||||||
args = parser.parse_args([
|
|
||||||
'--model_path', self.model_name,
|
|
||||||
'--lan', 'auto',
|
|
||||||
'--task', 'transcribe',
|
|
||||||
'--backend', 'whisper',
|
|
||||||
'--min-chunk-size', '0.5',
|
|
||||||
'--beams', '1',
|
|
||||||
])
|
|
||||||
|
|
||||||
# Create processor
|
|
||||||
self.processor, self.online_processor = simul_asr_factory(args)
|
|
||||||
logger.info(f"Loaded SimulStreaming with model: {self.model_name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Failed to load SimulStreaming: {e}")
|
|
||||||
logger.info("Falling back to standard Whisper")
|
|
||||||
USE_SIMULSTREAMING = False
|
|
||||||
self._load_whisper()
|
|
||||||
|
|
||||||
def _load_whisper(self):
|
|
||||||
"""Load standard Whisper model"""
|
|
||||||
try:
|
|
||||||
# Use the shared volume for model caching
|
|
||||||
download_root = os.environ.get('TORCH_HOME', '/app/models')
|
download_root = os.environ.get('TORCH_HOME', '/app/models')
|
||||||
self.model = whisper.load_model(self.model_name, device=self.device, download_root=download_root)
|
self._model = whisper.load_model(
|
||||||
logger.info(f"Loaded Whisper model: {self.model_name} on {self.device} from {download_root}")
|
self._model_name,
|
||||||
|
device=self._device,
|
||||||
|
download_root=download_root
|
||||||
|
)
|
||||||
|
logger.info(f"Loaded shared Whisper model: {self._model_name} on {self._device}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to load Whisper model: {e}")
|
logger.error(f"Failed to load Whisper model: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
"""Get the shared model instance"""
|
||||||
|
if not self._initialized:
|
||||||
|
raise RuntimeError("ModelManager not initialized. Call initialize() first.")
|
||||||
|
return self._model
|
||||||
|
|
||||||
|
def get_device(self):
|
||||||
|
"""Get the device being used"""
|
||||||
|
return self._device
|
||||||
|
|
||||||
|
def get_model_name(self):
|
||||||
|
"""Get the model name"""
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
"""Cleanup resources (call on shutdown)"""
|
||||||
|
with self._lock:
|
||||||
|
if self._model is not None:
|
||||||
|
del self._model
|
||||||
|
self._model = None
|
||||||
|
self._initialized = False
|
||||||
|
if self._device == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
logger.info("ModelManager cleaned up")
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionEngine:
|
||||||
|
"""Core transcription engine using shared Whisper model"""
|
||||||
|
|
||||||
|
def __init__(self, model_manager: Optional[ModelManager] = None):
|
||||||
|
"""Initialize with optional model manager (for testing)"""
|
||||||
|
self.model_manager = model_manager or ModelManager()
|
||||||
|
self.model = None # Will get from manager
|
||||||
|
self.processor = None
|
||||||
|
self.online_processor = None
|
||||||
|
self.device = self.model_manager.get_device()
|
||||||
|
self.model_name = self.model_manager._model_name
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""Load the transcription model"""
|
||||||
|
# Model is already loaded in ModelManager
|
||||||
|
# This method is kept for compatibility
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_model(self):
|
||||||
|
"""Get the shared model instance from ModelManager"""
|
||||||
|
return self.model_manager.get_model()
|
||||||
|
|
||||||
def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool:
|
def is_speech(self, audio: np.ndarray, energy_threshold: float = 0.002, zero_crossing_threshold: int = 50) -> bool:
|
||||||
"""
|
"""
|
||||||
Simple Voice Activity Detection
|
Simple Voice Activity Detection
|
||||||
@@ -176,13 +209,14 @@ class TranscriptionEngine:
|
|||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
# Use standard Whisper
|
# Use standard Whisper
|
||||||
if self.model:
|
model = self.get_model()
|
||||||
|
if model:
|
||||||
# Pad audio to minimum length if needed
|
# Pad audio to minimum length if needed
|
||||||
if len(audio) < SAMPLE_RATE:
|
if len(audio) < SAMPLE_RATE:
|
||||||
audio = np.pad(audio, (0, SAMPLE_RATE - len(audio)))
|
audio = np.pad(audio, (0, SAMPLE_RATE - len(audio)))
|
||||||
|
|
||||||
# Use more conservative settings to reduce hallucinations
|
# Use more conservative settings to reduce hallucinations
|
||||||
result = self.model.transcribe(
|
result = model.transcribe(
|
||||||
audio,
|
audio,
|
||||||
language=None if language == "auto" else language,
|
language=None if language == "auto" else language,
|
||||||
fp16=self.device == "cuda",
|
fp16=self.device == "cuda",
|
||||||
@@ -240,8 +274,9 @@ class TranscriptionEngine:
|
|||||||
audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE)
|
audio, _ = librosa.load(io.BytesIO(audio_data), sr=SAMPLE_RATE)
|
||||||
|
|
||||||
# Transcribe with Whisper
|
# Transcribe with Whisper
|
||||||
if self.model:
|
model = self.get_model()
|
||||||
result = self.model.transcribe(
|
if model:
|
||||||
|
result = model.transcribe(
|
||||||
audio,
|
audio,
|
||||||
language=None if config.language == "auto" else config.language,
|
language=None if config.language == "auto" else config.language,
|
||||||
task=config.task or "transcribe",
|
task=config.task or "transcribe",
|
||||||
@@ -278,8 +313,9 @@ class TranscriptionEngine:
|
|||||||
class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer):
|
class TranscriptionServicer(transcription_pb2_grpc.TranscriptionServiceServicer):
|
||||||
"""gRPC service implementation"""
|
"""gRPC service implementation"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, model_manager: Optional[ModelManager] = None):
|
||||||
self.engine = TranscriptionEngine()
|
self.model_manager = model_manager or ModelManager()
|
||||||
|
self.engine = TranscriptionEngine(self.model_manager)
|
||||||
self.sessions: Dict[str, TranscriptionSession] = {}
|
self.sessions: Dict[str, TranscriptionSession] = {}
|
||||||
self.start_time = time.time()
|
self.start_time = time.time()
|
||||||
|
|
||||||
@@ -419,7 +455,9 @@ async def serve_grpc(port: int = 50051):
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
servicer = TranscriptionServicer()
|
# Use the shared model manager
|
||||||
|
model_manager = get_global_model_manager()
|
||||||
|
servicer = TranscriptionServicer(model_manager)
|
||||||
transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server)
|
transcription_pb2_grpc.add_TranscriptionServiceServicer_to_server(servicer, server)
|
||||||
|
|
||||||
server.add_insecure_port(f'[::]:{port}')
|
server.add_insecure_port(f'[::]:{port}')
|
||||||
@@ -429,12 +467,27 @@ async def serve_grpc(port: int = 50051):
|
|||||||
await server.wait_for_termination()
|
await server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
# Global model manager instance (shared across all handlers)
|
||||||
|
_global_model_manager = None
|
||||||
|
|
||||||
|
def get_global_model_manager() -> ModelManager:
|
||||||
|
"""Get or create the global model manager"""
|
||||||
|
global _global_model_manager
|
||||||
|
if _global_model_manager is None:
|
||||||
|
_global_model_manager = ModelManager()
|
||||||
|
model_name = os.environ.get('MODEL_PATH', 'large-v3')
|
||||||
|
_global_model_manager.initialize(model_name)
|
||||||
|
return _global_model_manager
|
||||||
|
|
||||||
|
|
||||||
# WebSocket support for compatibility
|
# WebSocket support for compatibility
|
||||||
async def handle_websocket(websocket, path):
|
async def handle_websocket(websocket, path):
|
||||||
"""Handle WebSocket connections for compatibility"""
|
"""Handle WebSocket connections for compatibility"""
|
||||||
import websockets
|
import websockets
|
||||||
|
|
||||||
engine = TranscriptionEngine()
|
# Use the shared model manager instead of creating new engine
|
||||||
|
model_manager = get_global_model_manager()
|
||||||
|
engine = TranscriptionEngine(model_manager)
|
||||||
session_id = str(time.time())
|
session_id = str(time.time())
|
||||||
audio_buffer = bytearray()
|
audio_buffer = bytearray()
|
||||||
|
|
||||||
@@ -453,8 +506,8 @@ async def handle_websocket(websocket, path):
|
|||||||
audio_data = base64.b64decode(data['data'])
|
audio_data = base64.b64decode(data['data'])
|
||||||
audio_buffer.extend(audio_data)
|
audio_buffer.extend(audio_data)
|
||||||
|
|
||||||
# Process when we have enough audio
|
# Process when we have enough audio (3 seconds for better accuracy)
|
||||||
min_bytes = int(SAMPLE_RATE * 0.5 * 2)
|
min_bytes = int(SAMPLE_RATE * 3.0 * 2) # 3 seconds of PCM16
|
||||||
|
|
||||||
while len(audio_buffer) >= min_bytes:
|
while len(audio_buffer) >= min_bytes:
|
||||||
chunk = bytes(audio_buffer[:min_bytes])
|
chunk = bytes(audio_buffer[:min_bytes])
|
||||||
@@ -506,12 +559,24 @@ async def main():
|
|||||||
ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765'))
|
ws_port = int(os.environ.get('WEBSOCKET_PORT', '8765'))
|
||||||
enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true'
|
enable_websocket = os.environ.get('ENABLE_WEBSOCKET', 'true').lower() == 'true'
|
||||||
|
|
||||||
tasks = [serve_grpc(grpc_port)]
|
# Initialize the global model manager once at startup
|
||||||
|
logger.info("Initializing shared model manager...")
|
||||||
|
model_manager = get_global_model_manager()
|
||||||
|
logger.info(f"Model manager initialized with model: {model_manager._model_name}")
|
||||||
|
|
||||||
if enable_websocket:
|
try:
|
||||||
tasks.append(serve_websocket(ws_port))
|
tasks = [serve_grpc(grpc_port)]
|
||||||
|
|
||||||
await asyncio.gather(*tasks)
|
if enable_websocket:
|
||||||
|
tasks.append(serve_websocket(ws_port))
|
||||||
|
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
finally:
|
||||||
|
# Cleanup on shutdown
|
||||||
|
logger.info("Shutting down, cleaning up model manager...")
|
||||||
|
if model_manager:
|
||||||
|
model_manager.cleanup()
|
||||||
|
logger.info("Cleanup complete")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
Reference in New Issue
Block a user