| | |
| | import os |
| | from PIL import Image |
| | from transformers import BlipProcessor, BlipForConditionalGeneration |
| |
|
| | def process_dataset(zip_path, output_dir, generate_captions=True): |
| | os.makedirs(output_dir, exist_ok=True) |
| | |
| | |
| | import zipfile |
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| | zip_ref.extractall(output_dir) |
| | |
| | |
| | processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base") |
| | model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base") |
| |
|
| | |
| | for img_name in os.listdir(output_dir): |
| | if img_name.lower().endswith(('.png', '.jpg', '.jpeg')): |
| | img_path = os.path.join(output_dir, img_name) |
| | image = Image.open(img_path).convert('RGB') |
| | |
| | |
| | image.thumbnail((512, 512), Image.LANCZOS) |
| | image.save(img_path) |
| | |
| | if generate_captions: |
| | inputs = processor(image, return_tensors="pt") |
| | outputs = model.generate(**inputs, max_new_tokens=50) |
| | caption = processor.decode(outputs[0], skip_special_tokens=True) |
| | |
| | txt_path = os.path.splitext(img_path)[0] + ".txt" |
| | with open(txt_path, "w", encoding="utf-8") as f: |
| | f.write(caption) |
| | |
| | return output_dir |