基于 llama-Factory 动手实践 Llama 全参数 SFT 和 LoRA SFT

news/2025/2/8 20:01:08 标签: llama, AIGC, chatgpt, 深度学习

llamaFactory_Llama__SFT__0">一、llama-Factory:你的 Llama 模型 SFT 工厂

llama-Factory 是一个开源的、用户友好的工具,专门用于对 Llama 系列模型进行微调。它提供了简洁的界面和强大的功能,让你无需复杂的代码编写,就能轻松完成 Llama 模型的 SFT 任务,无论是 全参数微调 还是 参数高效的 LoRA 微调llama-Factory 都能轻松搞定。

llama-Factory 的优势:

  • 易于上手: 简洁的命令行界面,配置简单,即使是新手也能快速上手。
  • 功能强大: 支持全参数 SFT 和 LoRA SFT,满足不同场景的需求。
  • 高效便捷: 集成了常用的训练技巧,优化训练流程,提高效率。
  • 开源免费: 开放源代码,可以自由使用和定制。

二、准备环境:搭建你的 SFT 工作站

在开始实践之前,我们需要先搭建好运行 llama-Factory 的环境。

1. 硬件需求:

  • GPU: 建议使用 NVIDIA GPU,显存越大越好,不同模型大小对显存要求不同 (见经验信息)
  • 内存: 至少 16GB 内存。
  • 硬盘: 足够的硬盘空间,用于存放模型、数据集和训练结果。

2. 软件环境:

  • Python: 建议使用 Python 3.8 或更高版本。
  • CUDA: 如果使用 NVIDIA GPU,需要安装 CUDA 驱动和 Toolkit。
  • PyTorch: llama-Factory 基于 PyTorch 框架,需要安装 PyTorch。

llamaFactory_27">3. 安装 llama-Factory:

使用 pip 命令即可轻松安装 llama-Factory:

pip install llama-factory

安装完成后,可以通过以下命令验证是否安装成功:

llama-factory --version

三、数据准备:打造你的专属 SFT 数据集

SFT 的效果很大程度上取决于数据集的质量。我们需要准备一个符合 (Instruction, Input, Output) 格式的数据集。

1. 数据格式:

llama-Factory 接受 JSON 格式的数据集。每个数据样本包含以下字段:

  • instruction (必须): 指令描述,例如 “Summarize the following article”。
  • input (可选): 输入内容,例如文章内容。
  • output (必须): 期望的输出结果,例如文章摘要。

数据示例 (JSON 格式):

[
  {
    "instruction": "Summarize the following article.",
    "input": "Large language models (LLMs) are a type of artificial intelligence (AI) algorithm that uses deep learning techniques and massively large data sets to understand, summarize, generate and predict new content. LLMs are based on a type of neural network architecture called a transformer network, and are pre-trained on massive quantities of text data.",
    "output": "Large language models (LLMs) are AI algorithms using deep learning and large datasets to understand, summarize, generate, and predict content. They are based on transformer networks and pre-trained on vast text data."
  },
  {
    "instruction": "Translate the following English sentence to Chinese.",
    "input": "Hello, world!",
    "output": "你好,世界!"
  },
  {
    "instruction": "Write a Python function to calculate the Fibonacci sequence.",
    "input": "",
    "output": "```python\ndef fibonacci(n):\n  if n <= 1:\n    return n\n  else:\n    return fibonacci(n-1) + fibonacci(n-2)\n```"
  }
]

将你的数据集保存为 data.json 文件。

2. 数据加载:

llama-Factory 可以直接加载 JSON 格式的数据集。在训练命令中,通过 --dataset 参数指定数据集路径即可。

经验信息:SFT 数据量要求

SFT 数据集的大小对微调效果至关重要。数据量不足可能导致模型过拟合到训练数据,泛化能力不足;数据量过大则会增加训练时间和成本。以下是一些经验性的数据量建议,仅供参考,实际情况可能因任务复杂度、模型大小和数据质量而异

  • LoRA SFT:

    • 简单任务 (例如:情感分类、关键词提取): 几百到几千条样本可能就足够获得不错的效果。
    • 中等任务 (例如:问答、摘要): 建议至少几千到几万条样本。
    • 复杂任务 (例如:代码生成、创意写作): 可能需要数万甚至数十万条样本才能获得理想的效果。
  • 全参数 SFT: 由于全参数 SFT 会更新模型的所有参数,通常需要更多的数据才能充分发挥其潜力,并避免灾难性遗忘 (Catastrophic Forgetting)。在数据量充足的情况下,全参数 SFT 理论上可以达到更好的效果。

更小的 Llama 模型 (例如 7B, 13B) 相对于更大的模型 (例如 70B) 而言,通常对数据量的需求会相对少一些。

数据质量比数据量更重要! 高质量、多样化的 SFT 数据集是获得良好微调效果的关键。

四、全参数 SFT 实战:火力全开,深度微调

全参数 SFT 指的是更新模型的所有参数,这种方法能够充分利用 SFT 数据集的知识,让模型更好地适应特定任务。但全参数 SFT 需要更多的计算资源和时间。

1. 训练命令:

使用 llama-Factory 进行全参数 SFT 的基本命令如下:

llama-factory train \
    --model_name_or_path <你的 Llama 模型路径> \
    --dataset <你的数据集路径> \
    --output_dir <模型保存路径> \
    --template default \
    --finetuning_type full \
    --lora_rank 0 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --learning_rate 2e-5 \
    --logging_steps 10 \
    --save_steps 100 \
    --fp16

参数解释:

  • --model_name_or_path: 必须,指定你的 Llama 模型的 Hugging Face 模型名称或本地路径。例如 meta-llama/Llama-2-7b-hf/path/to/your/llama-model
  • --dataset: 必须,指定你的 SFT 数据集 JSON 文件路径,例如 data.json
  • --output_dir: 必须,指定微调后模型保存的路径,例如 output/llama-full-sft
  • --template: 必须,指定 Prompt 模板,default 模板适用于大多数场景。
  • --finetuning_type: 必须,设置为 full,表示进行全参数 SFT。
  • --lora_rank: 设置为 0,禁用 LoRA。
  • --num_train_epochs: 训练轮数,根据数据集大小和收敛情况调整。
  • --per_device_train_batch_size: 每个 GPU 的 Batch Size,根据 GPU 显存调整。
  • --gradient_accumulation_steps: 梯度累积步数,用于增大有效 Batch Size。
  • --learning_rate: 学习率,通常设置为 2e-51e-5
  • --logging_steps: 日志打印间隔。
  • --save_steps: 模型保存间隔。
  • --fp16: 使用混合精度训练,加速训练并减少显存占用 (需要 GPU 支持 FP16)。

示例命令 (假设你使用 Llama-2-7b-hf 模型,数据集为 data.json,输出路径为 output/llama-full-sft):

llama-factory train \
    --model_name_or_path meta-llama/Llama-2-7b-hf \
    --dataset data.json \
    --output_dir output/llama-full-sft \
    --template default \
    --finetuning_type full \
    --lora_rank 0 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 4 \
    --gradient_accumulation_steps 4 \
    --learning_rate 2e-5 \
    --logging_steps 10 \
    --save_steps 100 \
    --fp16

2. 资源需求:

全参数 SFT 对 GPU 显存要求较高,不同 Llama 模型大小对显存需求差异很大

  • Llama 7B: 至少 16GB 显存 (FP16 混合精度)。
  • Llama 13B: 至少 32GB 显存 (FP16 混合精度)。
  • Llama 30B/33B: 至少 64GB 显存 (FP16 混合精度)。
  • Llama 65B/70B: 100GB 甚至更多显存 (FP16 混合精度),可能需要多卡并行训练。

经验信息:模型大小选择

模型越大,效果理论上越好,但也更吃资源。 在选择 Llama 模型大小时,需要权衡你的任务需求和硬件资源:

  • 资源有限 (例如:消费级 GPU): 建议选择 Llama 7B 或 13B,并使用 LoRA SFT。
  • 资源充足 (例如:专业级 GPU 或 GPU 集群): 可以尝试 Llama 30B/33B 甚至更大的模型,并考虑全参数 SFT 或 LoRA SFT。

对于大多数常见的 SFT 任务,Llama 7B 或 13B 模型配合 LoRA SFT 已经能够取得不错的效果。 更大模型通常在需要更强推理能力和知识储备的复杂任务中才能体现出优势。

五、LoRA SFT 实战:参数高效,轻量微调

LoRA (Low-Rank Adaptation) 是一种参数高效的微调方法,它冻结预训练模型的参数,只训练少量新增的 LoRA 适配器层。LoRA SFT 能够在资源有限的情况下,快速有效地对 Llama 模型进行微调。

1. 训练命令:

使用 llama-Factory 进行 LoRA SFT 的基本命令与全参数 SFT 类似,只需要修改 --finetuning_type--lora_rank 参数:

llama-factory train \
    --model_name_or_path <你的 Llama 模型路径> \
    --dataset <你的数据集路径> \
    --output_dir <模型保存路径> \
    --template default \
    --finetuning_type lora \
    --lora_rank 8 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 2 \
    --learning_rate 1e-4 \
    --logging_steps 10 \
    --save_steps 100 \
    --fp16

参数修改:

  • --finetuning_type: 设置为 lora,表示进行 LoRA SFT。
  • --lora_rank: 设置为 LoRA 的秩 (rank),通常设置为 816。秩越大,模型参数越多,效果可能更好,但也更耗资源。
  • --learning_rate: LoRA SFT 通常可以使用更大的学习率,例如 1e-42e-4
  • --per_device_train_batch_size: LoRA SFT 对显存要求较低,可以适当增大 Batch Size。

示例命令 (假设你使用 Llama-2-7b-hf 模型,数据集为 data.json,输出路径为 output/llama-lora-sft):

llama-factory train \
    --model_name_or_path meta-llama/Llama-2-7b-hf \
    --dataset data.json \
    --output_dir output/llama-lora-sft \
    --template default \
    --finetuning_type lora \
    --lora_rank 8 \
    --num_train_epochs 3 \
    --per_device_train_batch_size 8 \
    --gradient_accumulation_steps 2 \
    --learning_rate 1e-4 \
    --logging_steps 10 \
    --save_steps 100 \
    --fp16

2. 资源需求:

LoRA SFT 对 GPU 显存要求较低,即使是 8GB 显存的 GPU,例如 NVIDIA RTX 2080 Ti 或 RTX 3070,也可以进行 LoRA SFT。 对于 Llama 7B 和 13B 模型,LoRA SFT 通常可以在消费级 GPU 上流畅运行。

六、模型评估与推理

训练完成后,你可以在 llama-Factory 中加载微调后的模型进行评估和推理。

推理命令:

llama-factory gradio \
    --model_name_or_path <你的微调模型路径> \
    --template default

<你的微调模型路径> 替换为你全参数 SFT 或 LoRA SFT 的输出路径,例如 output/llama-full-sftoutput/llama-lora-sft

运行命令后,llama-Factory 会启动一个 Gradio Web UI,你可以在 Web 界面上与微调后的模型进行交互,测试模型的性能。

七、小结

选择哪种 SFT 方式?

  • 全参数 SFT: 效果更好,但资源需求高,训练时间长。适用于追求极致性能,且资源充足的场景。
  • LoRA SFT: 参数高效,资源需求低,训练速度快。适用于资源有限,需要快速迭代,或者需要部署在资源受限设备上的场景。

http://www.niftyadmin.cn/n/5845246.html

相关文章

【DeepSeek论文精读】3. DeepSeekMoE:迈向混合专家语言模型的终极专业化

欢迎关注[【AIGC论文精读】](https://blog.csdn.net/youcans/category_12321605.html&#xff09;原创作品 【DeepSeek论文精读】1. 从 DeepSeek LLM 到 DeepSeek R1 【DeepSeek论文精读】2. DeepSeek LLM&#xff1a;以长期主义扩展开源语言模型 【DeepSeek论文精读】3. DeepS…

使用DeepSeek的技巧笔记

来源&#xff1a;新年逼自己一把&#xff0c;学会使用DeepSeek R1_哔哩哔哩_bilibili 前言 对于DeepSeek而言&#xff0c;我们不再需要那么多的提示词技巧&#xff0c;但还是要有两个注意点&#xff1a;你需要理解大语言模型的工作原理与局限,这能帮助你更好的知道AI可完成任务…

Office/WPS接入DS等多个AI工具,开启办公新模式!

在现代职场中&#xff0c;Office办公套件已成为工作和学习的必备工具&#xff0c;其功能强大但复杂&#xff0c;熟练掌握需要系统的学习。为了简化操作&#xff0c;使每个人都能轻松使用各种功能&#xff0c;市场上涌现出各类办公插件。这些插件不仅提升了用户体验&#xff0c;…

全志A133 android10 thermal温控策略配置调试

一&#xff0c;功能介绍 Thermal简称热控制系统&#xff0c;其功能是通过temperature sensor&#xff08;温度传感器&#xff09;测量当前CPU、GPU等设备的温度值&#xff0c;然后根据此温度值&#xff0c;影响CPU、GPU等设备的调频策略&#xff0c;对CPU、GPU等设备的最大频率…

Python----Python高级(并发编程:协程Coroutines,事件循环,Task对象,协程间通信,协程同步,将协程分布到线程池/进程池中)

一、协程 1.1、协程 协程&#xff0c;Coroutines&#xff0c;也叫作纤程(Fiber) 协程&#xff0c;全称是“协同程序”&#xff0c;用来实现任务协作。是一种在线程中&#xff0c;比线程更加轻量级的存在&#xff0c;由程序员自己写程序来管理。 当出现IO阻塞时&#xff0c;…

了解 ALV 中的 field catalog (ABAP List Viewer)

在 ABAP 中&#xff0c;字段目录是使用 ALV &#xff08;ABAP List Viewer&#xff09; 定义内部表中的数据显示方式的关键元素。它提供对 ALV 中显示的字段的各种属性的控制&#xff0c;例如列标题、对齐方式、可见性、可编辑性等。关键概念&#xff1a; Field Catelog 字段目…

DeepSeek关联WPS使用指南与案例解析

在数字化办公时代&#xff0c;人工智能&#xff08;AI&#xff09;技术正深刻地改变着我们处理文档、分析数据和进行创意表达的方式。DeepSeek作为新兴的AI技术代表&#xff0c;与办公软件巨头WPS的结合&#xff0c;为用户带来了前所未有的高效办公体验。本教程将深入探讨如何将…

基于Flask的全国海底捞门店数据可视化分析系统的设计与实现

【FLask】基于Flask的全国海底捞门店数据可视化分析系统的设计与实现&#xff08;完整系统源码开发笔记详细部署教程&#xff09;✅ 目录 一、项目简介二、项目界面展示三、项目视频展示 一、项目简介 该系统系统采用Python语言结合Flask框架开发&#xff0c;利用Pandas、NumP…