numpyからpytorchの変換で,dtypeを消す
x = np.random.rand(5)
print(x)
torch_x = torch.tensor(x)
print(torch_x)
[0.07875406 0.96178392 0.64953209 0.12122955 0.92958997]
tensor([0.0788, 0.9618, 0.6495, 0.1212, 0.9296], dtype=torch.float64)
numpyの配列を,そのままtorchに変換するとdtypeが残る.
torch_x = torch.tensor(x.astype(np.float32))
print(torch_x)
tensor([0.0788, 0.9618, 0.6495, 0.1212, 0.9296])
消すためには,numpyの型をfloat32にしてからtorchに変換する必要がある.
Tags: