2 minute read

Uniform Length Batching in PyTorch

  • 전체 토큰 길이가 비슷한 인풋끼리 배치를 이루어주는 방식
  • 그냥 랜덤하게 배치를 묶어주면 길이가 한 데이터를 제외하고는 평균 길이가 10인데 한 데이터 길이가 100인 경우 나머지 데이터들도 모두 100까지 패딩을 채워줘야하는 비효율이 발생함
  • 실제로 랜덤하게 배치를 묶어버리면 예시보다도 더 큰 비효율이 발생하는 경우가 많음
  • 학습 속도 개선 및 장비 활용도 측면에서 길이가 비슷한 인풋끼리 배치를 묶어주게 되면 많이 개선됨
  • 유일한 단점은 배치는 최대한 비빔밥처럼 골고루 섞여야 좋다고 하는데, 길이가 비슷한 놈들끼리 묶인다는 바이어스를 주게되므로 성능에 영향을 미칠수도 있다는 의견도 꽤 있으나 학습 속도 및 장비 활용도 측면에서 상당한 이득을 주기 때문에 많이 사용되는 방식

PyTorch Dataset Class

class ExampleDataset(torch.utils.data.Dataset):
    def __init__(self, datas):
        super(ExampleDataset, self).__init__()
        ...
        ...
        self.inputs = datas["inputs"]
        self.labels = datas["labels"]

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]
  • PyTorch Dataset 클래스는 input, label 페어를 __getitem__() 메서드로 리턴해주는 역할을 수행
  • 여기서 __getitem__()idx를 파라미터로 받게 되는데, 넘겨지는 인덱스에 따라 다른 데이터를 리턴해주는 방식
  • 이러한 점을 이용해서 넘겨지는 idx 순서만 셔플해주면 쉽게 데이터를 셔플해서 가져올 수 있음 (이러한 역할을 해주는 놈이 Sampler 클래스)
# 이렇게 0부터 차례차례 순서로 idx를 건네줄수도 있지만
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]

# 단순하게 idx 순서만 셔플함으로써 데이터를 가져오는 순서를 쉽게 셔플할 수 있음
[2, 10, 11, 13, 1, 15, 4, 14, 9, 5, 12, 8, 0, 7, 6, 3]

Bucketing Sampler

  • 데이터를 완전 무작위하게 셔플하는게 아니라, 어떤 기준으로 버킷에 담아두고 셔플시, 이 버킷을 셔플하는 방식의 Sampler
  • 이렇게 하게 되면 묶이는 배치는 유지하면서도 모델 학습에 들어가는 순서는 셔플이 되므로 어떤 기준으로 배치를 묶는 경우 많이 사용되는 방식
  • 구현은 아래처럼 간단하게 가능
    • self.bins를 2차원 list로 만들어서 안에 담긴 list들이 버킷 역할을 함.
class BucketingSampler(Sampler):
    def __init__(self, data_source: torch.utils.data.Dataset, batch_size=1):
        super(BucketingSampler, self).__init__(data_source)
        self.data_source = data_source
        ids = list(range(0, len(data_source)))
        self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]

    def __iter__(self):
        for ids in self.bins:
            yield ids

    def __len__(self):
        return len(self.bins)

    def shuffle(self, epoch):
        np.random.shuffle(self.bins)

Uniform Length Batching Sampler

  • Uniform Length Batching은 위의 Bucketing Sampler에서 토큰 길이를 기준으로 sorting 로직만 추가해주면 됨.
class UniformLengthBatchingSampler(Sampler):
    def __init__(self, data_source: torch.utils.data.Dataset, batch_size=1):
        super(UniformLengthBatchingSampler, self).__init__(data_source)
        self.data_source = data_source
        #
        # 여기에 토큰 길이 기준으로 sorting하는 로직만 추가해주면 끝
        #
        ids = list(range(0, len(data_source)))
        self.bins = [ids[i:i + batch_size] for i in range(0, len(ids), batch_size)]

    def __iter__(self):
        for ids in self.bins:
            yield ids

    def __len__(self):
        return len(self.bins)

    def shuffle(self, epoch):
        np.random.shuffle(self.bins)

collate_fn

  • PyTorch에서 배치로 묶인 인풋을 처리하는 로직이 담긴 함수
  • 보통 여기서 길이가 다른 인풋들에 대하여 padding을 추가하고 텐서형으로 변환하는 코드가 수행됨.
  • Uniform Length Batching을 수행하기 위해서는 Sampler에서 묶일 때까지 패딩을 추가하면 안 되기 때문에 collate_fn을 정의해주어야함.
  • 구현 예시
def collate_fn(batch):
    def seq_length_(p):
        return len(p[0])

    max_seq_sample = max(batch, key=seq_length_)[0]
    max_seq_size = len(max_seq_sample)

    batch_size = len(batch)

    input_ids = torch.zeros(batch_size, max_seq_size).fill_(0).long()
    attention_masks = torch.zeros(batch_size, max_seq_size).fill_(0).long()
    labels = torch.zeros(batch_size, max_seq_size).fill_(0).long()

    for idx in range(batch_size):
        sample = batch[idx]
        sample_input_ids = sample[0]
        sample_attention_masks = sample[1]
        sample_labels = sample[2]

        input_ids[idx].narrow(0, 0, len(sample_input_ids)).copy_(torch.LongTensor(sample_input_ids))
        attention_masks[idx].narrow(0, 0, len(sample_attention_masks)).copy_(torch.LongTensor(sample_attention_masks))
        labels[idx].narrow(0, 0, len(sample_labels)).copy_(torch.LongTensor(sample_labels))

    return {
        "input_ids": input_ids,
        "attention_mask": attention_masks,
        "labels": labels,
    }

How to apply?

  • 이렇게 정의한 Samplercollate_fnDataLoader 클래스 생성시 넘겨주면 적용이 됨.
trainset = ExampleDataset(train_datas)
sampler = UniformLengthBatchingSampler(trainset, batch_size=32)

train_loader = DataLoader(
    dataset=trainset,
    sampler=sampler,
    shuffle=True,
    collate_fn=collate_fn,
)

Leave a comment