반응형
전체를 바꾸지 말고, 조금만 살짝 바꿔서 똑똑하게 만들자!
AI 모델은 원래 엄청나게 많은 숫자(파라미터)를 가지고 있어. 이걸 다 바꾸려면 시간이 오래 걸리고 컴퓨터 리소스도 많이 사용해야 한다.
그래서 LoRA는 이를 최소한으로 수정해서 최대의 효과를 내고자 하는 방법이라고 할 수 있다.
원래 LLM의 파인튜닝 에 대한 기본 수식은
h = x × W
결과=입력×W
x: 입력값 (예: "고양이 사진")
W: 원래 모델이 가진 숫자들 (무게라고도 해)
h: 결과값 (예: "이건 고양이야!")
LoRA를 쓰면 이렇게 바뀐다.
h = x × W + x × A × B
결과=입력×W+입력×A×B
W: 원래 고양이를 잘하는 AI의 지식
A, B: 강아지를 배운 작은 메모
LoRA의 장점
💾 메모리 절약 전체 모델이 아닌 일부 파라미터만 학습
⚡ 빠른 학습 적은 연산량으로 빠르게 파인튜닝 가능
🔄 재사용성 여러 태스크에 쉽게 적용 가능
🧩 모듈화 기존 모델 구조를 거의 변경하지 않음
# Comparing Standard Fine-Tuning and LoRA Fine-Tuning on SST-2 Dataset
In this notebook, we will compare the performance and training time of standard fine-tuning and LoRA fine-tuning using the 'distilbert-base-uncased' model on the SST-2 dataset.
We will use the Hugging Face Transformers, Datasets, and PEFT libraries for this task.
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType
import time
import numpy as np
import matplotlib.pyplot as plt
## Load Dataset
We will use the SST-2 dataset from the GLUE benchmark.
# Load dataset
dataset = load_dataset("glue", "sst2")
## Load Tokenizer
We will use the tokenizer from the 'distilbert-base-uncased' model.
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
## Tokenize Dataset
We will tokenize the dataset using the loaded tokenizer.
# Tokenize dataset
def tokenize_function(examples):
return tokenizer(examples["sentence"], padding="max_length", truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
## Load Model
We will load the 'distilbert-base-uncased' model for sequence classification.
# Load model
model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
## Define Training Arguments
We will define the training arguments for both standard fine-tuning and LoRA fine-tuning.
# Define training arguments
training_args = TrainingArguments(
output_dir="./results",
evaluation_strategy="epoch",
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
num_train_epochs=1,
weight_decay=0.01,
logging_dir="./logs",
logging_steps=10,
no_cuda=True # Use CPU
)
## Standard Fine-Tuning
We will perform standard fine-tuning on the model.
# Standard fine-tuning
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"]
)
start_time = time.time()
trainer.train()
standard_training_time = time.time() - start_time
# Evaluate standard model
standard_eval_results = trainer.evaluate()
## LoRA Fine-Tuning
We will perform LoRA fine-tuning on the model.
# LoRA fine-tuning
lora_config = LoraConfig(
r=8,
lora_alpha=32,
target_modules=["query", "value"],
lora_dropout=0.1,
bias="none",
task_type=TaskType.SEQ_CLS
)
lora_model = get_peft_model(model, lora_config)
lora_trainer = Trainer(
model=lora_model,
args=training_args,
train_dataset=tokenized_datasets["train"],
eval_dataset=tokenized_datasets["validation"]
)
start_time = time.time()
lora_trainer.train()
lora_training_time = time.time() - start_time
# Evaluate LoRA model
lora_eval_results = lora_trainer.evaluate()
## Training Time Comparison
We will compare the training time of standard fine-tuning and LoRA fine-tuning.
# Plot training time comparison
labels = ['Standard Fine-Tuning', 'LoRA Fine-Tuning']
training_times = [standard_training_time, lora_training_time]
x = np.arange(len(labels))
width = 0.35
fig, ax = plt.subplots()
rects1 = ax.bar(x - width/2, training_times, width, label='Training Time (s)')
ax.set_ylabel('Time (seconds)')
ax.set_title('Training Time Comparison')
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.legend()
fig.tight_layout()
plt.show()
## Evaluation Results
We will print the evaluation results of both standard fine-tuning and LoRA fine-tuning.
# Print evaluation results
print("Standard Fine-Tuning Evaluation Results:", standard_eval_results)
print("LoRA Fine-Tuning Evaluation Results:", lora_eval_results)
반응형
'Bigdata' 카테고리의 다른 글
LangChain Messages 역활 (0) | 2025.06.10 |
---|---|
머신러닝 - 결정 트리(DecisionTree) 알고리즘 핵심 정리 (0) | 2024.12.29 |
머신러닝 - 로지스틱 회귀 분류 알고리즘 이해, 시그모이드 함수 (1) | 2024.12.28 |
머신러닝 - 데이터셋 표준화 (0) | 2024.12.27 |
머신러닝 - 선형 회귀 핵심 정리 (0) | 2024.12.27 |