모델 경량화 기법: Pruning(가지치기) - 실전편
1. 프루닝 구현 방법
1.1 PyTorch에서의 프루닝 구현
PyTorch에서는 torch.nn.utils.prune 모듈을 통해 다양한 프루닝 기법을 구현할 수 있다. 주요 함수들은 다음과 같다:
- prune.random_unstructured: 무작위로 파라미터를 프루닝
- prune.l1_unstructured: L1 norm 기준으로 프루닝
- prune.ln_structured: Ln norm 기준으로 구조적 프루닝
- prune.global_unstructured: 전역 기준으로 프루닝
1.2 기본 프루닝 예시 코드
import torch
import torch.nn.utils.prune as prune
# 모델 정의
model = torch.nn.Linear(10, 10)
# 가중치의 30%를 L1 norm 기준으로 프루닝
prune.l1_unstructured(model, name='weight', amount=0.3)
# 프루닝된 가중치 확인
print(f"프루닝된 가중치 비율: {float(torch.sum(model.weight == 0)) / float(model.weight.nelement()):.2f}")
2. 고급 프루닝 기법
2.1 Iterative Magnitude Pruning (IMP) 구현
IMP는 다음과 같은 단계로 구현할 수 있다:
- 초기 모델 학습
- 프루닝 수행
- 가중치 재초기화
- 재학습
- 반복
def iterative_pruning(model, train_loader, prune_ratio, num_iterations):
# 초기 가중치 저장
initial_state_dict = copy.deepcopy(model.state_dict())
for iteration in range(num_iterations):
# 모델 학습
train_model(model, train_loader)
# 프루닝 수행
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=prune_ratio
)
# 가중치 재초기화
reset_weights(model, initial_state_dict)
2.2 Structured Pruning 예시
채널 단위 프루닝의 구현 예시:
def structured_channel_pruning(model, amount):
for name, module in model.named_modules():
if isinstance(module, torch.nn.Conv2d):
prune.ln_structured(
module,
name='weight',
amount=amount,
n=2,
dim=0 # 출력 채널 방향으로 프루닝
)
3. 프루닝 성능 평가
3.1 주요 평가 지표
- 모델 크기 감소율: (프루닝 후 크기) / (원본 크기)
- 연산량 감소율: (프루닝 후 FLOPs) / (원본 FLOPs)
- 정확도 변화: (프루닝 후 정확도) - (원본 정확도)
- 추론 시간: 실제 디바이스에서의 속도 향상 정도
3.2 성능 측정 코드
def evaluate_pruning(original_model, pruned_model, test_loader):
# 모델 크기 비교
original_size = get_model_size(original_model)
pruned_size = get_model_size(pruned_model)
# 정확도 비교
original_acc = evaluate(original_model, test_loader)
pruned_acc = evaluate(pruned_model, test_loader)
print(f"크기 감소율: {pruned_size/original_size:.2f}")
print(f"정확도 변화: {pruned_acc - original_acc:.2f}%")
4. 실전 팁과 주의사항
4.1 효과적인 프루닝을 위한 팁
- 점진적 프루닝: 한 번에 많은 양을 프루닝하지 않고 조금씩 진행
- 레이어별 차등: 중요한 레이어는 프루닝 비율을 낮게 설정
- 데이터 특성 고려: 데이터셋의 특성에 맞는 프루닝 방식 선택
- 하드웨어 고려: 목표 디바이스의 특성에 맞는 프루닝 구조 선택
4.2 주의사항
- 과도한 프루닝 주의: 성능 저하가 급격히 발생할 수 있다
- 재학습 시간 고려: 프루닝 후 재학습에 상당한 시간이 필요하다
- 메모리 효율성: 구조적 프루닝이 실제 속도 향상에 더 효과적일 수 있다
댓글남기기