RSSC-LLAVA:基于XTuner用遥感数据微调LLAVA模型-程序员宅基地

技术标签: 人工智能  

RSSC-LLAVA:基于XTuner用遥感数据微调LLAVA模型

项目介绍:

基于UCM场景分类数据集构建简单的对话文本,随后利用Xtuner微调LLAVA模型,实现遥感图文问答,主要是场景分类。
代码分享:https://github.com/biscuit279/RSSC-LLAVA

实现步骤

1.数据准备

下载UCM场景分类数据集,包含21个类别,每个类别有100张图片
LAVVA模型的微调分为两步,第一步是做文本和图像的特征对齐,第二步是图文问答
第一步的时候需要准备两个json文件:

1.图像的文本描述,原作者用GPT生成的数据集,场景分类的话可以用脚本模拟一套

在这里插入图片描述
使用GPT3.5生成代码,prompt如下:

User
帮我写一段python代码:输入是一个Images文件夹,Images包含多个子文件夹,子文件夹的名称是类别名,子文件夹中装有多张图片。输出是一个json文件,格式如下:
[
  {
    "id": "004539375",
    "image": "00453/004539375.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "Render a clear and concise summary of the photo.\n<image>"
      },
      {
        "from": "gpt",
        "value": "select luxury furniture 3 - inch gel memory foam mattress topper"
      }
    ]
  },
  {
    "id": "002239345",
    "image": "00223/002239345.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "Write a terse but informative summary of the picture.\n<image>"
      },
      {
        "from": "gpt",
        "value": "a grey watch with an army style strap"
      }
    ]
  }
id字段是图片的名称,image字段是文件夹名称/图片名称,conversion字段的from不变,第一个value改成"What is it",第二个value改成"This is a photo of {class}"

生成的代码为

import os
import json

def process_images(input_folder):
    output_data = []
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.endswith('.jpg'):
                image_id = os.path.splitext(file)[0]
                image_path = os.path.join(root, file)
                class_name = os.path.basename(root)
                
                conversation_human = {
                    "from": "human",
                    "value": "Write a terse but informative summary of the picture.\n<image>"
                }
                
                conversation_gpt = {
                    "from": "gpt",
                    "value": f"This is a photo of {class_name}"
                }
                
                output_data.append({
                    "id": image_id,
                    "image": f"{class_name[:5]}/{image_id}.jpg",
                    "conversations": [conversation_human, conversation_gpt]
                })
    return output_data

def save_to_json(data, output_file):
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2)

if __name__ == "__main__":
    input_folder = "Images"
    output_file = "output.json"

    images_data = process_images(input_folder)
    save_to_json(images_data, output_file)

相关脚本已上传至github
2.meta数据集,记录图片的基本信息,主要是名称,URL,以及blip生成的caption
在这里插入图片描述
这个可能用不到,暂时不管

2.修改代码

主要修改config文件中的数据路径以及evaluation案例
随便选了一个数据集中的图片作为evaluation案例,未修改evaluattion_inputs
在这里插入图片描述
在这里插入图片描述

3.环境安装

从github下载xtuner项目并安装

git clone https://github.com/InternLM/xtuner.git
pip install -e '.[all]'

4.pretrain

修改好了以后,运行xtuner train llava_internlm_chat_7b_clip_vit_large_p14_336_e1_gpu8_pretrain --deepspeed deepspeed_zero2
会自动下载internlm7b模型
但是遇到了一个报错,似乎是网络问题
在这里插入图片描述
再次运行相同的指令,竟然无法复现这个错误,变成了一堆莫名其妙的提示
在这里插入图片描述
未完待续。。

2.4更新:
上次的报错其实就是因为那个warning,pandas包没装好,卸载重装即可解决

pip uninstall pandas
pip install pandas

注意在训练时一定要加` --deepspeed deepspeed_zero2``否则会报数据类型不匹配的错误
成功运行:
在这里插入图片描述
2100个图片文本对,一个epoch只需要5分钟左右,占31980m显存

换成56k,70个类别的数据集,用同样的脚本生成json,再次运行

5.finetune

finetune阶段每张图片需要有五组对应的问答,将生成pretrain数据的代码稍作修改,添加其他几种template即可。
总共选择11种问题的模板,9种回答的模板,每组对话的QA都是从模板中随机挑选的。
注意第一组QA中应该有"\n“

import os
import json
import random

Q_templates = ["Describe the image concisely.",
"Provide a brief description of the given image.",
"Offer a succinct explanation of the picture presented.",
"Summarize the visual content of the image.",
"Give a short and clear explanation of the subsequent image.",
"Share a concise interpretation of the image provided.",
"Present a compact description of the photo's key features.",
"Relay a brief, clear account of the picture shown.",
"Render a clear and concise summary of the photo.",
"Write a terse but informative summary of the picture.",
"Create a compact narrative representing the image presented."]

A_templates = ["This is a photo of a {}.",
"This is a satellite image of a {}. ",
"This is a land use image of a {}. ",
"This is a remote sensing image of a {}.", 
"Here is an aerial picture depicting {}.",
"Displayed is an aerial photo illustrating {}.",
"This image captures the aerial perspective of {}.",
"Presented is an aerial view of {}.",
"This picture shows {} from an aerial vantage point."]


def process_images(input_folder):
    output_data = []
    for root, dirs, files in os.walk(input_folder):
        for file in files:
            if file.endswith('.jpg'):
                image_id = os.path.splitext(file)[0]
                image_path = os.path.join(root, file)
                class_name = os.path.basename(root)
                
                conversation_human = {
                    "from": "human",
                    "value": "Write a terse but informative summary of the picture./n<image>"
                }
                
                conversation_gpt = {
                    "from": "gpt",
                    "value": f"This is an aerial image of {class_name}."
                }
                
                conversations = [conversation_human, conversation_gpt]
                Q_samples = random.sample(Q_templates, 4)
                A_samples = random.sample(A_templates, 4)

                for i in range(4):
                    conversation_human = {
                    "from": "human",
                    "value": f"{Q_samples[i]}"
                    }
                    conversation_gpt = {
                    "from": "gpt",
                    "value": f"{A_samples[i]}".format(class_name)
                    
                    }
                    # conversation_gpt['value'].replace('{class_name}', class_name)
                    # import ipdb;ipdb.set_trace()
                    conversations.append(conversation_human)
                    conversations.append(conversation_gpt)

                output_data.append({
                    "id": image_id,
                    "image": f"{image_id}.jpg",
                    "conversations": conversations
                })

    return output_data

def save_to_json(data, output_file):
    with open(output_file, 'w') as f:
        json.dump(data, f, indent=2)

得到的数据例如:

[
  {
    "id": "airplane1",
    "image": "airplane1.jpg",
    "conversations": [
      {
        "from": "human",
        "value": "Write a terse but informative summary of the picture./n<image>"
      },
      {
        "from": "gpt",
        "value": "This is an aerial image of airplane."
      },
      {
        "from": "human",
        "value": "Describe the image concisely."
      },
      {
        "from": "gpt",
        "value": "Presented is an aerial view of airplane."
      },
      {
        "from": "human",
        "value": "Provide a brief description of the given image."
      },
      {
        "from": "gpt",
        "value": "Displayed is an aerial photo illustrating airplane."
      },
      {
        "from": "human",
        "value": "Offer a succinct explanation of the picture presented."
      },
      {
        "from": "gpt",
        "value": "This is a land use image of a airplane. "
      },
      {
        "from": "human",
        "value": "Summarize the visual content of the image."
      },
      {
        "from": "gpt",
        "value": "This is a photo of a airplane."
      }
    ]
  }
  ]

将生成好的数据按以下格式组织:
在这里插入图片描述
随后修改finetune阶段的配置文件llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune.py
仍然是需要修改数据、模型的路径和测试图片的位置
随后运行NPROC_PER_NODE=8 xtuner train llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune --deepspeed deepspeed_zero2
将8改成实际的显卡数量
这一步使用qlora方法同时微调llm和vit,将会得到两个adapter

发现同一个工作空间,运行finetune的时候还需要重新下载一遍internlm7b模型,猜测应该是下载完之后直接加载,加载结束后就删除了本地文件。更正:默认会下载到/root/.cache/huggingface/hub/下
可以使用以下代码手动下载文件,这样下次运行就不需要重复下载了。

cd ~/RSSC
apt install git git-lfs -y
git lfs install
git lfs clone https://modelscope.cn/Shanghai_AI_Laboratory/internlm-chat-7b.git -b v1.0.3

如果是手动下载的,需要修改模型位置参数llm_name_or_path为模型存放的路径
在这里插入图片描述
但手动下载可能会报一个KeyError,可能是internlm模型的版本问题,暂时未找到合适的解决方案

仍然采用自动下载的方式。8卡A800跑56k张图片数据要一个小时
在这里插入图片描述

6.模型合并与部署

转换成huggingface格式:

xtuner convert pth_to_hf llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune ./work_dirs/llava_internlm_chat_7b_qlora_clip_vit_large_p14_336_lora_e1_gpu8_finetune/iter_875.pth/ ./work_dirs/iter_875_hf

可以把config文件的指定成本地py文件,然后修改py文件中的模型路径,从而避免重复缓存模型
格式转换完成后,将会得到llm_adapter,projector,visual_encoder_adapter,可以分别与llm和vit合并
在这里插入图片描述

转换完成后,就可以进行对话了,分别输入llm模型,视觉模型,hf格式的llava模型,以及图片的路径,即可开始对话
xtuner chat internlm/internlm-chat-7b
–visual-encoder openai/clip-vit-large-patch14-336
–llava xtuner/llava-internlm-7b
–prompt-template internlm_chat
–image $IMAGE_PATH

发现用不同的问题,都可以输出正确的场景分类结果。
在这里插入图片描述

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_46212981/article/details/135972166

智能推荐

FTP命令字和返回码_ftp 登录返回230-程序员宅基地

文章浏览阅读3.5k次,点赞2次,收藏13次。为了从FTP服务器下载文件,需要要实现一个简单的FTP客户端。FTP(文件传输协议) 是 TCP/IP 协议组中的应用层协议。FTP协议使用字符串格式命令字,每条命令都是一行字符串,以“\r\n”结尾。客户端发送格式是:命令+空格+参数+"\r\n"的格式服务器返回格式是以:状态码+空格+提示字符串+"\r\n"的格式,代码只要解析状态码就可以了。读写文件需要登陆服务器,特殊用..._ftp 登录返回230

centos7安装rabbitmq3.6.5_centos7 安装rabbitmq3.6.5-程序员宅基地

文章浏览阅读648次。前提:systemctl stop firewalld 关闭防火墙关闭selinux查看getenforce临时关闭setenforce 0永久关闭sed-i'/SELINUX/s/enforcing/disabled/'/etc/selinux/configselinux的三种模式enforcing:强制模式,SELinux 运作中,且已经正确的开始限制..._centos7 安装rabbitmq3.6.5

idea导入android工程,idea怎样导入Android studio 项目?-程序员宅基地

文章浏览阅读5.8k次。满意答案s55f2avsx2017.09.05采纳率:46%等级:12已帮助:5646人新版Android Studio/IntelliJ IDEA可以直接导入eclipse项目,不再推荐使用eclipse导出gradle的方式2启动Android Studio/IntelliJ IDEA,选择 import project3选择eclipse 项目4选择 create project f..._android studio 项目导入idea 看不懂安卓项目

浅谈AI大模型技术:概念、发展和应用_ai大模型应用开发-程序员宅基地

文章浏览阅读860次,点赞2次,收藏6次。AI大模型技术已经在自然语言处理、计算机视觉、多模态交互等领域取得了显著的进展和成果,同时也引发了一系列新的挑战和问题,如数据质量、计算效率、知识可解释性、安全可靠性等。城市运维涉及到多个方面,如交通管理、环境监测、公共安全、社会治理等,它们需要处理和分析大量的多模态数据,如图像、视频、语音、文本等,并根据不同的场景和需求,提供合适的决策和响应。知识搜索有多种形式,如语义搜索、对话搜索、图像搜索、视频搜索等,它们可以根据用户的输入和意图,从海量的数据源中检索出最相关的信息,并以友好的方式呈现给用户。_ai大模型应用开发

非常详细的阻抗测试基础知识_阻抗实部和虚部-程序员宅基地

文章浏览阅读8.2k次,点赞12次,收藏121次。为什么要测量阻抗呢?阻抗能代表什么?阻抗测量的注意事项... ...很多人可能会带着一系列的问题来阅读本文。不管是数字电路工程师还是射频工程师,都在关注各类器件的阻抗,本文非常值得一读。全文13000多字,认真读完大概需要2小时。一、阻抗测试基本概念阻抗定义:阻抗是元器件或电路对周期的交流信号的总的反作用。AC 交流测试信号 (幅度和频率)。包括实部和虚部。​图1 阻抗的定义阻抗是评测电路、元件以及制作元件材料的重要参数。那么什么是阻抗呢?让我们先来看一下阻抗的定义。首先阻抗是一个矢量。通常,阻抗是_阻抗实部和虚部

小学生python游戏编程arcade----基本知识1_arcade语言 like-程序员宅基地

文章浏览阅读955次。前面章节分享试用了pyzero,pygame但随着想增加更丰富的游戏内容,好多还要进行自己编写类,从今天开始解绍一个新的python游戏库arcade模块。通过此次的《连连看》游戏实现,让我对swing的相关知识有了进一步的了解,对java这门语言也有了比以前更深刻的认识。java的一些基本语法,比如数据类型、运算符、程序流程控制和数组等,理解更加透彻。java最核心的核心就是面向对象思想,对于这一个概念,终于悟到了一些。_arcade语言 like

随便推点

【增强版短视频去水印源码】去水印微信小程序+去水印软件源码_去水印机要增强版-程序员宅基地

文章浏览阅读1.1k次。源码简介与安装说明:2021增强版短视频去水印源码 去水印微信小程序源码网站 去水印软件源码安装环境(需要材料):备案域名–服务器安装宝塔-安装 Nginx 或者 Apachephp5.6 以上-安装 sg11 插件小程序已自带解析接口,支持全网主流短视频平台,搭建好了就能用注:接口是公益的,那么多人用解析慢是肯定的,前段和后端源码已经打包,上传服务器之后在配置文件修改数据库密码。然后输入自己的域名,进入后台,创建小程序,输入自己的小程序配置即可安装说明:上传源码,修改data/_去水印机要增强版

verilog进阶语法-触发器原语_fdre #(.init(1'b0) // initial value of register (1-程序员宅基地

文章浏览阅读557次。1. 触发器是FPGA存储数据的基本单元2. 触发器作为时序逻辑的基本元件,官方提供了丰富的配置方式,以适应各种可能的应用场景。_fdre #(.init(1'b0) // initial value of register (1'b0 or 1'b1) ) fdce_osc (

嵌入式面试/笔试C相关总结_嵌入式面试笔试c语言知识点-程序员宅基地

文章浏览阅读560次。本该是不同编译器结果不同,但是尝试了g++ msvc都是先计算c,再计算b,最后得到a+b+c是经过赋值以后的b和c参与计算而不是6。由上表可知,将q复制到p数组可以表示为:*p++=*q++,*优先级高,先取到对应q数组的值,然后两个++都是在后面,该行运算完后执行++。在电脑端编译完后会分为text data bss三种,其中text为可执行程序,data为初始化过的ro+rw变量,bss为未初始化或初始化为0变量。_嵌入式面试笔试c语言知识点

57 Things I've Learned Founding 3 Tech Companies_mature-程序员宅基地

文章浏览阅读2.3k次。57 Things I've Learned Founding 3 Tech CompaniesJason Goldberg, Betashop | Oct. 29, 2010, 1:29 PMI’ve been founding andhelping run techn_mature

一个脚本搞定文件合并去重,大数据处理,可以合并几个G以上的文件_python 超大文本合并-程序员宅基地

文章浏览阅读1.9k次。问题:先讲下需求,有若干个文本文件(txt或者csv文件等),每行代表一条数据,现在希望能合并成 1 个文本文件,且需要去除重复行。分析:一向奉行简单原则,如无必要,绝不复杂。如果数据量不大,那么如下两条命令就可以搞定合并:cat a.txt >> new.txtcat b.txt >> new.txt……去重:cat new...._python 超大文本合并

支付宝小程序iOS端过渡页DFLoadingPageRootController分析_类似支付宝页面过度加载页-程序员宅基地

文章浏览阅读489次。这个过渡页是第一次打开小程序展示的,点击某个小程序前把手机的开发者->network link conditioner->enable & very bad network 就会在停在此页。比如《支付宝运动》这个小程序先看这个类的.h可以看到它继承于DTViewController点击左上角返回的方法- (void)back;#import "DTViewController.h"#import "APBaseLoadingV..._类似支付宝页面过度加载页

推荐文章

热门文章

相关标签