4 minute read


ASR 모델은 크게 CTC, Seq2Seq (AED), RNN-T 세 가지로 분류할 수 있습니다.
해당 포스트에서는 각 모델들의 beam search 방법을 정리하고 기록합니다.


CTC

  • prefix search 라고도 부른다.
  • 마지막이 blank token 인지 아닌지로 나눠서 계산한다.
  • 예를 들어 <b>를 blank token이라고 했을 때, greedy에서는 a <b> <b>가 가장 확률값이 높을 수 있지만 prefix search에서는 a <b> b , a b b, a b <b> 등 확률 값을 모두 더 했더니 3번째 타임스텝에서 greedy의 결과인 a 보다 prefix search 결과인 ab라는 확률값이 더 높을 수 있다.
  • 이처럼 blank에 따라 결과값이 바뀔 수 있기 때문에 blank를 잘 고려해줘야 되는 것이 CTC 모델의 특징이다.

Seq2Seq (AED)

  • encoder가 먼저 한번에 계산되고 이를 이용하여 decoder에서 각 타임 스텝마다 하나씩 인퍼런스 하는 구조이다.
  • beam size=5 라고 가정하고 t=a 일때, topk 5개를 뽑고 t=(a+1)일 때 뽑았던 topk 5개를 모두 input으로 사용하여 또 topk를 뽑는다. 그러면 총 25개가 나올텐데 그중에서 누적확률로 가장 확률값이 높은 5개를 또 뽑는다. 그리고 또 t=(a+2)를 진행하는 형태이다.
  • 이렇게 beam search는 누적 확률 중에 가장 높은 beam 개수를 가지고 inference를 하기 때문에 deterministic 한 특징을 갖는다. non-deterministic 한 디코딩을 하기 위해서는 sampling 방법을 사용한다.
  • top-k sampling 같은 경우는 확률이 높은 순서대로 k개의 인덱스만 남겨두고 나머지는 -inf로 바꾼 다음에 softmax를 해서 0으로 바꾼다. 그러면 k개의 인덱스만 어떤 확률 값을 가지고 있을 것이고 (topk가 20일 때를 예로 들면 softmax를 했기 때문에 20개의 합이 1이 될 것이고 나머지는 0의 확률을 가지고 있을 것이다), 이때 idx_next = torch.multinomial(probs, num_samples=1) 를 하면 num_samples 만큼 확률값 기반으로 뽑는다.
  • 그러니까 top-k나 아래에 설명할 top-p sampling을 사용하면 확률 값을 weight로 하여 샘플링을 하기 때문에 랜덤성이 존재하게 되고, 그렇기 때문에 deterministic 하지 않은 것이다. 나도 처음에는 top-k개를 뽑는거면 beam search랑 똑같은 것 같은데 어떻게 다른지 의문을 가지고 있었는데 코드를 보면서 이해하게 되었다.
  • top-p sampling은 top-k sampling 에서 조금 더 발전한 것이다. 정적으로 k개를 정하지 말고, 확률 p값으로 계산하여 각 타임스텝에서 뽑힐 k개의 개수가 달라지도록 동적으로 sampling 하는 방법이다. 타임스텝 t에서 각 확률 값을 sort하고 torch.cumsum 을 사용하여 높은 확률부터 누적해서 더한다. 그 값들과 자기 확률과의 차이를 구해서 그 차가 p 이상이 되면 top-k에서 했던 방법처럼 0으로 바꿔주고 그렇지 않으면 그대로 유지한다. 결국 top-p sampling도 일정 확률값 차이가 나면 0으로 바꿔서 고려 대상에서 제외시키고 idx_next = torch.multinomial(probs, num_samples=1)을 사용해서 weight 기반 샘플링을 해주는 것이다.

RNN-T

  • 인코더의 타임스텝과 디코더의 타임스텝을 모두 고려해줘야하는 것이 특징이다.
  • 아래의 espnet 코드를 기준으로 설명하면, 인코더의 타임스텝을 기준으로 계산되고 blank가 나오면 B_hyps에 저장하고 다른 토큰들은 A_hyps에 저장한다. B_hyps에 beam size가 될 때까지 인코더의 타임스텝은 고정하고 디코더의 타임스텝만 돌아가면서 뽑는다. 이때 A_hyps의 가장 큰 확률값을 가진 토큰을 넣어 topk를 뽑고 그 중에 blank가 나오면 B_hyps에 저장하고 그렇지 않으면 위의 과정을 반복한다.
  • 그렇게 해서 B_hyps에 beam size까지 채워지면 중단하고 encoder 타임스텝을 한칸 이동하고, B_hyps에 있던 것을 A_hyps로 옮기고, B_hyps는 다시 빈 리스트로 만든다.
  • 즉, 인코더의 다음 스텝으로 갈 놈들만 남겨놓고 나머지는 디코더의 스텝만 돌려서 beam size 만큼을 찾아가는 과정이다.
  • espnet 구현체

      def recognize_beam(self, h, recog_args, rnnlm=None):
          """Beam search implementation for transformer-transducer.
          Args:
              h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
              recog_args (Namespace): argument Namespace containing options
              rnnlm (torch.nn.Module): language model module
          Returns:
              nbest_hyps (list of dicts): n-best decoding results
          """
          beam = recog_args.beam_size
          k_range = min(beam, self.odim)
          nbest = recog_args.nbest
          normscore = recog_args.score_norm_transducer
    
          B_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}]
    
          for i, hi in enumerate(h):
              A_hyps = B_hyps
              B_hyps = []
    
              while True:
                  new_hyp = max(A_hyps, key=lambda x: x['score'])
                  A_hyps.remove(new_hyp)
    
                  ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0))
                  ys_mask = to_device(self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0))
                  y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache'])
    
                  ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)
    
                  for k in six.moves.range(self.odim):
                      beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                                  'yseq': new_hyp['yseq'][:],
                                  'cache': new_hyp['cache']}
    
                      if k == self.blank:
                          B_hyps.append(beam_hyp)
                      else:
                          beam_hyp['yseq'].append(int(k))
                          beam_hyp['cache'] = c
    
                          A_hyps.append(beam_hyp)
    
                  if len(B_hyps) >= k_range: // beam_size (W)
                      break
    
          if normscore:
              nbest_hyps = sorted(
                  B_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
          else:
              nbest_hyps = sorted(
                  B_hyps, key=lambda x: x['score'], reverse=True)[:nbest]
    
          return nbest_hyps
    

RNN-T beam search 개선

  1. state_beam을 도입했다. log space에서 hypothesis setB(다음 타임스텝으로 넘어갈 놈들)의 best hypothesis보다 state_beam + hypothesis setA의 best hypothesis이 더 확률이 낮은 경우, “hypothesis set B의 이미 존재하는 hypos” 들이, “hypothesis setA 로부터 확장될 hypothesises들”보다 이미 더욱 좋은 hypothesises이라고 가정해서 while loop 를 끝내 버려서 그 때까지의 hypothesis를 가지고 다음 타임스텝으로 간다.
  2. 또한 A에 추가되는 hypothesises의 개수를 제한하는 expand_beam을 도입했다. 기존에는 time=t에서 hypothesises setA에 추가할 때 beam size만큼 for loop를 돌려서 추가를 했다. 하지만 더 pruning을 하기 위해 해당 토큰의 확률값 >= 타임스텝에서 가장 높은 확률값(greedy) (blank 제외) - expand_beam 을 성립하는 것만 hypothesises setA에 추가한다. greedy한 확률값에서 expand_beam 정도까지만 허락해주고 그 밑은 그냥 넘어가겠다 라는 뜻이다. 두 개의 하이퍼파라미터를 도입하여 pruning을 진행했다.
  • espnet 구현체

      def recognize_beam_facebook(self, h, recog_args, rnnlm=None, state_beam, expand_beam):
          """facebook Beam search implementation for rnn-transducer.
          Args:
              h (torch.Tensor): encoder hidden state sequences (maxlen_in, Henc)
              recog_args (Namespace): argument Namespace containing options
              rnnlm (torch.nn.Module): language model module
              state_beam: ...
              expand_beam: ...
          Returns:
              nbest_hyps (list of dicts): n-best decoding results
          """
          beam = recog_args.beam_size
          k_range = min(beam, self.odim)
          nbest = recog_args.nbest
          normscore = recog_args.score_norm_transducer
    
          B_hyps = [{'score': 0.0, 'yseq': [self.blank], 'cache': None}]
    
          for i, hi in enumerate(h):
              A_hyps = B_hyps
              B_hyps = []
    
              while True:
                  new_hyp = max(A_hyps, key=lambda x: x['score'])
                  a_best_hyp = max(A_hyps, key=lambda x: x['score'])
                  b_best_hyp = max(B_hyps, key=lambda x: x['score'])
                    
                  if log(b_best_hyp) >= state_beam + log(a_best_hyp):
                      break
                    
                  A_hyps.remove(new_hyp)
    
                  ys = to_device(self, torch.tensor(new_hyp['yseq']).unsqueeze(0))
                  ys_mask = to_device(self, subsequent_mask(len(new_hyp['yseq'])).unsqueeze(0))
                  y, c = self.forward_one_step(ys, ys_mask, new_hyp['cache'])
    
                  ytu = torch.log_softmax(self.joint(hi, y[0]), dim=0)
    
                  best_prob = max(ytu[1:])
                    
                  for k in six.moves.range(self.odim):
                      if k == self.blank:
                          beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                                  'yseq': new_hyp['yseq'][:],
                                  'cache': new_hyp['cache']}
    
                          B_hyps.append(beam_hyp)
            
                      else:
                          if float(ytu[k]) >= log(best_prob) - expand_beam:
                              beam_hyp = {'score': new_hyp['score'] + float(ytu[k]),
                                  'yseq': new_hyp['yseq'][:],
                                  'cache': new_hyp['cache']}
    
                              beam_hyp['yseq'].append(int(k))
                              beam_hyp['cache'] = c
    
                              A_hyps.append(beam_hyp)
    
                  if len(B_hyps) >= k_range: // beam_size (W)
                      break
    
          if normscore:
              nbest_hyps = sorted(
                  B_hyps, key=lambda x: x['score'] / len(x['yseq']), reverse=True)[:nbest]
          else:
              nbest_hyps = sorted(
                  B_hyps, key=lambda x: x['score'], reverse=True)[:nbest]
    
          return nbest_hyps
    

Leave a comment