LLM.int8() 개요
LLM.int8()은 Large Language Model (LLM)의 계산 성능을 개선하기 위한 8-bit 양자화 방법이다. 기존의 8-bit 양자화 방법은 성능 저하가 발생하는 문제점이 있었는데, LLM.int8()은 이를 해결하여 LLM의 성능을 유지하면서도 계산 성능을 크게 향상시킬 수 있다
LLM.int8()의 핵심 요소는 vector-wise quantization과 mixed-precision decomposition이다.
vector-wise quantization은 텐서 당 여러 개의 scaling constant를 사용하여 outlier의 영향력을 줄이는 방법이다.
mixed-precision decomposition은 0.1%의 outlier만 16-bit로 나타내어지고, 99.9%의 값들은 8-bit로 matmul 계산이 되는 방법으로 성능에 영향을 최소화 한다.
LLM.int8()은 bitsandbytes 라이브러리를 통해 구현할 수 있다. bitsandbytes는 transformers, accelerate 등 여러 다른 라이브러리를 통해서 쓸 수 있도록 되어있기 때문에 확장성도 좋다고 할 수 있으므로, 본인에게 편한 라이브러리를 사용해보자.
간단한 모델을 int8로 변환하기
bitsandbytes를 사용하여 간단한 모델을 int8로 변환하는 방법은 다음과 같다.
pip install torch bitsandbytes
필요한 라이브러리를 import한다.
import torch
import torch.nn as nn
import bitsandbytes as bnb
테스트를 위해 간단한 linear 모델을 정의하도록 하자.
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64)
)
정의했던 FP16 모델을 기반으로, Linear8bitLt을 활용하여 int8 모델을 재정의한다.
int8_model = nn.Sequential(
bnb.nn.Linear8bitLt(64, 64, has_fp16_weights=False),
bnb.nn.Linear8bitLt(64, 64, has_fp16_weights=False)
)
저장해둔 가중치를 int8 모델에 로드하면 끝이다.
int8_model.load_state_dict(torch.load("model.pt"))
int8_model = int8_model.to(0) # 여기에서 양자화가 진행
이를 통해 int8로 양자화된 모델에 input을 넣어 inference를 진행할 수 있다.
input_ = torch.randn((1, 64), dtype=torch.float16)
hidden_states = int8_model(input_.to(torch.device('cuda', 0)))
전체 코드는 다음과 같다.
이처럼 LLM.int8()은 LLM의 계산 성능을 크게 향상시킬 수 있는 효과적인 방법으로ㅡ bitsandbytes 라이브러리를 사용하여 간단하게 구현할 수 있다.
https://colab.research.google.com/drive/1v4m4uB0Q5rntkWWbj2E02Fuo8HXTEJgC?usp=sharing
'Bigdata' 카테고리의 다른 글
RNN - Python numpy 기초 코드 실습 (0) | 2024.06.08 |
---|---|
2023년 빅데이터 오픈소스 플랫폼 Top 3 (0) | 2024.06.08 |
LLM - Llama2(라마2) 모델 개인 노트북으로 실행하기(CPU기반) (1) | 2024.06.08 |
Σ σ, ς / 시그마(sigma) - 뜻과 읽는법 (0) | 2024.06.07 |
Hugging Face - model(허깅페이스 모델) download 3가지 방법 (0) | 2024.04.12 |