stt_whisper.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. from fastapi import FastAPI, WebSocket, WebSocketDisconnect
  2. import whisper
  3. import webrtcvad
  4. import numpy as np
  5. from pydub import AudioSegment
  6. import scipy.io.wavfile as wavfile
  7. import io
  8. import asyncio
  9. app = FastAPI()
  10. # Whisper 모델 로드 (tiny 모델로 실시간성 유지)
  11. model = whisper.load_model("tiny")
  12. # VAD 설정
  13. vad = webrtcvad.Vad()
  14. vad.set_mode(1) # 0~3 (3이 가장 엄격), 1은 중간 수준의 감도
  15. # 클라이언트 관리
  16. clients = {}
  17. # 발화 감지 설정
  18. SAMPLE_RATE = 16000 # Whisper와 VAD가 요구하는 샘플레이트
  19. FRAME_DURATION_MS = 30 # VAD 프레임 길이 (10, 20, 30ms 중 선택)
  20. SILENCE_DURATION = 1.0 # 침묵 지속 시간 (초)
  21. def is_speech_vad(audio_chunk):
  22. """webrtcvad를 사용한 발화 감지"""
  23. # 16-bit PCM으로 변환
  24. audio = np.frombuffer(audio_chunk, dtype=np.int16)
  25. frame_size = (SAMPLE_RATE * FRAME_DURATION_MS / 1000) * 2 # 바이트 단위
  26. if len(audio) < frame_size:
  27. return False
  28. return vad.is_speech(audio[:frame_size], SAMPLE_RATE)
  29. async def process_audio_in_memory(audio_buffer):
  30. """메모리에서 오디오 처리 및 텍스트 변환"""
  31. audio_segment = AudioSegment.from_file(io.BytesIO(audio_buffer), format="webm")
  32. # WAV 변환 (Mono, 16-bit PCM, 16kHz 샘플링)
  33. audio_segment = audio_segment.set_channels(1).set_frame_rate(SAMPLE_RATE).set_sample_width(2)
  34. wav_buffer = io.BytesIO()
  35. audio_segment.export(wav_buffer, format="wav")
  36. wav_buffer.seek(0)
  37. # Whisper로 음성 인식
  38. result = model.transcribe(wav_buffer, fp16=False)
  39. return result["text"]
  40. @app.websocket("/audio-stream")
  41. async def websocket_endpoint(websocket: WebSocket):
  42. await websocket.accept()
  43. client_id = str(id(websocket))
  44. clients[client_id] = websocket
  45. print(f"Client {client_id} connected")
  46. audio_buffer = bytearray()
  47. last_speech_time = 0
  48. silence_start = None
  49. try:
  50. while True:
  51. # 오디오 청크 수신
  52. audio_chunk = await websocket.receive_bytes()
  53. # 오디오 데이터를 새로운 버퍼에 저장 (기존 데이터 누적 방지)
  54. audio_buffer = bytearray(audio_chunk) # 🔥 새로운 데이터로 덮어쓰기
  55. # 수신 크기 확인
  56. print(f"Received data size: {len(audio_chunk)} bytes")
  57. # 오디오 바이너리 데이터 => 숫자배열(numpy)로 해석
  58. audio_np = np.frombuffer(audio_buffer, dtype=np.int16).copy()
  59. # WAV 파일로 저장 (덮어쓰기)
  60. output_file = "recorded_audio.wav"
  61. wavfile.write(output_file, 16000, audio_np)
  62. # STT 처리
  63. stt_result = model.transcribe(output_file, language="ko")
  64. transcription = stt_result["text"]
  65. # 클라이언트에 데이터 전송
  66. await websocket.send_text(transcription)
  67. # VAD로 발화 감지
  68. # if is_speech_vad(audio_chunk):
  69. # last_speech_time = asyncio.get_event_loop().time()
  70. # silence_start = None
  71. # await websocket.send_text("Speech detected...")
  72. # else:
  73. # if silence_start is None:
  74. # silence_start = asyncio.get_event_loop().time()
  75. # elif (asyncio.get_event_loop().time() - silence_start) > SILENCE_DURATION and last_speech_time > 0:
  76. # # 침묵이 지속되면 음성 인식 수행
  77. # transcription = await process_audio_in_memory(bytes(audio_buffer))
  78. # await websocket.send_text(transcription)
  79. # audio_buffer = bytearray() # 버퍼 초기화
  80. # silence_start = None
  81. # last_speech_time = 0
  82. except WebSocketDisconnect:
  83. print(f"Client {client_id} disconnected")
  84. del clients[client_id]
  85. except Exception as e:
  86. print(f"Error: {e}")
  87. await websocket.send_text(f"Error: {str(e)}")
  88. if __name__ == "__main__":
  89. import uvicorn
  90. uvicorn.run(app, host="0.0.0.0", port=8000)