1 분 소요

1. 프루닝 구현 방법

1.1 PyTorch에서의 프루닝 구현

PyTorch에서는 torch.nn.utils.prune 모듈을 통해 다양한 프루닝 기법을 구현할 수 있다. 주요 함수들은 다음과 같다:

  • prune.random_unstructured: 무작위로 파라미터를 프루닝
  • prune.l1_unstructured: L1 norm 기준으로 프루닝
  • prune.ln_structured: Ln norm 기준으로 구조적 프루닝
  • prune.global_unstructured: 전역 기준으로 프루닝

1.2 기본 프루닝 예시 코드

import torch
import torch.nn.utils.prune as prune

# 모델 정의
model = torch.nn.Linear(10, 10)

# 가중치의 30%를 L1 norm 기준으로 프루닝
prune.l1_unstructured(model, name='weight', amount=0.3)

# 프루닝된 가중치 확인
print(f"프루닝된 가중치 비율: {float(torch.sum(model.weight == 0)) / float(model.weight.nelement()):.2f}")

2. 고급 프루닝 기법

2.1 Iterative Magnitude Pruning (IMP) 구현

IMP는 다음과 같은 단계로 구현할 수 있다:

  1. 초기 모델 학습
  2. 프루닝 수행
  3. 가중치 재초기화
  4. 재학습
  5. 반복
def iterative_pruning(model, train_loader, prune_ratio, num_iterations):
    # 초기 가중치 저장
    initial_state_dict = copy.deepcopy(model.state_dict())
    
    for iteration in range(num_iterations):
        # 모델 학습
        train_model(model, train_loader)
        
        # 프루닝 수행
        prune.global_unstructured(
            parameters_to_prune,
            pruning_method=prune.L1Unstructured,
            amount=prune_ratio
        )
        
        # 가중치 재초기화
        reset_weights(model, initial_state_dict)

2.2 Structured Pruning 예시

채널 단위 프루닝의 구현 예시:

def structured_channel_pruning(model, amount):
    for name, module in model.named_modules():
        if isinstance(module, torch.nn.Conv2d):
            prune.ln_structured(
                module,
                name='weight',
                amount=amount,
                n=2,
                dim=0  # 출력 채널 방향으로 프루닝
            )

3. 프루닝 성능 평가

3.1 주요 평가 지표

  • 모델 크기 감소율: (프루닝 후 크기) / (원본 크기)
  • 연산량 감소율: (프루닝 후 FLOPs) / (원본 FLOPs)
  • 정확도 변화: (프루닝 후 정확도) - (원본 정확도)
  • 추론 시간: 실제 디바이스에서의 속도 향상 정도

3.2 성능 측정 코드

def evaluate_pruning(original_model, pruned_model, test_loader):
    # 모델 크기 비교
    original_size = get_model_size(original_model)
    pruned_size = get_model_size(pruned_model)
    
    # 정확도 비교
    original_acc = evaluate(original_model, test_loader)
    pruned_acc = evaluate(pruned_model, test_loader)
    
    print(f"크기 감소율: {pruned_size/original_size:.2f}")
    print(f"정확도 변화: {pruned_acc - original_acc:.2f}%")

4. 실전 팁과 주의사항

4.1 효과적인 프루닝을 위한 팁

  1. 점진적 프루닝: 한 번에 많은 양을 프루닝하지 않고 조금씩 진행
  2. 레이어별 차등: 중요한 레이어는 프루닝 비율을 낮게 설정
  3. 데이터 특성 고려: 데이터셋의 특성에 맞는 프루닝 방식 선택
  4. 하드웨어 고려: 목표 디바이스의 특성에 맞는 프루닝 구조 선택

4.2 주의사항

  • 과도한 프루닝 주의: 성능 저하가 급격히 발생할 수 있다
  • 재학습 시간 고려: 프루닝 후 재학습에 상당한 시간이 필요하다
  • 메모리 효율성: 구조적 프루닝이 실제 속도 향상에 더 효과적일 수 있다

댓글남기기