| | import os |
| | import datetime |
| | import json |
| | import logging |
| | import librosa |
| | import pickle |
| | from typing import Dict |
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import yaml |
| | from models.audiosep import AudioSep, get_model_class |
| |
|
| |
|
| | def ignore_warnings(): |
| | import warnings |
| | |
| | warnings.filterwarnings('ignore', category=UserWarning, module='torch.functional') |
| |
|
| | |
| | pattern = r"Some weights of the model checkpoint at roberta-base were not used when initializing RobertaModel: \['lm_head\..*'\].*" |
| | warnings.filterwarnings('ignore', message=pattern) |
| |
|
| |
|
| |
|
| | def create_logging(log_dir, filemode): |
| | os.makedirs(log_dir, exist_ok=True) |
| | i1 = 0 |
| |
|
| | while os.path.isfile(os.path.join(log_dir, "{:04d}.log".format(i1))): |
| | i1 += 1 |
| |
|
| | log_path = os.path.join(log_dir, "{:04d}.log".format(i1)) |
| | logging.basicConfig( |
| | level=logging.DEBUG, |
| | format="%(asctime)s %(filename)s[line:%(lineno)d] %(levelname)s %(message)s", |
| | datefmt="%a, %d %b %Y %H:%M:%S", |
| | filename=log_path, |
| | filemode=filemode, |
| | ) |
| |
|
| | |
| | console = logging.StreamHandler() |
| | console.setLevel(logging.INFO) |
| | formatter = logging.Formatter("%(name)-12s: %(levelname)-8s %(message)s") |
| | console.setFormatter(formatter) |
| | logging.getLogger("").addHandler(console) |
| |
|
| | return logging |
| |
|
| |
|
| | def float32_to_int16(x: float) -> int: |
| | x = np.clip(x, a_min=-1, a_max=1) |
| | return (x * 32767.0).astype(np.int16) |
| |
|
| |
|
| | def int16_to_float32(x: int) -> float: |
| | return (x / 32767.0).astype(np.float32) |
| |
|
| |
|
| | def parse_yaml(config_yaml: str) -> Dict: |
| | r"""Parse yaml file. |
| | |
| | Args: |
| | config_yaml (str): config yaml path |
| | |
| | Returns: |
| | yaml_dict (Dict): parsed yaml file |
| | """ |
| |
|
| | with open(config_yaml, "r") as fr: |
| | return yaml.load(fr, Loader=yaml.FullLoader) |
| |
|
| |
|
| | def get_audioset632_id_to_lb(ontology_path: str) -> Dict: |
| | r"""Get AudioSet 632 classes ID to label mapping.""" |
| | |
| | audioset632_id_to_lb = {} |
| |
|
| | with open(ontology_path) as f: |
| | data_list = json.load(f) |
| |
|
| | for e in data_list: |
| | audioset632_id_to_lb[e["id"]] = e["name"] |
| |
|
| | return audioset632_id_to_lb |
| |
|
| |
|
| | def load_pretrained_panns( |
| | model_type: str, |
| | checkpoint_path: str, |
| | freeze: bool |
| | ) -> nn.Module: |
| | r"""Load pretrained pretrained audio neural networks (PANNs). |
| | |
| | Args: |
| | model_type: str, e.g., "Cnn14" |
| | checkpoint_path, str, e.g., "Cnn14_mAP=0.431.pth" |
| | freeze: bool |
| | |
| | Returns: |
| | model: nn.Module |
| | """ |
| |
|
| | if model_type == "Cnn14": |
| | Model = Cnn14 |
| |
|
| | elif model_type == "Cnn14_DecisionLevelMax": |
| | Model = Cnn14_DecisionLevelMax |
| |
|
| | else: |
| | raise NotImplementedError |
| |
|
| | model = Model(sample_rate=32000, window_size=1024, hop_size=320, |
| | mel_bins=64, fmin=50, fmax=14000, classes_num=527) |
| |
|
| | if checkpoint_path: |
| | checkpoint = torch.load(checkpoint_path, map_location="cpu") |
| | model.load_state_dict(checkpoint["model"]) |
| |
|
| | if freeze: |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| | return model |
| |
|
| |
|
| | def energy(x): |
| | return torch.mean(x ** 2) |
| |
|
| |
|
| | def magnitude_to_db(x): |
| | eps = 1e-10 |
| | return 20. * np.log10(max(x, eps)) |
| |
|
| |
|
| | def db_to_magnitude(x): |
| | return 10. ** (x / 20) |
| |
|
| |
|
| | def ids_to_hots(ids, classes_num, device): |
| | hots = torch.zeros(classes_num).to(device) |
| | for id in ids: |
| | hots[id] = 1 |
| | return hots |
| |
|
| |
|
| | def calculate_sdr( |
| | ref: np.ndarray, |
| | est: np.ndarray, |
| | eps=1e-10 |
| | ) -> float: |
| | r"""Calculate SDR between reference and estimation. |
| | |
| | Args: |
| | ref (np.ndarray), reference signal |
| | est (np.ndarray), estimated signal |
| | """ |
| | reference = ref |
| | noise = est - reference |
| |
|
| |
|
| | numerator = np.clip(a=np.mean(reference ** 2), a_min=eps, a_max=None) |
| |
|
| | denominator = np.clip(a=np.mean(noise ** 2), a_min=eps, a_max=None) |
| |
|
| | sdr = 10. * np.log10(numerator / denominator) |
| |
|
| | return sdr |
| |
|
| |
|
| | def calculate_sisdr(ref, est): |
| | r"""Calculate SDR between reference and estimation. |
| | |
| | Args: |
| | ref (np.ndarray), reference signal |
| | est (np.ndarray), estimated signal |
| | """ |
| |
|
| | eps = np.finfo(ref.dtype).eps |
| |
|
| | reference = ref.copy() |
| | estimate = est.copy() |
| | |
| | reference = reference.reshape(reference.size, 1) |
| | estimate = estimate.reshape(estimate.size, 1) |
| |
|
| | Rss = np.dot(reference.T, reference) |
| | |
| | a = (eps + np.dot(reference.T, estimate)) / (Rss + eps) |
| |
|
| | e_true = a * reference |
| | e_res = estimate - e_true |
| |
|
| | Sss = (e_true**2).sum() |
| | Snn = (e_res**2).sum() |
| |
|
| | sisdr = 10 * np.log10((eps+ Sss)/(eps + Snn)) |
| |
|
| | return sisdr |
| |
|
| |
|
| | class StatisticsContainer(object): |
| | def __init__(self, statistics_path): |
| | self.statistics_path = statistics_path |
| |
|
| | self.backup_statistics_path = "{}_{}.pkl".format( |
| | os.path.splitext(self.statistics_path)[0], |
| | datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S"), |
| | ) |
| |
|
| | self.statistics_dict = {"balanced_train": [], "test": []} |
| |
|
| | def append(self, steps, statistics, split, flush=True): |
| | statistics["steps"] = steps |
| | self.statistics_dict[split].append(statistics) |
| |
|
| | if flush: |
| | self.flush() |
| |
|
| | def flush(self): |
| | pickle.dump(self.statistics_dict, open(self.statistics_path, "wb")) |
| | pickle.dump(self.statistics_dict, open(self.backup_statistics_path, "wb")) |
| | logging.info(" Dump statistics to {}".format(self.statistics_path)) |
| | logging.info(" Dump statistics to {}".format(self.backup_statistics_path)) |
| |
|
| |
|
| | def get_mean_sdr_from_dict(sdris_dict): |
| | mean_sdr = np.nanmean(list(sdris_dict.values())) |
| | return mean_sdr |
| |
|
| |
|
| | def remove_silence(audio: np.ndarray, sample_rate: int) -> np.ndarray: |
| | r"""Remove silent frames.""" |
| | window_size = int(sample_rate * 0.1) |
| | threshold = 0.02 |
| |
|
| | frames = librosa.util.frame(x=audio, frame_length=window_size, hop_length=window_size).T |
| | |
| |
|
| | new_frames = get_active_frames(frames, threshold) |
| | |
| |
|
| | new_audio = new_frames.flatten() |
| | |
| |
|
| | return new_audio |
| |
|
| |
|
| | def get_active_frames(frames: np.ndarray, threshold: float) -> np.ndarray: |
| | r"""Get active frames.""" |
| |
|
| | energy = np.max(np.abs(frames), axis=-1) |
| | |
| |
|
| | active_indexes = np.where(energy > threshold)[0] |
| | |
| |
|
| | new_frames = frames[active_indexes] |
| | |
| |
|
| | return new_frames |
| |
|
| |
|
| | def repeat_to_length(audio: np.ndarray, segment_samples: int) -> np.ndarray: |
| | r"""Repeat audio to length.""" |
| | |
| | repeats_num = (segment_samples // audio.shape[-1]) + 1 |
| | audio = np.tile(audio, repeats_num)[0 : segment_samples] |
| |
|
| | return audio |
| |
|
| | def calculate_segmentwise_sdr(ref, est, hop_samples, return_sdr_list=False): |
| | min_len = min(ref.shape[-1], est.shape[-1]) |
| | pointer = 0 |
| | sdrs = [] |
| | while pointer + hop_samples < min_len: |
| | sdr = calculate_sdr( |
| | ref=ref[:, pointer : pointer + hop_samples], |
| | est=est[:, pointer : pointer + hop_samples], |
| | ) |
| | sdrs.append(sdr) |
| | pointer += hop_samples |
| |
|
| | sdr = np.nanmedian(sdrs) |
| |
|
| | if return_sdr_list: |
| | return sdr, sdrs |
| | else: |
| | return sdr |
| |
|
| |
|
| | def loudness(data, input_loudness, target_loudness): |
| | """ Loudness normalize a signal. |
| | |
| | Normalize an input signal to a user loudness in dB LKFS. |
| | |
| | Params |
| | ------- |
| | data : torch.Tensor |
| | Input multichannel audio data. |
| | input_loudness : float |
| | Loudness of the input in dB LUFS. |
| | target_loudness : float |
| | Target loudness of the output in dB LUFS. |
| | |
| | Returns |
| | ------- |
| | output : torch.Tensor |
| | Loudness normalized output data. |
| | """ |
| | |
| | |
| | delta_loudness = target_loudness - input_loudness |
| | gain = torch.pow(10.0, delta_loudness / 20.0) |
| |
|
| | output = gain * data |
| |
|
| | |
| | |
| | |
| |
|
| | return output |
| |
|
| |
|
| | def load_ss_model( |
| | configs: Dict, |
| | checkpoint_path: str, |
| | query_encoder: nn.Module |
| | ) -> nn.Module: |
| | r"""Load trained universal source separation model. |
| | |
| | Args: |
| | configs (Dict) |
| | checkpoint_path (str): path of the checkpoint to load |
| | device (str): e.g., "cpu" | "cuda" |
| | |
| | Returns: |
| | pl_model: pl.LightningModule |
| | """ |
| |
|
| | ss_model_type = configs["model"]["model_type"] |
| | input_channels = configs["model"]["input_channels"] |
| | output_channels = configs["model"]["output_channels"] |
| | condition_size = configs["model"]["condition_size"] |
| | |
| | |
| | SsModel = get_model_class(model_type=ss_model_type) |
| |
|
| | ss_model = SsModel( |
| | input_channels=input_channels, |
| | output_channels=output_channels, |
| | condition_size=condition_size, |
| | ) |
| |
|
| | |
| | pl_model = AudioSep.load_from_checkpoint( |
| | checkpoint_path=checkpoint_path, |
| | strict=False, |
| | ss_model=ss_model, |
| | waveform_mixer=None, |
| | query_encoder=query_encoder, |
| | loss_function=None, |
| | optimizer_type=None, |
| | learning_rate=None, |
| | lr_lambda_func=None, |
| | map_location=torch.device('cpu'), |
| | ) |
| |
|
| | return pl_model |
| |
|
| |
|
| | def parse_yaml(config_yaml: str) -> Dict: |
| | r"""Parse yaml file. |
| | |
| | Args: |
| | config_yaml (str): config yaml path |
| | |
| | Returns: |
| | yaml_dict (Dict): parsed yaml file |
| | """ |
| |
|
| | with open(config_yaml, "r") as fr: |
| | return yaml.load(fr, Loader=yaml.FullLoader) |