convert pytorch's tensor in Python API (#1172)

* convert pytorch's tensor

* separate tests for convert pytorch tensor
This commit is contained in:
andrew
2023-10-26 01:39:14 +07:00
committed by GitHub
parent 0acd16751d
commit 6a446d9d73
3 changed files with 43 additions and 0 deletions

View File

@ -0,0 +1,14 @@
import candle
import torch
# convert from candle tensor to torch tensor
t = candle.randn((3, 512, 512))
torch_tensor = t.to_torch()
print(torch_tensor)
print(type(torch_tensor))
# convert from torch tensor to candle tensor
t = torch.randn((3, 512, 512))
candle_tensor = candle.Tensor(t)
print(candle_tensor)
print(type(candle_tensor))