본문 바로가기
DeepLearning

[PyTorch] torch.load 시, 학습 환경과 달라 모델 로딩이 되지 않는 에러 (map_location으로 해결)

by daewooki 2022. 6. 24.
반응형

에러 발생 내용

Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

 

CUDA 환경에서 학습한 모델을 CPU로 가져와서 모델을 사용할 때 load_state_dict 함수를 사용해서 모델을 로딩한다.

이 때, load_state_dict(torch.load(model_path)) 위의 에러 메시지와 함께 모델이 정상적으로 올라가지 않을 때가 있다. 

 

이 때 에러 발생 내용에 나와있듯이 map_location 파라미터를 추가해서 모델을 사용하고자 하는 장비의 device를 정해주어야한다. 

 

1
2
3
4
5
6
7
import torch
 
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
model = Model()
model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device)
cs
반응형

댓글