| | import os
|
| | import re
|
| | import json
|
| | import torch
|
| | import librosa
|
| | import soundfile
|
| | import torchaudio
|
| | import numpy as np
|
| | import torch.nn as nn
|
| | from tqdm import tqdm
|
| | import torch
|
| |
|
| | from . import utils
|
| | from . import commons
|
| | from .models import SynthesizerTrn
|
| | from .split_utils import split_sentence
|
| | from .mel_processing import spectrogram_torch, spectrogram_torch_conv
|
| | from .download_utils import load_or_download_config, load_or_download_model
|
| |
|
| | class TTS(nn.Module):
|
| | def __init__(self,
|
| | language,
|
| | device='auto',
|
| | use_hf=True,
|
| | config_path=None,
|
| | ckpt_path=None):
|
| | super().__init__()
|
| | if device == 'auto':
|
| | device = 'cpu'
|
| | if torch.cuda.is_available(): device = 'cuda'
|
| | if torch.backends.mps.is_available(): device = 'mps'
|
| | if 'cuda' in device:
|
| | assert torch.cuda.is_available()
|
| |
|
| |
|
| | hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
|
| |
|
| | num_languages = hps.num_languages
|
| | num_tones = hps.num_tones
|
| | symbols = hps.symbols
|
| |
|
| | model = SynthesizerTrn(
|
| | len(symbols),
|
| | hps.data.filter_length // 2 + 1,
|
| | hps.train.segment_size // hps.data.hop_length,
|
| | n_speakers=hps.data.n_speakers,
|
| | num_tones=num_tones,
|
| | num_languages=num_languages,
|
| | **hps.model,
|
| | ).to(device)
|
| |
|
| | model.eval()
|
| | self.model = model
|
| | self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
| | self.hps = hps
|
| | self.device = device
|
| |
|
| |
|
| | checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
|
| | self.model.load_state_dict(checkpoint_dict['model'], strict=True)
|
| |
|
| | language = language.split('_')[0]
|
| | self.language = 'ZH_MIX_EN' if language == 'ZH' else language
|
| |
|
| | @staticmethod
|
| | def audio_numpy_concat(segment_data_list, sr, speed=1.):
|
| | audio_segments = []
|
| | for segment_data in segment_data_list:
|
| | audio_segments += segment_data.reshape(-1).tolist()
|
| | audio_segments += [0] * int((sr * 0.05) / speed)
|
| | audio_segments = np.array(audio_segments).astype(np.float32)
|
| | return audio_segments
|
| |
|
| | @staticmethod
|
| | def split_sentences_into_pieces(text, language, quiet=False):
|
| | texts = split_sentence(text, language_str=language)
|
| | if not quiet:
|
| | print(" > Text split to sentences.")
|
| | print('\n'.join(texts))
|
| | print(" > ===========================")
|
| | return texts
|
| |
|
| | def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
|
| | language = self.language
|
| | texts = self.split_sentences_into_pieces(text, language, quiet)
|
| | audio_list = []
|
| | if pbar:
|
| | tx = pbar(texts)
|
| | else:
|
| | if position:
|
| | tx = tqdm(texts, position=position)
|
| | elif quiet:
|
| | tx = texts
|
| | else:
|
| | tx = tqdm(texts)
|
| | for t in tx:
|
| | if language in ['EN', 'ZH_MIX_EN']:
|
| | t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
|
| | device = self.device
|
| | bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
|
| | with torch.no_grad():
|
| | x_tst = phones.to(device).unsqueeze(0)
|
| | tones = tones.to(device).unsqueeze(0)
|
| | lang_ids = lang_ids.to(device).unsqueeze(0)
|
| | bert = bert.to(device).unsqueeze(0)
|
| | ja_bert = ja_bert.to(device).unsqueeze(0)
|
| | x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
|
| | del phones
|
| | speakers = torch.LongTensor([speaker_id]).to(device)
|
| | audio = self.model.infer(
|
| | x_tst,
|
| | x_tst_lengths,
|
| | speakers,
|
| | tones,
|
| | lang_ids,
|
| | bert,
|
| | ja_bert,
|
| | sdp_ratio=sdp_ratio,
|
| | noise_scale=noise_scale,
|
| | noise_scale_w=noise_scale_w,
|
| | length_scale=1. / speed,
|
| | )[0][0, 0].data.cpu().float().numpy()
|
| | del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
|
| |
|
| | audio_list.append(audio)
|
| | torch.cuda.empty_cache()
|
| | audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
|
| |
|
| | if output_path is None:
|
| | return audio
|
| | else:
|
| | if format:
|
| | soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
|
| | else:
|
| | soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|
| |
|