Audio-Text-to-Text
Transformers
Safetensors
Hindi
English
audio
speech
audio-language-model
whisper
sarvam-m
lora
projector
indic
hindi
Instructions to use Mayank022/Audio-Language-Model with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Mayank022/Audio-Language-Model with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("Mayank022/Audio-Language-Model", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import dataclasses | |
| import torch | |
| import transformers | |
| from transformers import Trainer, TrainingArguments, TrainerCallback | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from huggingface_hub import HfApi, login | |
| import wandb | |
| from dotenv import load_dotenv | |
| from config import TrainConfig, ModelConfig | |
| from model import MultiModalModel | |
| from data import AudioTextDataset, DataCollator | |
| class SamplePredictionCallback(TrainerCallback): | |
| """Every N steps, print ground-truth vs model-predicted transcript for a few samples.""" | |
| def __init__(self, tokenizer, data_collator, train_dataset, sample_every_n_steps: int = 100, num_samples: int = 2, prompt: str = "Transcribe the following audio:"): | |
| self.tokenizer = tokenizer | |
| self.data_collator = data_collator | |
| self.train_dataset = train_dataset | |
| self.sample_every_n_steps = sample_every_n_steps | |
| self.num_samples = num_samples | |
| self.prompt = prompt | |
| def on_log(self, args, state, control, model=None, **kwargs): | |
| if state.global_step == 0 or state.global_step % self.sample_every_n_steps != 0: | |
| return | |
| if model is None: | |
| return | |
| model.eval() | |
| device = next(model.parameters()).device | |
| try: | |
| indices = [i % len(self.train_dataset) for i in range(self.num_samples)] | |
| batch = self.data_collator([self.train_dataset[i] for i in indices]) | |
| audio_values = batch["audio_values"].to(device) | |
| labels_batch = batch["labels"] | |
| continuations = batch.get("continuation", [""] * audio_values.size(0)) | |
| prompt_ids = self.tokenizer(self.prompt, return_tensors="pt", add_special_tokens=True).input_ids.to(device) | |
| prompt_ids = prompt_ids.expand(audio_values.size(0), -1) | |
| with torch.no_grad(): | |
| gen_ids = model.generate( | |
| input_ids=prompt_ids, | |
| audio_values=audio_values, | |
| max_new_tokens=120, | |
| do_sample=False, | |
| pad_token_id=self.tokenizer.pad_token_id or self.tokenizer.eos_token_id, | |
| ) | |
| prompt_len = prompt_ids.size(1) | |
| # Create a wandb Table | |
| columns = ["Step", "Audio Index", "Ground Truth", "Prediction", "Continuation"] | |
| table = wandb.Table(columns=columns) | |
| print(f"\n[WandB] Logging sample predictions at step {state.global_step}") | |
| for i in range(audio_values.size(0)): | |
| gt_tokens = [t for t in labels_batch[i].tolist() if t != -100] | |
| gt_text = self.tokenizer.decode(gt_tokens, skip_special_tokens=True).strip() | |
| pred_text = self.tokenizer.decode(gen_ids[i][prompt_len:], skip_special_tokens=True).strip() | |
| cont_ref = continuations[i] if i < len(continuations) else "" | |
| # Add row to table | |
| table.add_data(state.global_step, i, gt_text, pred_text, cont_ref) | |
| # Log the table to wandb | |
| if wandb.run is not None: | |
| wandb.log({"sample_predictions": table}, step=state.global_step) | |
| else: | |
| print("Warning: WandB run not active, skipping logging.") | |
| except Exception as e: | |
| print(f"[SamplePredictionCallback] Error: {e}\n") | |
| finally: | |
| model.train() | |
| import shutil | |
| import glob | |
| from transformers.trainer_utils import get_last_checkpoint | |
| class AggressiveDeleteCallback(TrainerCallback): | |
| """ | |
| Deletes ALL existing checkpoints in output_dir *before* saving a new one | |
| to ensure we don't run out of disk space. | |
| Only keeps the one we are currently training on (in memory) effectively, | |
| but on disk we want 0 checkpoints just before save. | |
| WARNING: If save fails, we have NO checkpoints on disk. Risk accepted by user. | |
| """ | |
| def __init__(self, output_dir): | |
| self.output_dir = output_dir | |
| def on_step_end(self, args, state, control, **kwargs): | |
| # Check if we are about to save | |
| # Trainer checks: if save_strategy == "steps" and global_step % save_steps == 0 | |
| if args.save_strategy == "steps" and args.save_steps > 0: | |
| if state.global_step > 0 and state.global_step % args.save_steps == 0: | |
| # We are about to save. Delete old checkpoints. | |
| print(f"\n[AggressiveDeleteCallback] Step {state.global_step}: Deleting old checkpoints to free space before saving...") | |
| # Verify we aren't deleting something we just wrote (unlikely in on_step_end, save happens after) | |
| # But we might be resuming. | |
| ckpts = glob.glob(os.path.join(self.output_dir, "checkpoint-*")) | |
| for ckpt in ckpts: | |
| try: | |
| shutil.rmtree(ckpt) | |
| print(f" Deleted {ckpt}") | |
| except Exception as e: | |
| print(f" Failed to delete {ckpt}: {e}") | |
| def train(): | |
| # Load environment variables | |
| load_dotenv() | |
| # Load Configs | |
| train_config = TrainConfig() | |
| model_config = ModelConfig() | |
| # Initialize WandB | |
| wandb.init( | |
| project=train_config.wandb_project, | |
| entity=train_config.wandb_entity, | |
| name=train_config.wandb_run_name, | |
| config=dataclasses.asdict(train_config), | |
| ) | |
| # Initialize Tokenizer & Processor | |
| tokenizer = transformers.AutoTokenizer.from_pretrained(model_config.text_model_id) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| processor = transformers.AutoProcessor.from_pretrained(model_config.audio_model_id) | |
| # Initialize Model | |
| model = MultiModalModel(model_config) | |
| # Apply LoRA if requested | |
| if train_config.use_lora: | |
| peft_config = LoraConfig( | |
| task_type=TaskType.CAUSAL_LM, | |
| inference_mode=False, | |
| r=train_config.lora_r, | |
| lora_alpha=train_config.lora_alpha, | |
| lora_dropout=train_config.lora_dropout, | |
| target_modules=["q_proj", "v_proj"] | |
| ) | |
| model.llm = get_peft_model(model.llm, peft_config) | |
| model.llm.print_trainable_parameters() | |
| # Dataset | |
| train_dataset = AudioTextDataset(train_config, processor, model_config, tokenizer) | |
| data_collator = DataCollator(processor, tokenizer) | |
| # Training Arguments (tuned for A100 80GB: bf16, larger batch, fast dataloader) | |
| training_args = TrainingArguments( | |
| output_dir=train_config.output_dir, | |
| per_device_train_batch_size=train_config.batch_size, | |
| gradient_accumulation_steps=train_config.accum_steps, | |
| learning_rate=train_config.learning_rate, | |
| lr_scheduler_type=train_config.lr_scheduler_type, | |
| num_train_epochs=train_config.num_epochs, | |
| max_steps=train_config.max_steps, | |
| bf16=train_config.use_bf16, | |
| gradient_checkpointing=train_config.gradient_checkpointing, | |
| dataloader_num_workers=train_config.dataloader_num_workers, | |
| dataloader_pin_memory=train_config.dataloader_pin_memory, | |
| logging_steps=train_config.log_steps, | |
| logging_first_step=True, | |
| logging_nan_inf_filter=True, | |
| save_steps=train_config.save_steps, | |
| save_total_limit=train_config.save_total_limit, | |
| eval_strategy="no", # change if val set provided | |
| remove_unused_columns=False, # Important because we have custom forward signature | |
| report_to="wandb", | |
| log_level="info", | |
| log_level_replica="info", | |
| ) | |
| sample_callback = SamplePredictionCallback( | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| train_dataset=train_dataset, | |
| sample_every_n_steps=train_config.sample_pred_every_steps, | |
| num_samples=2, | |
| prompt="Transcribe the following audio:", | |
| ) | |
| aggressive_delete_callback = AggressiveDeleteCallback(train_config.output_dir) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| data_collator=data_collator, | |
| callbacks=[sample_callback, aggressive_delete_callback], | |
| ) | |
| total_steps = train_config.max_steps | |
| print(f"\n>>> Training: max_steps={total_steps}, batch_size={train_config.batch_size}, " | |
| f"grad_accum={train_config.accum_steps} (effective batch={train_config.batch_size * train_config.accum_steps})") | |
| print(f">>> Sample predictions (GT vs predicted transcript) every {train_config.sample_pred_every_steps} steps.\n") | |
| # Resume from checkpoint if exists | |
| last_checkpoint = get_last_checkpoint(train_config.output_dir) | |
| if last_checkpoint is not None: | |
| print(f">>> Resuming from checkpoint: {last_checkpoint}") | |
| trainer.train(resume_from_checkpoint=last_checkpoint) | |
| else: | |
| trainer.train() | |
| # Save | |
| trainer.save_model(train_config.output_dir) | |
| tokenizer.save_pretrained(train_config.output_dir) | |
| processor.save_pretrained(train_config.output_dir) | |
| # Push to Hub | |
| if train_config.push_to_hub: | |
| print(f"\n>>> Pushing model to Hugging Face Hub: {train_config.hub_model_id}") | |
| if train_config.hub_token: | |
| login(token=train_config.hub_token) | |
| api = HfApi() | |
| # Create repo if needed | |
| # private=True by default for safety, user can adjust | |
| try: | |
| api.create_repo(repo_id=train_config.hub_model_id, private=train_config.hub_private_repo, exist_ok=True) | |
| except Exception as e: | |
| print(f"Warning: Could not create repo {train_config.hub_model_id}. Error: {e}") | |
| # Upload model folder | |
| try: | |
| api.upload_folder( | |
| folder_path=train_config.output_dir, | |
| repo_id=train_config.hub_model_id, | |
| repo_type="model", | |
| ) | |
| # Upload code files to ensure custom model works | |
| for file in ["model.py", "config.py", "data.py", "inference.py"]: | |
| if os.path.exists(file): | |
| api.upload_file( | |
| path_or_fileobj=file, | |
| path_in_repo=file, | |
| repo_id=train_config.hub_model_id, | |
| repo_type="model", | |
| ) | |
| print(f">>> Successfully pushed to {train_config.hub_model_id}") | |
| except Exception as e: | |
| print(f"Error pushing to hub: {e}") | |
| if __name__ == "__main__": | |
| train() | |