星辰大海 AI星辰大海 AI
首页
  • ChatGPT
  • Claude
  • Midjourney
  • Stable Diffusion
  • 大语言模型
  • 图像生成模型
  • 语音模型
Demo 示例
开发笔记
GitHub
首页
  • ChatGPT
  • Claude
  • Midjourney
  • Stable Diffusion
  • 大语言模型
  • 图像生成模型
  • 语音模型
Demo 示例
开发笔记
GitHub
  • 开发笔记

    • 开发笔记
    • Prompt 工程笔记
    • 模型微调笔记
    • API 使用笔记

模型微调笔记

什么是模型微调

模型微调(Fine-tuning)是在预训练模型的基础上,使用特定领域的数据进行进一步训练,使模型适应特定任务的过程。

微调方法

1. 全参数微调(Full Fine-tuning)

训练模型的所有参数:

优点:

  • 效果最好
  • 完全适应任务

缺点:

  • 需要大量显存
  • 训练时间长
  • 容易过拟合

2. LoRA (Low-Rank Adaptation)

只训练少量低秩矩阵:

优点:

  • 显存需求低
  • 训练速度快
  • 可以保存多个 LoRA 权重

缺点:

  • 效果可能略低于全参数微调

3. QLoRA

量化 + LoRA,进一步降低显存需求:

优点:

  • 显存需求极低
  • 可以在消费级 GPU 上运行

缺点:

  • 可能略微影响精度

LoRA 微调实践

环境准备

pip install peft transformers datasets accelerate bitsandbytes

基础 LoRA 微调

from peft import LoraConfig, get_peft_model, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer

# 加载模型
model_name = "meta-llama/Llama-2-7b-chat-hf"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    device_map="auto"
)

# 配置 LoRA
lora_config = LoraConfig(
    r=16,  # 低秩维度
    lora_alpha=32,  # 缩放参数
    target_modules=["q_proj", "v_proj"],  # 目标模块
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM
)

# 应用 LoRA
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

准备数据集

from datasets import load_dataset

# 加载数据集
dataset = load_dataset("your_dataset")

# 格式化数据
def format_prompt(example):
    return {
        "text": f"### 指令:\n{example['instruction']}\n\n### 回答:\n{example['output']}"
    }

dataset = dataset.map(format_prompt)

# Tokenization
def tokenize_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        max_length=512,
        padding="max_length"
    )

tokenized_dataset = dataset.map(tokenize_function, batched=True)

训练配置

from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="./lora_model",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=2e-4,
    fp16=True,
    logging_steps=10,
    save_steps=100,
    evaluation_strategy="steps",
    eval_steps=100,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
)

trainer.train()

保存和加载

# 保存
model.save_pretrained("./lora_model")

# 加载
from peft import PeftModel

base_model = AutoModelForCausalLM.from_pretrained(model_name)
lora_model = PeftModel.from_pretrained(base_model, "./lora_model")

QLoRA 微调

配置量化

from transformers import BitsAndBytesConfig

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=quantization_config,
    device_map="auto"
)

应用 LoRA

from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "v_proj"],
    lora_dropout=0.05,
)

model = get_peft_model(model, lora_config)

数据集准备

格式要求

{
    "instruction": "任务指令",
    "input": "输入内容(可选)",
    "output": "期望输出"
}

数据增强

def augment_data(example):
    # 同义词替换
    # 句式变换
    # 添加噪声
    return augmented_example

数据清洗

def clean_data(text):
    # 去除特殊字符
    # 统一格式
    # 过滤低质量数据
    return cleaned_text

训练技巧

1. 学习率调度

from transformers import get_linear_schedule_with_warmup

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=100,
    num_training_steps=1000
)

2. 梯度累积

training_args = TrainingArguments(
    gradient_accumulation_steps=4,  # 累积 4 步梯度
    per_device_train_batch_size=2,  # 实际 batch size = 2 * 4 = 8
)

3. 混合精度训练

training_args = TrainingArguments(
    fp16=True,  # 使用半精度
)

4. 检查点保存

training_args = TrainingArguments(
    save_steps=100,
    save_total_limit=3,  # 只保留最近 3 个检查点
)

评估方法

1. 困惑度(Perplexity)

from transformers import pipeline

pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
# 计算困惑度

2. 人工评估

  • 生成样本
  • 人工评分
  • 统计分析

3. 任务特定指标

  • 分类任务:准确率、F1
  • 生成任务:BLEU、ROUGE
  • 对话任务:响应相关性

常见问题

以下是我在模型微调过程中遇到的常见问题:

Q: 显存不足?

解决方案:

  1. 使用 QLoRA
  2. 减小 batch size
  3. 使用梯度累积
  4. 使用 CPU offload

Q: 过拟合?

解决方案:

  1. 增加数据量
  2. 使用数据增强
  3. 降低学习率
  4. 增加 dropout
  5. 早停(Early Stopping)

Q: 训练不收敛?

解决方案:

  1. 检查学习率
  2. 检查数据质量
  3. 检查模型初始化
  4. 使用学习率调度

Q: 效果不好?

解决方案:

  1. 增加训练数据
  2. 调整 LoRA 参数(r, alpha)
  3. 尝试不同的 target_modules
  4. 增加训练轮数

最佳实践

  1. 数据质量: 高质量的数据是关键
  2. 逐步调优: 从小数据集开始,逐步增加
  3. 参数选择: r=16-64, alpha=32-128 是常用范围
  4. 评估验证: 定期评估,避免过拟合
  5. 版本管理: 保存不同版本的模型和配置

工具推荐

  • PEFT: Hugging Face 的微调库
  • Axolotl: 微调框架
  • LLaMA Factory: 易用的微调工具

相关资源

  • PEFT 文档
  • LoRA 论文
  • QLoRA 论文
在 GitHub 上编辑此页
Prev
Prompt 工程笔记
Next
API 使用笔记