개발/python
[pytorch] Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
발전가
2022. 10. 31. 23:08
해당 오류는 model과 입력이 되는 data의 type이 달라서 발생하는 오류이다.
해당 오류의 대부분은 model이나 data가 하나는 gpu에 하나는 cpu에 올라와 있는 상태로 연산할 때 발생한다.
따라서 다음과 같은 사항을 체크하여 오류를 해결한다.
1. model을 gpu에 정확히 올렸는지,
model = torch.load("...")
model.to("cuda")
2. dataset의 data와 target(label)을 gpu에 정확히 올렸는지,
for data, target in trainloader:
data, target = data.to(device), target.to(device)
#...
#training
#...
3. gpu를 사용하지 않을 경우, model이나 dataset의 data와 target을 gpu에서 정확히 내렸는지,
model = model.detach().cpu()
for data, target in trainloader:
data = data.detach().cpu()
target = target.detach().cpu()