본문 바로가기
논문 리뷰

[논문 리뷰] Structural Pruning for Diffusion Models

by donyy 2024. 5. 22.

Abstract

  • 확산 확률 모델(Diffusion Probabilistic Models, DPMs)은 효과적이지만 훈련과 추론 동안 상당한 계산 비용을 수반함.
  • 이를 해결하기 위해, 기존의 확산 모델로부터 경량 확산 모델을 학습하는 효율적인 압축 방법인 Diff-Pruning을 제안: 비기여 확산 단계를 무시하고 중요한 가중치를 식별하기 위해 정보가 풍부한 그라디언트를 조합하는 가지치기된 시간 단계에 대한 테일러 전개(Taylor expansion)
  • Diff-Pruning의 이점
    1. 효율성: 기존 학습 비용의 10~20%로 FLOP 약 50% 감소
    2. 일관성: 가지치기된 확산 모델의 사전 학습된 모델과 일치하는 생성 행동 유지
  • 코드: https://github.com/VainF/Diff-Pruning

1. Introduction

DPMs의 성능은 뛰어나지만, 학습과 추론 시 상당한 계산 오버헤드를 수반하는 경우↑ = 성능과 효율성 사이의 trade-off

→ 자원이 제한된 환경에서 이러한 모델의 광범위한 적용에 있어 중요한 도전 과제

 

기존의 해결책

  • 아키텍처 개선, 학습 방법 최적화, 샘플링 가속화 등 세 가지 주요 주제를 중심으로 확산 모델을 개선하고자 함.
  • 일부 개선 효과를 보였으나, 여전히 문제 문제가 있음: 무거운 재훈련 없이 사전 학습된 모델을 효율적으로 재사용하고 사용자 지정할 수 있는 일반적인 압축 방법이 없다는 것 (the absence of a general compression method that enables the efficient reuse and customization of these pre-existing models without heavy re-training)
  • 과제 해결의 중요성: 사전 학습된 확산 모델의 잠재력 활용 + 다양한 도메인/작업에서의 광범위한 응용 촉진

연구의 의의: Diff-Pruning

  • 구조적 가지치기
    • 네트워크에서 불필요한 매개변수와 하위 구조를 제거 → 모델 크기 효과적으로 ↓
    • but DPMs에 구조적 가지치기를 적용하는 것은 어려움.
  • Diff-Pruning
    • 동기부여: 확산 과정의 다른 단계가 생성된 샘플에 가변적으로 기여한다는 것.
    • 핵심 내용: 가지치기된 시간 단계에 대한 테일러 전개로, 이미지 내용, 세부 사항, 가지치기하는 동안의 노이즈 확산 단계에서의 부정적인 영향을 균형 있게 맞춤.
      • 확산모델의 목표: 후기 시간 단계(t→T): 생성된 이미지의 고수준 내용(=퀄리티) 우선시 → 초기 시간 단계(t→0): 더 세부적인 내용을 정제
      • but 가지치기를 위한 테일러 전개 사용 시, 큰 t를 가진 노이즈 단계는 중요도 추정에 정보가 풍부한 그래디언트 제공X 압축 성능에 해 끼칠 수 있음 ⇒ 내용, 세부 사항 및 노이즈 간의 트레이드오프를 모델링하여 효율적이고 유연한 확산 모델 가지치기 알고리즘 제안
    • 결과: 모델의 생성 품질을 유지/개선하면서도 상당한 압축률을 달성함.
      • 효율성: 원래 모델의 10%의 학습 비용으로 FLOP 50% 압축함.
      • 일관성: 동일한 입력이 주어졌을 때, 압축된 모델과 사전 학습된 모델이 일관된 출력을 생성할 수 있음.

2. Ralted Works

Efficient Diffusion Models

  • 효율적인 확산 모델 기존 방법론
    1. 네트워크 아키텍처의 개선
    2. 일반적으로 U-Net 모델을 노이즈 제거기로 사용, 효율성은 계층적 디자인을 도입 or 새로운 잠재 공간 내에서 학습을 실행하여 향상
    3. 학습 절차의 향상
    4. 다른 시간 단계에 할당된 가중치를 조정하여 확산 학습을 가속화 패치 수준에서 확산 모델을 학습 → 학습 효율성↑
    5. 샘플링 가속화 일반적으로 확산 모델의 재학습을 필요로 하지 않는 샘플링의 효율성 강조 조기 종료, 증류 등의 방법으로 필요한 단계↓

Network Pruning

가지치기 방법론의 분류: 구조적 가지치기 & 비구조적 가지치기

  • 구조적 가지치기: 네트워크에서 매개변수와 하위 구조를 물리적으로 제거할 수 있는 능력
  • 비구조적 가지치기: 본질적으로 매개변수를 제로로 마스킹

네트워크 가지치기 연구의 대부분은 주로 판별 작업(특히 분류 작업)에 집중, 생성 작업에서의 압축은 많이 연구되지 않았음.

3. Diffusion Model Objectives

확산 모델 정의

데이터 분포 q(x)가 주어졌을 때, 확산 모델은 q(x)를 근사하기 위해 생성 분포 $\small p_\theta(x)$를 모델링하는 것을 목표로 함.

$$ p_{\theta}(x) = \int p_{\theta}(x_0:T) \, dx_1:T, \text{ where } p_{\theta}(x_0:T) := p(x_T) \prod_{t=1}^T p_{\theta}(x_{t-1} | x_t)

$$

  • 확산 모델의 전체적인 구조
  • $\small x_1, \ldots, x_T$: latent variables (잠재 변수) 확산 프로세스에서 시간 $\small t$ 에 따라 변하는 데이터의 상태를 나타냄.
  • $\small p(x_T)$: 초기 데이터 분포
  • $\small p_0(x_{t-1} | x_t)$: 조건부 분포 시간 $\small t$ 에서 $\small t-1$ 로의 변화를 모델링하며, 주로 가우스 분포를 사용하여 정의됨.

노이즈 예측 목표

$$ \mathcal{L}(\theta) := \mathbb{E}{t, x_0 \sim q(x), \epsilon \sim \mathcal{N}(0,1)} \left[ \| \epsilon - \epsilon{\theta}(\sqrt{\hat{\alpha}_t} x_0 + \sqrt{1 - \hat{\alpha}_t} \epsilon, t) \|^2 \right] $$

  • 노이즈 예측기의 학습 목표를 나타내는 수식
  • $\small \epsilon$ : 고정된 가우스 분포에서 추출된 랜덤 노이즈
  • $\small \epsilon_\theta$ : 학습된 노이즈 예측기
  • 목표는 원본 데이터 $\small x_0$ 에서 노이즈를 제거하는 것 → 이를 통해 학습 과정에서 정확한 데이터 재구성을 달성하고자 함

역확산 과정

$$ x_{t-1} = \frac{1}{\sqrt{\alpha_t}} \left( x_t - \frac{\beta_t}{\sqrt{1 - \alpha_t}} \epsilon_{\theta}(x_t, t) \right) + \sigma_t z, \text{ where } z \sim \mathcal{N}(0, I) \text{ for steps } t > 1 \text{ and } z = 0 \text{ for } t = 1. $$

  • $\small \alpha_t = 1 - \beta_t$: 시간에 따라 변하는 스케일 매개변수
  • $\small z$: 표준 정규 분포에서 추출된 노이즈
  • $\small x_0$에서 시작하여 점차적으로 노이즈를 제거함으로써 $\small x_0$을 재구성

4. Structrual Pruning for Diffusion Models

목표 설정

사전 학습된 확산 모델의 매개변수 $\small \theta$가 주어졌을 때 네트워크에서 하위 구조를 제거하여 경량화된 $\small \theta'$을 만드는 것

$$ \min_{\theta'} \left( L(\theta') - L(\theta) \right) \quad \text{s.t.} \quad \|\theta'\|_0 \leq s

$$

  • 매개변수 $\small \theta$ : 각 하위 구조 $\small \theta_i = [\theta_{i0}, \theta_{i1}, \dots, \theta_{iK}, ]$가 K scalar parameters를 포함하는 row 벡터인 간단한 2-D 행렬로 가정.
  • 구조적 가지치기는 원래 성능을 최대한 보존하는 $\small \theta'$의 희소 매개변수 행렬을 찾는 것을 목표로 함. = 가지치기를 통해 파라미터를 최소화하면서 손실 함수 𝐿에서 발생하는 변화를 최소화하는 문제
  • $\small |\theta'|_0$: 비제로(non-zero) 행 벡터 수를 세는 매개변수들의 L-0 norm을 의미
  • $\small s$ : 가지치기된 모델의 희소성을 나타냄.
  • 확산 모델의 반복적인 본성 때문에 훈련 목표 $\small L$은 $\small T$개의 상호 연결된 작업 $\small \{ L_1, L_2, \dots, L_T \}$의 조합으로 인식될 수 있음.
  • 각 작업은 서로에게 영향을 미치고 의존함 ⇒ 주로 단일 목표를 최적화하는 전통적인 가지치기 문제와는 다른 새로운 challenge 제시
  • 초반에 가지치기에서 각 손실 구성요소 $\small L_T$의 개별 기여 조사 → 확산 모델 가지치기를 위해 Diff-Pruning 제안

Taylor Expansion at $L_t$

구조적 가지치기를 위해 $\small L_T$의 기여를 모델링

$\small L_T$에서의 손실 방해를 선형적으로 근사하기 위해 테일러 전개 활용

$$ Lt(\theta') = Lt(\theta) + \nabla Lt(\theta)(\theta' - \theta) + O(\|\theta' - \theta\|^2) \\ \Rightarrow Lt(\theta') - Lt(\theta) = \nabla Lt(\theta)(\theta' - \theta) + O(\|\theta' - \theta\|^2) $$

  • Diff-Pruning은 가중치의 중요성을 평가하기 위해 가지치기된 시간 단계에서 테일러 전개 활용
  • 초기 단계에서는 색상과 가장자리와 같은 로컬 디테일에 초점 → 후반 단계에서는 객체와 형태와 같은 콘텐츠에 더 초점
  • 이를 위해 이진 가중치 αt ∈ {0, 1}을 사용하여 이러한 요소들 사이의 균형을 맞추는 임계값 방법 제안
  • $\small \theta_{ik}$ 의 중요성을 평가하려면 위의 식에서 $\small \theta'_{ik}$ = 0으로 설정하면 됨.

$$ I_t(\theta_{ik}, x) = |L_t(\theta|θ_{ik}=0) - L_t(\theta)| $$

$$ = |(θ_{i0} - θ_{i0}) \cdot \nabla θ_{i0} + \cdots + (0 - θ_{ik}) \cdot \nabla θ_{ik} + \cdots + (θ_{iK} - θ_{iK}) \cdot \nabla θ_{iK}| $$

$$ = |θ_{ik} \cdot \nabla θ_{ik} L_t(\theta, x)| $$

$$ I_t(\theta_{i}, x) = \sum_k |L_t(\theta|{\theta{ik}=0}) - L_t(\theta)| = \sum_k |θ_{ik} \cdot \nabla θ_{ik} L_t(\theta, x)| $$

The Contribution of $L_t$

  • 시간 단계 t에서 동일한 입력에 대해 예측 오류 $\small \delta_t = \epsilon_{\theta'}(x,t)$ 고려
  • Reverse process는 $\small \epsilon_{\theta'}(x,t)=\epsilon_\theta(x,t)+\delta_t$ 에서 시작해서 3번 수식의 과정 반복 → 시간 단계 t에서 생성된 이미지 $\small x_{0}$에서 $\small \theta'$에서 $\small \delta_{t\rightarrow0}$ 의 오류를 검토
  • 이 오류는 $\small x'{t-1}=x{t-1}+\delta_{t-1}$ 에 의해 주어진 다음 입력값에 직접적인 영향을 미침 ⇒ 왜곡이 점차 확대
  • t ↑: 체인 효과에 의해 이미지에 큰 영향
  • t ↓: 적은 수정을 통한 이미지 수정 (= 작은 영향0
  • 시간 단계 t에 대한 가중 변수로 작용하는 $\small \alpha_t$를 도입하여 서로 다른 시간 단계에 대한 절충안 도입 (model the pruning problem as a weighted trade-off between contents and details by introducing α_{t})
  • $\small \alpha_t$ 는 0 또는 1의 값을 갖도록 하여 모든 시간 단계에 대한 부분 단계를 허용

Taylor Score over Pruned Timesteps

  • 확산 프로세스에서 일부 "중요하지 않은" 시간 단계를 제거하고자 함.
  • 실험 결과
    • 콘텐츠 생성을 담당하는 시간 단계가 확산 프로세스의 끝에만 있는 것이 아님.
    • 전체 단계 목표를 사용하면 최적이 아닌 결과가 도출될 수 있음. (← 노이즈 단계에 수렴된 gradients 때문.)
  • 손실 $\small L_{t}$가 수렴하면 손실 곡선은 주로 고차 항목에 의해 영향을 받음.
  • $\small L_{t}$ 손실 항목이 t→T로 접근할 때 빠르게 감소
  • t가 클수록 예측 오류가 더 큰 영향을 미침 ⇒ 중요한 시간 단계를 포함하는 것이 필요

5. Experiments

5.1 Settings

Datasets and Models

  • 여섯 가지 다양한 데이터셋 사용: CIFAR-10(32×32), CelebA-HQ(64×64), LSUN Church(256×256), LSUN Bedroom(256×256), ImageNet-1K(256×256) 등
  • 사용된 DPMs 유형: Denoising Diffusion Probability Models (DDPMs)와 Latent Diffusion Models (LDMs)
  • (재현 가능성을 위해 [18]과 [41]에서 사전 훈련된 DPM들을 사용 → 이 모델들을 한 번에 가지치기)

Evaluation Metrics

  1. 효율성 메트릭: 매개변수 수(#Params)와 곱셈-추가 누적(Multiply-Add Accumulation, MACs) 포함
  2. 품질 메트릭: Frechet Inception Distance (FID)
  3. 일관성 메트릭: Structural Similarity (SSIM)
    • 동일한 노이즈 입력으로 사전 훈련된 모델과 가지치기된 모델이 생성한 이미지 간의 유사성을 평가
    • ImageNet에 대해서는 250단계 DDIM 샘플러, 다른 실험에는 100단계 DDIM 샘플러 사용

5.2 An Simple Benchmark for Diffusion Pruning

Scratch Training vs. Pruning

  • 스크래치 트레이닝
  • 수렴에 상대적으로 더 많은 단계 요구 ⇒ 스크래치에서의 훈련 비용이 사전 훈련된 모델과 유사 ⇒ 스크레치에서 경량 모델을 훈련하는 거이 그닥 효율적/경제적이지 않을 수 있음.
  • (모든) 가지치기 방법
  • 약 100K 단계 내에 수렴 가능 & FID 및 SSIM 점수가 스크래치 트레이닝보다 ↑

Pruning Criteria

네트워크 가지치기에서 중요한 것 = 네트워크 내 불필요한 매개변수를 식별하는 가지치기 기준을 설정하는 것

세 가지 기본 가지치기 방법 적용:

  1. Random 무작위 가지치기편견이 없는 간단한 기준선을 제공하고 가지치기에 대한 경쟁적인 기준선 제공.
  2. 각 $\small \theta_i$ 에 대해 균일 분포에서 파생된 importance score 무작위 할당
  3. Magnitude 크기 기반 가지치기cf) 수식 $\small I(\theta) = |\theta|$ 의 의미
    • 매개변수 $\small \theta$ 의 중요도를 그 매개변수의 절대값으로 정의 즉, 매개변수의 크기가 클수록 그 매개변수는 모델에 있어 더 중요하다고 간주함.
    • 매우 간단하다는 장점, 매개변수의 크기만을 고려하여 중요도 평가함.
  4. "norm이 작을수록 정보가 적다"는 가설 → 중요도를 $\small I(\theta) = |\theta|$ 로 모델링
  5. Taylor 테일러 기반 가지치기cf) 수식 $\small I(\theta) = |\theta|$ 의 의미
    • 매개변수 $\small \theta$ 의 중요도를 평가할 때 매개변수의 크기뿐만 아니라, 그 매개변수가 손실 함수 $\small L(x,\theta)$의 변화율에 미치는 영향도 함께 고려함.
    • $\small \nabla_\theta L(x, \theta)$: 손실 함수에 대한 $\small \theta$ 의 그래디언트(미분 값) 의미 매개변수가 모델의 성능에 얼마나 영향을 미치는지를 나타냄.
  6. 중요도를 $\small I(\theta, x) = |\theta| \cdot \nabla_\theta L(x, \theta)$ 로 측정 Diff-Pruning 방법에서 논의된 손실 변화를 최소화하는 데이터 중심 기준
  • 세 가지 기준 방법:
  • 두 데이터셋에서 일관된 순위를 유지X - Magnitude는 CIFAR-10에서 최고의 FID 성능 달성 but CelebA 데이터셋에서는 성능↓
  • 해당 연구 방법: 기존 방법들에 비해 두 데이터셋 모두에서 안정적인 개선 CelebA-HQ에서는 100K 최적화만으로도 사전 훈련 모델 능가 but CIFAR-10에서는 더 복잡한 장면과 더 많은 범주로 인해 성능 저하 발생

5.3 Pruning at Higher Resolutions

DDPMs on LSUM

LSUN Church와 LSUN Bedroom이라는 두 개의 256×256 해상도 장면 데이터셋으로Diff-Pruning의 효율성 및 효과 입증을 위한 가지치기 실험 수행

  • 사전 훈련된 모델: 각각 약 2.4M과 4.4M 훈련 단계 필요 (=많은 시간 소요)
  • 가지치기된 모델: 표준 단계의 10%에 해당하는 Steps에서 통과할 만한 FID 점수로 수렴
  • Scratch Training: 심각하게 부적합
  • but 대규모 데이터셋에서 훈련된 모델을 압축하는 것은 여전히 어려움.
  • 더 많은 훈련 단계로 FID 점수를 추가로 개선 가능.
  • 가지치기된 모델은 대부분의 매개변수를 사전 훈련된 모델로부터 상속받으므로 유사한 이미지 생성 능력 보존 가능.

Conditional LDMs on ImageNet

ImageNet-1K에서 사전 훈련된 LDM의 가지치기 결과

  • LDM은 인코더, 디코더, 그리고 U-Net 모델로 구성 → 약 400M의 매개변수는 U-Net 아키텍처에서, 55M개는 오토인코더에서 나옴 ⇒ 주로 U-Net 모델의 가지치기에 초점
  • 수렴된 레이어 무시 & 가지치기 과정의 효율성 증대 위해 임계값 T = 0.1 사용. ⇒ 534단계만이 가지치기 과정에 참여
  • 중요도 추정 → 모든 레이어에 30%의 pre-defined channel sparsity (사전 정의적 채널 희소성?) 적용 → 189.43M 매개변수를 가진 경량화된 U-Net
  • 가지치기가 된 모델 fine-tuning

5.4 Ablation Study

Pruned Timesteps

더 큰 gradients를 가진 단계 우선시 + 가능한 한 많은 내용과 세부 정보 보전 노력

→ 더 정확하고 효율적인 pruning을 가능하게 함.

The impacts of timestep pruning

  • 사전 훈련된 확산 모델을 50단계에서 1000단계에 걸쳐 가지치기하며, 가지치기로 인한 왜곡을 측정하기 위해 SSIM 지표를 사용
  • 초기 단계(t → 0)가 나중 단계(t → T)에 비해 일반적으로 더 큰 그라디언트 (더 큰 기울기) 를 가지며, 이러한 기울기는 단계 수가 많을수록 수렴에 도달
  • CIFAR-10 데이터셋에서는 250단계에서 최적의 SSIM 점수가 나타났으며, 더 많은 단계를 추가할 경우 이미지 품질이 소폭 저하 ← 수렴 지점에서 일차 테일러 전개의 부정확성 때문으로 이때의 기울기는 더 이상 유용한 정보 제공xX,누적을 통해 유익한 기울기 왜곡
  • CelebA 데이터셋에서는 상황이 약간 다르며, 더 많은 단계가 중요도 추정에 도움이 될 수 있음을 관찰함.

Pruning Ratios

  • 판별 모델에서 사용되는 CNN과 달리, 확산 모델은 모델 크기 변화에 상당히 민감 16%라는 적은 절단 비율에도 FID 점수에 눈에 띄는 저하 초래(4.19 → 4.62)
  • 분류 작업에서는 손실의 변동이 예측 신뢰도를 악화시킬 뿐, 반드시 최종 정확도에 영향을 미치지는 것은X
  • but 생성 모델에서는 FID 점수가 매우 민감 ⇒ 도메인 이동에 더 취약

Thresholding

  • $\tau$ = 0: 모든 단계에서 완전한 테일러 전개를 의미 T > 0: 중요도 추정 동안 특정 타임스텝 절단을 의미
  • 표 5에서 제시된 정량적 결과는 그림 5에서 보여진 SSIM 결과와 일치
  • Diff-Pruning은 생성된 이미지의 품질이 최상일 때 최적의 성능 발휘
  • CIFAR-10과 같은 데이터셋에서는 약 200단계의 테일러 전개가 만족스러운 결과를 달성하기에 충분 → 이 경우 오히려 전체 테일러 전개를 사용하는게 ⊖: 약 700단계에 걸친 잡음 그라디언트를 누적시켜 초기 단계의 정확한 그라디언트 정보를 흐리게 하기 때문.

Visualization of Different Importance Criteria

T > 0을 사용한 Diff-Pruning 방법: 가지치기 후 SSIM=0.905으로 뛰어난 시각적 품질 달성

Diff-Pruning에서는 더 많은 시간 단계를 사용하는 것이 오히려 부정적인 영향 - 질감과 내용 모두에서 더 큰 왜곡을 초래할 수 있음.

6. Conclusion

Diff-Pruning

  • 특징: 가지치기된 시간 단계에 대한 테일러 전개를 활용하여 중요하지 않은 매개변수를 식별하고 제거
  • 결과:
    • 사전 훈련된 모델로부터 가벼우면서도 일관된 모델 생성 가능
    • 사전 훈련에 비해 약 10%~20%의 적은 비용 발생
  •