개발/python

[pytorch] Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

발전가 2022. 10. 31. 23:08

해당 오류는 model과 입력이 되는 data의 type이 달라서 발생하는 오류이다.

 

해당 오류의 대부분은 model이나 data가 하나는 gpu에 하나는 cpu에 올라와 있는 상태로 연산할 때 발생한다.

 

따라서 다음과 같은 사항을 체크하여 오류를 해결한다.

 

1. model을 gpu에 정확히 올렸는지,

model = torch.load("...")
model.to("cuda")

 

2. dataset의 data와 target(label)을 gpu에 정확히 올렸는지,

for data, target in trainloader:
    data, target = data.to(device), target.to(device)
    #...
    #training
    #...

 

3. gpu를 사용하지 않을 경우, model이나 dataset의 data와 target을 gpu에서 정확히 내렸는지,

model = model.detach().cpu()
for data, target in trainloader:
    data = data.detach().cpu()
    target = target.detach().cpu()