DeepLearning
[PyTorch] torch.load 시, 학습 환경과 달라 모델 로딩이 되지 않는 에러 (map_location으로 해결)
daewooki
2022. 6. 24. 19:47
반응형
에러 발생 내용
Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU. |
CUDA 환경에서 학습한 모델을 CPU로 가져와서 모델을 사용할 때 load_state_dict 함수를 사용해서 모델을 로딩한다.
이 때, load_state_dict(torch.load(model_path)) 위의 에러 메시지와 함께 모델이 정상적으로 올라가지 않을 때가 있다.
이 때 에러 발생 내용에 나와있듯이 map_location 파라미터를 추가해서 모델을 사용하고자 하는 장비의 device를 정해주어야한다.
1
2
3
4
5
6
7
|
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Model()
model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device)
|
cs |
반응형