NumPyからPyTorchへの変換時におけるデータ型の統一方法

x = np.random.rand(5)
print(x)
torch_x = torch.tensor(x)
print(torch_x)
[0.07875406 0.96178392 0.64953209 0.12122955 0.92958997]
tensor([0.0788, 0.9618, 0.6495, 0.1212, 0.9296], dtype=torch.float64)

NumPyの配列をそのままPyTorchのテンソルに変換すると、元のdtypeが保持されます。

torch_x = torch.tensor(x.astype(np.float32))
print(torch_x)
tensor([0.0788, 0.9618, 0.6495, 0.1212, 0.9296])

dtype を指定するには、NumPyの配列の型を float32 に変換してからPyTorchのテンソルに変換する必要があります。