TorchScript & ONNX
Torchscript는 코드를 Eager mode에서 Script mode로 변환한다. Pytorch model을 TorchScript로 변환하는 방법에는 두가지 방법이 있다.
- Tracing 방식
- Script 방식
Tracing
Tracing 방법은 어떤 입력값(data instance)을 사용하여 모델의 구조를 파악하고, 이 입력값의 모델 안에서의 흐름을 통해 모델을 기록하는 방식이다.
조건문을 많이 사용하지 않는 모델의 경우 이 방식을 이용하여 변환하는 것이 적합하다.
또는 인풋 값에 따라 로직이 바뀌지 않는 경우 이 방법을 사용한다.
import torch
import torchvision
# An instance of your model.
model = torchvision.models.resnet18()
# An example input you would normally provide to your model's forward() method.
example = torch.rand(1, 3, 224, 224)
# Use torch.jit.trace to generate a torch.jit.ScriptModule via tracing.
traced_script_module = torch.jit.trace(model, example)
Tracing 방식은 eager model의 코드를 재사용할 수 있는 효과적인 방법이다. 그러나 이 방식을 사용하면 Control-flow나 data structure, python construct가 보존되지 않는다. 따라서 이 방식을 사용하는 경우에는 항상 IR을 검사하여 pytorch model이 올바르게 동작하는지를 확인해줘야한다.
이러한 limitation을 해결하기 위해 Annotation 방식이 고안되었다.
Script
import torch
class MyModule(torch.nn.Module):
def __init__(self, N, M):
super(MyModule, self).__init__()
self.weight = torch.nn.Parameter(torch.rand(N, M))
def forward(self, input):
if input.sum() > 0:
output = self.weight.mv(input)
else:
output = self.weight + input
return output
input값에 따라 영향을 받는 Control-flow 를 사용하고 있기 때문에 tracing 기법은 적합하지 않다.
대신 torch.jit.script()
함수를 통해 모듈을 compile하여 ScriptModule
로 변환한다.
또한, 이 방식은 tracing mode와 다르게 data sample은 전달할 필요가 없다. 오직 model의 instance만 input으로 넣어주면 된다.
my_module = MyModule(10,20)
sm = torch.jit.script(my_module)
Mixing Tracing and Scripting
import torch
def foo(x, y):
return 2 * x + y
traced_foo = torch.jit.trace(foo, (torch.rand(3), torch.rand(3)))
@torch.jit.script
def bar(x):
return traced_foo(x, x)
Example : BERT
from transformers import BertTokenizer, BertModel
from time import perf_counter
import torch
import numpy as np
def timer(f,*args):
start = perf_counter()
f(*args)
return (1000 * (perf_counter() - start))
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', torchscript=True)
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = tokenizer.tokenize(text)
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
# torch CPU
native_model = BertModel.from_pretrained("bert-base-uncased")
print("Torch CPU")
print(np.mean([timer(native_model,tokens_tensor,segments_tensors) for _ in range(100)]))
# torch GPU
native_gpu = native_model.cuda()
tokens_tensor_gpu = tokens_tensor.cuda()
segments_tensors_gpu = segments_tensors.cuda()
print("Torch GPU")
print(np.mean([timer(native_gpu,tokens_tensor_gpu,segments_tensors_gpu) for _ in range(100)]))
# torch.jit.trace on CPU
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
print("TorchScript CPU")
print(np.mean([timer(traced_model, tokens_tensor, segments_tensors) for _ in range(100)]))
# torch.jit.trace on GPU
traced_model_gpu = torch.jit.trace(model.cuda(), [tokens_tensor.cuda(), segments_tensors.cuda()])
print("TorchScript GPU")
print(np.mean([timer(traced_model_gpu, tokens_tensor.cuda(), segments_tensors.cuda()) for _ in range(100)]))
ONNX
위의 설명된 방법으로 TorchScript를 만들었다면, torch.onnx.export
를 사용하여 onnx 모델로 변환할 수 있습니다.
torch.onnx.export
함수 안에서는 torch.nn.Module 보다 torch.jit.ScriptModule 을 인풋으로 요구하고 있습니다.
그래서 TorchScript 모델이 아닌 기본 torch 모델을 넣어도 내부적으로는 torch.jit.trace() 를 사용하여 ScriptModule로 변환하고, 다시 한번 onnx 변환을 하고 있습니다.
자세한 내용은 Tracing vs Scripting 을 참고해주세요.
참고
- PyTorch JIT and TorchScript
- How to Convert a Model from PyTorch to TensorRT and Speed Up Inference
- pytorch 모델 저장과 ONNX 사용
Leave a comment