HiFi-GAN 논문 리뷰
Abstract
- HiFi-GAN으로 좋은 품질의 오디오를 빠르게 합성할 수 있다.
- 오디오의 퀄리티를 높이는데 오디오의 주기적인 패턴을 모델링 한 것이 중요했다.
- 학습 때 보지 못했던 화자의 mel-spectrogram도 오디오로 잘 합성한다.
Introduction
- 기존에는 Autoregressive 모델이나 flow-based model을 사용하였는데, 최근에는 GAN(generative adversarial networks) 구조를 활용해 보코더(vocoder)의 음성 합성 속도와 메모리 효율을 높이는 시도가 있었다.
- GAN 구조의 모델들은 합성속도가 빨랐지만, 기존 방법들에 비해 음성 퀄리티가 조금씩 떨어졌다.
- 본 논문에서 제안한 HiFi-GAN은 여러 개의 sub-discriminator를 사용해 음성 퀄리티도 높이고, 합성 속도도 빠르게 할 수 있었다.
Model
Overview
- HiFi-GAN은 하나의 generator와 두 개의 discriminators로 구성되어있다.
- multi-scale discriminator(MSD)
- multi-period discriminator(MPD)
Generator
- 인풋으로 mel-spectrogram을 받고, 아웃풋으로 오디오를 반환
- 노이즈를 추가적인 인풋으로 사용하지 않았다. (noise augment X)
Multi-Receptive Field Fusion
- 각각의 residual block이 다른 kernel sizes와 dilation rates를 가져서 다양한 길이를 관찰할 수 있다.
- 여러 개의 조절할 수 있는 파라미터를 가지고 있어, 합성 퀄리티와 합성 시간을 trade-off로 조절할 수 있다.
코드로 이해하기가 더 쉬운 분들을 위해 간략한 코드도 첨부합니다.
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from utils import init_weights, get_padding
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
self.conv_pre = weight_norm(Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3))
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
self.ups.append(weight_norm(
ConvTranspose1d(h.upsample_initial_channel//(2**i), h.upsample_initial_channel//(2**(i+1)),
k, u, padding=(k-u)//2)))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel//(2**(i+1))
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
self.ups.apply(init_weights)
self.conv_post.apply(init_weights)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i*self.num_kernels+j](x)
else:
xs += self.resblocks[i*self.num_kernels+j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2])))
])
self.convs1.apply(init_weights)
self.convs2 = nn.ModuleList([
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1))),
weight_norm(Conv1d(channels, channels, kernel_size, 1, dilation=1,
padding=get_padding(kernel_size, 1)))
])
self.convs2.apply(init_weights)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
Discriminator
- MelGAN 논문에서 제안한 multi-scale discriminator (MSD)와 본 논문에서 제안한 multi-period discriminator (MPD)를 사용합니다.
Multi-Period Discriminator
- 길이가 T인 1D 오디오를 주기 p로 나눠서 (T / p, p) 인 2D 데이터로 shape을 변경하고 2D convolution을 적용합니다.
- 모든 convolution layer에서 kernel size의 width를 1로 하여 독립적으로 주기적인 샘플만 처리하도록 하였습니다.
- [2, 3, 5, 7, 11] 주기를 사용하여 최대한 겹치지 않고 오디오의 다양한 부분을 봅니다.
class DiscriminatorP(torch.nn.Module):
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
super(DiscriminatorP, self).__init__()
self.period = period
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv2d(1, 32, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(32, 128, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(128, 512, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(512, 1024, (kernel_size, 1), (stride, 1), padding=(get_padding(5, 1), 0))),
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
])
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), "reflect")
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(self):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorP(period=2),
DiscriminatorP(period=3),
DiscriminatorP(period=5),
DiscriminatorP(period=7),
DiscriminatorP(period=11),
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
Multi-Scale Discriminator
- MPD가 주기로 나눠서 샘플들을 따로따로 처리했기 때문에, MSD는 오디오를 시간 순서대로 쭉 평가한다.
- MSD는 input scale을 다르게 한 3개의 sub-discriminator가 혼합한 것이다.
- 하나는 그냥 raw audio
- 두번째는 시간 축으로 1/2 average pooling 된 audio
- 세번째는 시간 축으로 1/4 average pooling 된 audio
class DiscriminatorS(torch.nn.Module):
def __init__(self, use_spectral_norm=False):
super(DiscriminatorS, self).__init__()
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
self.convs = nn.ModuleList([
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
])
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
def forward(self, x):
fmap = []
for l in self.convs:
x = l(x)
x = F.leaky_relu(x, LRELU_SLOPE)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(self):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
DiscriminatorS(use_spectral_norm=True),
DiscriminatorS(),
DiscriminatorS(),
])
self.meanpools = nn.ModuleList([
AvgPool1d(4, 2, padding=2),
AvgPool1d(4, 2, padding=2)
])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for i, d in enumerate(self.discriminators):
if i != 0:
y = self.meanpools[i-1](y)
y_hat = self.meanpools[i-1](y_hat)
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
Training Loss Terms
GAN Loss
- Generator와 discriminator는 LS-GAN처럼 non-vanishing gradient flow를 위해 기본적인 GAN의 binary cross-entropy 부분을 least square loss function로 바꿔서 학습했다.
- Discriminator는 ground truth sample을 1, generator로 만들어진 sample을 0으로 구분하도록 학습한다.
- Generator는 sample quality를 높여서 discriminator를 속이도록 학습한다.
- x는 ground truth 오디오, s는 ground truth 오디오의 mel-spectrogram
Mel-Spectrogram Loss
- Generator의 학습 효율을 높이고, 오디오의 quality를 높이기 위해 mel-spectrogram loss를 추가하였다.
- Generator가 합성한 오디오의 mel-spectrogram과 ground truth의 mel-spectrogram의 L1 distance이다.
Feature Matching Loss
- Generator가 합성한 오디오를 discriminator에 통과시킨 중간 feature값, ground truth를 discriminator에 통과시킨 중간 feature값을 L1 distance로 계산한 loss이다.
- MelGAN 논문에서 처음으로 제안하였고 성공적으로 적용되어 generator를 학습하는데 추가적으로 사용된다.
- T는 discriminator layer 수, D는 각각 i 번째 layer의 feature값, N은 i번째 layer에서 feature 개수이다.
Final Loss
- 최종적으로 위는 generator loss, 아래는 discriminator loss 이다.
Experiments
- 학습 데이터로 LJSpeech 사용
- 학습 때 보지 못했던 화자의 오디오 생성을 평가하기 위해 VCTK multi-speaker 데이터셋 사용
- generator는 하이퍼파라미터에 따라 V1, V2, V3 버전이 있다.
Results
Audio Quality and Synthesis Speed
- V1은 조금 무거운 편이지만, ground truth 오디오랑 MOS가 0.09 밖에 차이가 나지 않았다.
- V2는 조금 더 라이트하지만 합성 속도가 빨라서 대부분 V2로 많이 사용하는 것 같다.
- 또는 trade-off로 선택해서 사용하면 될 것 같다.
Ablation Study
- MPD, MSD, MRF 그리고 mel-spectrogram loss에 대한 ablation study를 진행했다.
- MSD를 제거했을 때는 미세하게 떨어졌지만, MPD를 제거하면 상당히 많이 점수가 떨어졌다.
- MRF와 mel-spectrogram loss를 제거했을 때도 점수가 떨어졌다.
- 심지어 MelGAN에 MPD를 추가하니까 점수가 많이 올랐다.
Generalization to Unseen Speakers
- unseen speaker의 발화도 잘 합성해낸다.
End-to-End Speech Synthesis
text to mel-spectrogram
,mel-spectrogram to waveform
인 end-to-end 음성 합성을 분석- Tacotron2 + HiFi-GAN 구조로 실험 진행.
- 실제 음성의 mel-spectrogram과 tacotron2의 mel-spectrogram을 비교해보았을 때, tacotron2가 생성한 mel-spectrogram이 noisy가 많았다.
- 그래서 tacotron2가 생성한 mel-spectrogram으로 fine-tuning을 하니까 MOS가 많이 올랐다.
- 반면에 WaveGlow는 점수가 비슷했다.
- WaveGlow는 fine-tuning한 것과 안한 것의 차이가 별로 없다.
- 하지만 HiFi-GAN에서는 fine-tuning을 했을 때, mel-spectrogram이 조금 더 noisy해졌지만, 음성의 퀄리티는 더 높아진 것을 확인할 수 있었다.
Conclusion
- HiFi-GAN 모델이 공개 되어있는 모델 중 음성 합성 퀄리티가 가장 좋고 합성 속도도 빠르다.
- 그 이유로는 음성의 다양한 period를 고려하는 MPD가 효과적이었다.
- unseen speaker에 대해서도 잘 합성한다.
- 같은 discriminator와 학습 메커니즘으로 3개의 generator를 목적에 맞게 학습할 수 있다.
Leave a comment