microsoft/TypeAgent
Publicmirrored fromhttps://github.com/microsoft/TypeAgentAvailable
python/tts/speechT5/speechT5.py
89lines · modecode
| 1 | # Copyright (c) Microsoft Corporation. |
| 2 | # Licensed under the MIT License. |
| 3 | |
| 4 | from io import BytesIO |
| 5 | import wave; |
| 6 | import uvicorn |
| 7 | import logging |
| 8 | import torch |
| 9 | from fastapi import FastAPI |
| 10 | from fastapi.middleware.cors import CORSMiddleware |
| 11 | from fastapi.responses import JSONResponse, StreamingResponse |
| 12 | from transformers import pipeline |
| 13 | from datasets import load_dataset |
| 14 | import struct |
| 15 | from pydantic import BaseModel |
| 16 | import time |
| 17 | |
| 18 | import os |
| 19 | os.environ['KMP_DUPLICATE_LIB_OK']='True' |
| 20 | |
| 21 | # Configure logging |
| 22 | logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s') |
| 23 | |
| 24 | # Initialize the app |
| 25 | app = FastAPI() |
| 26 | |
| 27 | # allow all cors |
| 28 | app.add_middleware( |
| 29 | CORSMiddleware, |
| 30 | allow_origins=["*"], |
| 31 | allow_credentials=True, |
| 32 | allow_methods=["*"], |
| 33 | allow_headers=["*"], |
| 34 | ) |
| 35 | |
| 36 | # Load the model |
| 37 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 38 | print('Using device:', device) |
| 39 | print("Loading model...") |
| 40 | synthesizer = pipeline("text-to-speech", "microsoft/speecht5_tts", device=device) |
| 41 | embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") |
| 42 | filenames = [] |
| 43 | for (i, x) in enumerate(embeddings_dataset): |
| 44 | filenames.append([x["filename"],str(i)]) |
| 45 | speaker_embeddings = list(map((x := lambda x: torch.tensor(x["xvector"]).unsqueeze(0)), embeddings_dataset)); |
| 46 | |
| 47 | print("Model loaded!") |
| 48 | |
| 49 | |
| 50 | @app.get("/voices") |
| 51 | async def voices(): |
| 52 | return JSONResponse(content=filenames, status_code=200) |
| 53 | |
| 54 | class SynthesizeRequest(BaseModel): |
| 55 | text: str |
| 56 | voiceName: str | None = None |
| 57 | |
| 58 | @app.post("/synthesize") |
| 59 | async def synthesize(data: SynthesizeRequest): |
| 60 | try: |
| 61 | text = data.text |
| 62 | if data.voiceName is None: |
| 63 | voiceName = 7306 |
| 64 | else: |
| 65 | voiceName = int(data.voiceName) |
| 66 | print("Synthesizing with voice", voiceName, ":", text) |
| 67 | start = time.time() |
| 68 | speech = synthesizer(text, forward_params={"speaker_embeddings": speaker_embeddings[voiceName]}) |
| 69 | end = time.time() |
| 70 | print("Synthesized in", end-start, "seconds") |
| 71 | byte_io = BytesIO() |
| 72 | f = wave.open(byte_io, 'wb') |
| 73 | f.setnchannels(1) |
| 74 | f.setsampwidth(3) |
| 75 | f.setframerate(speech["sampling_rate"]) |
| 76 | data_as_bytes = (struct.pack('<i', int(samp*(2**23-1))) for samp in speech["audio"]) |
| 77 | for data_bytes in data_as_bytes: |
| 78 | f.writeframes(data_bytes[0:3]) |
| 79 | f.close() |
| 80 | byte_io.seek(0) |
| 81 | print("Written in", time.time()-end, "seconds") |
| 82 | return StreamingResponse(byte_io, media_type="audio/wav") |
| 83 | except Exception as e: |
| 84 | logging.error("An error occurred during synthesize", exc_info=True) |
| 85 | return JSONResponse(content={"error": "An internal error has occurred!"}, status_code=500) |
| 86 | |
| 87 | |
| 88 | if __name__ == "__main__": |
| 89 | uvicorn.run(app, host="0.0.0.0", port=8002) |