自定义数据 微调CLIP (结合paper)

CLIP 是 Contrastive Language-Image Pre-training 的缩写,是一个擅长理解文本和图像之间关系的模型,下面是一个简单的介绍:

优点: CLIP 在零样本学习方面特别强大,它可以(用自然语言)给出图像的描述,并在基于该描述对新图像进行分类方面表现良好,例如,您可以将图像描述为“a”。猫的黑白照片”,CLIP 可以准确地对猫的新照片进行分类,即使它以前没有见过这些特定图像。
训练: CLIP 在从互联网收集的大量文本图像对数据集上进行训练,这使得它能够学习视觉概念及其描述之间的联系。
局限性: CLIP 也有缺点,训练的计算成本可能很高,并且在需要非常具体或抽象概念的任务上,或者对于与训练所用的文本描述非常不同的数据时,可能表现不佳。训练可能会将社会偏见引入模型中。

paper:Learning Transferable Visual Models From Natural Language Supervision

本文用CLIP做一个零样本分类,
CLIP训练的时候用的是图片和文本描述对,并没有分类的标签,那如何让CLIP做零样本分类?
我们需要给出标签的文本,让图像和所有的文本标签进行匹配,得分高的就是匹配到的标签文本。

paper中提到预测哪个文本整体与哪个图像配对,而不是该文本的准确单词。

在这里插入图片描述

下面通过一个kaggle数据集来具体说明。

这里选用indo fashion dataset, 它有15种印度服饰。

在这里插入图片描述
类别如下:
在这里插入图片描述

数据集结构:
其中images文件夹下又有train, val, test文件夹。

在这里插入图片描述

再看一下json文件,
image_path指的是上面images文件夹下的路径,
product_title是和图片对应的文本描述,训练的时候就是用图片和这个文本进行匹配。
class_label训练的时候不需要,最后验证分类是否正确时会用到。

在这里插入图片描述

import需要的库,定义数据集的文件夹,读取json数据

import json
from PIL import Image
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import clip
from transformers import CLIPProcessor,CLIPModel
from tqdm import tqdm

json_path = 'your_path/train_data.json'
image_path = 'your_path/images/train/'

input_data = []
with open(json_path, 'r') as f:
	for line in f:
		obj = json.loads(line)
		input_data.append(obj)

CLIP模型,如果不能download, 手动下载走offline模式。

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
Setting our device to GPU (Cuda) and loading the pre-trained CLIP model.

device = "cuda:0" if torch.cuda.is_available() else "cpu" 

model, preprocess = clip.load("ViT-B/32", device=device, jit=False)

定义Dataloader

# Define a custom dataset
class image_title_dataset():
	def __init__(self, list_image_path, list_txt):
		self.image_path = list_image_path

		# Tokenize text using CLIP's tokenizer
		self.title = clip.tokenize(list_txt)
	
    def __len__(self):
		# Define the length of the dataset
		return len(self.title)

	def __getitem__(self, idx):
        image = preprocess(Image.open(self.image_path[idx]))
		title = self.title[idx]
		return image, title

这里的dataset需要传入list_image_path和list_txt,
格式是这种:
list_image_path = [‘folder/image1.jpg’,‘folder2/image2.jpg’]
list_txt = [‘description for image1.jpg’ , ‘description for image2.jpg’]
所以要把image_path和product_title都装进list里面。

注意,CLIP的最大序列长度限制在76, 而有些文本描述非常长,需要截掉一部分,
当然截到76长度也有很多种方法,这里简单粗暴就从开头取长度76.

实际代码中,indo数据集不限制长度会报错,而博主觉得这个76可能是text被tokenize之后的token的长度,而不是原文本的长度,
因为把文本截到长度>77也是可以的。
而token的长度是由tokenize的算法决定的。具体最大极限文本长度是多少没测,这里简单地截取到77.

在这里插入图片描述

list_image_path = []
list_txt = []
for item in input_data:
  img_path = image_path + item['image_path'].split('/')[-1]
  
  caption = item['product_title'][:77]
  list_image_path.append(img_path)
  list_txt.append(caption)

dataset = image_title_dataset(list_image_path, list_txt)
train_dataloader = DataLoader(dataset, batch_size=100, shuffle=True) 

# Function to convert model's parameters to FP32 format
#转精度省内存.
def convert_models_to_fp32(model): 
    for p in model.parameters(): 
        p.data = p.data.float() 
        p.grad.data = p.grad.data.float() 

if device == "cpu":
    model.float()  # Convert the model's parameters to float if using CPU

optimizer用Adam,参数按paper中的设置.
不过博主的机器容纳不了这么大的batch_size, 具体batch_size设多少合适,需要自己去验证。

在这里插入图片描述
由于数据集比较小,lr设得更小一些。

optimizer = torch.optim.Adam(
    model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6 ,weight_decay=0.2) 

训练

paper中的训练是这样的
在这里插入图片描述

    for epoch in range(num_epochs):
        pbar = tqdm(train_dataloader, total=len(train_dataloader))
        for batch in pbar:
            optimizer.zero_grad()

            images, texts = batch

            images = images.to(device)
            texts = texts.to(device)

            logits_per_image, logits_per_text = model(images, texts)

            ground_truth = torch.arange(len(images), dtype=torch.long, device=device)

            total_loss = (loss_img(logits_per_image, ground_truth) + loss_txt(logits_per_text, ground_truth)) / 2
            total_loss.backward()
            if device == "cpu":
                optimizer.step()
            else:
                convert_models_to_fp32(model)
                optimizer.step()
                clip.model.convert_weights(model)

            pbar.set_description(f"Epoch {epoch}/{num_epochs}, Loss: {total_loss.item():.4f}")
            if torch.isnan(total_loss).any():
                print("epoch {} loss is NaN".format(epoch))
                epoch = num_epochs
                break

训练中,遇到了这些问题:
loss出现了NaN, 调整batch_size能解决,batch_size不要太小。
loss降不下去了,看看paper中的参数,有哪些需要调整。

训练完之后,找来一张图片测试。
这里又有一些注意事项,
请看paper.
因为训练的时候是图片和一段文本描述匹配的,而不是图片和一个单词。
所以你做零样本分类时,类别文本最好不要只写一个单词,比如只写"Saree"。
你要写"A photo of Saree", 这就成了一个句子,效果就会好一些。

在这里插入图片描述

model, preprocess = clip.load("ViT-B/32", device=device)

checkpoint = torch.load("model.pt")
model.load_state_dict(checkpoint['model_state_dict'])

clothing_items = [
    "Saree",
    "Lehenga",
    "Women Kurta",
    "Dupatta",
    "Gown",
    "Nehru Jacket",
    "Sherwani",
    "Men Kurta",
    "Men Mojari",
    "Leggings and Salwar",
    "Blouse",
    "Palazzo",
    "Dhoti Pants",
    "Petticoat",
    "Women Mojari"
]

这里你可能要问,那json文件里面的标签不是这么写的,比如"Women Kurta",json文件的标签是"women_kurta",
为什么不写成"women_kurta"。
这个博主是测试过的,写成json文件里面的标签形式准确率会降低,可能是因为"Women Kurta"更接近自然语言,更贴合训练数据吧。

把15个类别的标签都写成"A photo of {label}" 进行测试。

#你想测的第几张图片
index_ = 500
image_json = input_data[index_]
image_path = os.path.join("indo-fashion-dataset", image_json['image_path'])
image_class = image_json['class_label']
image = preprocess(Image.open(image_path)).unsqueeze(0).to(device)
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in clothing_items]).to(device)

with torch.no_grad():
    # Encode image and text
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)
    
    # Calculate similarity scores between image and text
    logits_per_image, logits_per_text = model(image, text)
    probs = logits_per_image.softmax(dim=-1).cpu().numpy()

# Normalize image and text features
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)

# Calculate similarity scores
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the top predictions
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{clothing_items[index]:>16s}: {100 * value.item():.2f}%")

# Display the image with its class label
plt.imshow(plt.imread(image_path))
plt.title(f"Image for class: {image_class}")
plt.axis('off')
plt.show()

请添加图片描述
请添加图片描述

训练中并没有精调参数,也没有训练很多epoch. 效果如下。
统计了一下测试集中7450张图片的top1和top3准确率。
top1: 77.7%, top3: 93.57%

请添加图片描述

paper中说CLIP 模型的 Top-5 准确率明显高于其 Top-1 准确率, 本文虽测的是top3, 但也是明显高于top1的。

在这里插入图片描述

又试了一下这种方法,这里效果并没有变好。

在这里插入图片描述

参考资料1
参考资料2

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/572180.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

AI Agent新对决:LangGraph与AutoGen的技术角力

AI Agent变革未来,LangGraph对抗AutoGen ©作者|Blaze 来源|神州问学 引言 比尔.盖茨曾在他的博客上发表一篇文章:《AI is about to completely change how you use computers》。在文章中,比尔盖茨探讨AI Agent对我们未来生活的巨大影…

目标检测YOLO数据集的三种格式及转换

目标检测YOLO数据集的三种格式 在目标检测领域,YOLO(You Only Look Once)算法是一个流行的选择。为了训练和测试YOLO模型,需要将数据集格式化为YOLO可以识别的格式。以下是三种常见的YOLO数据集格式及其特点和转换方法。 1. YOL…

RabbitMQ, DelayQueue, Redis的介绍以及IDEA的实现

RabbitMQ RabbitMQ是一个开源的消息队列中间件,它实现了高效、可靠的消息传递机制。它支持多种消息传递模式,如发布/订阅、点对点、请求/回应等。RabbitMQ以其可靠性、灵活性和易用性受到广泛的关注和应用。 RabbitMQ基于AMQP(Advanced Mess…

【禅道客户案例】专访鸿泉物联研发副总监徐小倩,感受上市公司研发项目管理“知与行”

杭州鸿泉物联网技术股份有限公司(以下简称“鸿泉物联”、“公司”)成立于2009年6月11日,2019年11月6日登陆上海证券交易所科创板(股票代码:688288),注册资本10034.392万元,目前员工6…

「案例分享」DevExpress XAF (WinForms UI)赋能医疗管理系统,让操作更自动化!

DevExpress XAF是一款强大的现代应用程序框架,它采用模块化设计,开发人员可以选择内建模块,也可以自行创建,从而以更快的速度和比开发人员当前更强有力的方式创建应用程序。 获取DevExpress 新版正式版下载(Q技术交流&#xff1a…

CTFshow-PWN-栈溢出(pwn44)

64位的 system(); 但是好像没"/bin/sh" 上面的办法不行了,想想办法 检查: 是 64 位程序 ida 反编译 main 函数: 跟进 ctfshow 函数: 存在栈溢出 offset:0xAh8 在前面经验的基础上,这里我们直…

Redis第14讲——Redis实现分布式锁(Redission源码解析)

在多线程环境下,为了保证数据的线程安全,我们通常用加锁的方式,使同一时刻只有一个线程可以对这个共享资源进行操作,在单服务系统我们常用JVM锁——Synchronized、ReentrantLock等。然而在多台服务系统的情况下,JVM锁就…

数字科技助力垃圾分类展厅,增强内容交互新体验!

如今,许多行业都开始运用数字技术,探索其在展览展示领域中的应用,其中垃圾分类展厅作为现代城市文明建设的重要一环,也通过这些技术的运用,打造出了更加生动且富有科技感的展示空间,它不仅提升公众对垃圾分…

原生微信小程序中案例--仿boss区域树选择列多选功能

1. 需求描述: 区域三级列表, 有添加,编辑,删除功能。 选择父级分类,其下子类全部选中,当前分类后加标志显示全字样取消选中子类,其父类分类后标志显示选中数量若子类全部选中,除当…

神经网络鸢尾花分类

⚠申明: 未经许可,禁止以任何形式转载,若要引用,请标注链接地址。 全文共计3077字,阅读大概需要3分钟 🌈更多学习内容, 欢迎👏关注👀【文末】我的个人微信公众号&#xf…

【超简单实用】Zotero 7 内置pdf背景颜色更改插推荐以及安装

Zotero beta7 pdf 内置颜色更换 zetore 6 很多成熟的插件在 zetore 7都不能用了。版本回退看起来内置文章的注释会被消除,所以又不想退回去。前几个月在找beta 7 的pdf 护眼色的插件一直没有,今天终于发现了!!!&#…

《架构风清扬-Java面试系列第26讲》聊聊的LinkedBlockingQueue的特点及使用场景

LinkedBlockingQueue也是BlockingQueue接口的一个实现类之一 这个属于基础性问题,老规矩,我们将从使用场景和代码示例来进行讲解 来,思考片刻,给出你的答案 1,使用场景 实现:基于链表实现的阻塞队列&#…

Django5框架之多重继承

在Django模型中也支持使用多重继承,这点与Python语法中的继承是一致的。Django模型多重继承就是同时继承多个父类模型,父类中第一个出现的基类(如:Meta类)是默认被使用的。如果存在多个父类包含Meta类的情况&#xff0…

Redis - Set 集合

前言 集合类型可保存多个字符串类型的元素,但和列表类型不同的是,集合中的元素之间是⽆序的(顺序不重要,变换一下集合中的数据顺序,集合不会发生改变) 的并且元素不允许重复 ⼀个集合中最多可以存储 2^32-1…

echarts 堆叠柱状图 顶部添加合计

堆叠有3个,后面加了一个对象显示顶部的数据, 其实主要的代码还是在series 的第四项,需要注意的是 series的第四项中的data需要为 data: [0, 0, 0] 顶部的统计才能显示出来 增加的代码如下 {name: 综合,type: bar,stack: total,label: {sh…

基于springboot+vue+Mysql的篮球竞赛预约平台

开发语言:Java框架:springbootJDK版本:JDK1.8服务器:tomcat7数据库:mysql 5.7(一定要5.7版本)数据库工具:Navicat11开发软件:eclipse/myeclipse/ideaMaven包:…

STM32 HAL库F103系列之DAC实验(二)

DAC输出正弦波实验 实验简要 1,功能描述 通过DAC1通道1(PA4)输出正弦波,然后通过DS100示波器查看波形 2,使用定时器7 TRGO事件触发转换 TEN1位置1、TSEL1[2:0]010 3,关闭输出缓冲 BOFF1位置1 4,使用DMA模式 DMAE…

无人机+遥控器:工业级手持地面站(支持安卓系统)功能技术详解

手持地面站是一种专为无人机设计的便携式设备,用于实现飞行控制、任务规划、数据链路通信等功能。由于支持安卓系统,这种地面站设备在软件生态上具有极大的灵活性,能够兼容并运行众多基于安卓平台的无人机控制应用程序。 在硬件方面&#xff…

vue中的mixin(局部混入、全局混入)

一、mixin是什么 Mixin是面向对象程序设计语言中的类,提供了方法的实现。其他类可以访问mixin类的方法而不必成为其子类;Mixin类通常作为功能模块使用,在需要该功能时“混入”,有利于代码复用又避免了多继承的复杂 Vue中的mixin…

如何在 Ubuntu 14.04 上配置 StatsD 以收集 Graphite 的任意统计数据

介绍 Graphite 是一个图形库,允许您以灵活和强大的方式可视化不同类型的数据。它通过其他统计收集应用程序发送给它的数据进行图形化。 在之前的指南中,我们讨论了如何安装和配置 Graphite 本身,以及如何安装和配置 collectd 以编译系统和服…
最新文章