본문 바로가기

개발/python

[pytorch] Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

해당 오류는 model과 입력이 되는 data의 type이 달라서 발생하는 오류이다.

 

해당 오류의 대부분은 model이나 data가 하나는 gpu에 하나는 cpu에 올라와 있는 상태로 연산할 때 발생한다.

 

따라서 다음과 같은 사항을 체크하여 오류를 해결한다.

 

1. model을 gpu에 정확히 올렸는지,

model = torch.load("...")
model.to("cuda")

 

2. dataset의 data와 target(label)을 gpu에 정확히 올렸는지,

for data, target in trainloader:
    data, target = data.to(device), target.to(device)
    #...
    #training
    #...

 

3. gpu를 사용하지 않을 경우, model이나 dataset의 data와 target을 gpu에서 정확히 내렸는지,

model = model.detach().cpu()
for data, target in trainloader:
    data = data.detach().cpu()
    target = target.detach().cpu()