mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
convert pytorch's tensor in Python API (#1172)
* convert pytorch's tensor * separate tests for convert pytorch tensor
This commit is contained in:
14
candle-pyo3/test_pytorch.py
Normal file
14
candle-pyo3/test_pytorch.py
Normal file
@ -0,0 +1,14 @@
|
||||
import candle
|
||||
import torch
|
||||
|
||||
# convert from candle tensor to torch tensor
|
||||
t = candle.randn((3, 512, 512))
|
||||
torch_tensor = t.to_torch()
|
||||
print(torch_tensor)
|
||||
print(type(torch_tensor))
|
||||
|
||||
# convert from torch tensor to candle tensor
|
||||
t = torch.randn((3, 512, 512))
|
||||
candle_tensor = candle.Tensor(t)
|
||||
print(candle_tensor)
|
||||
print(type(candle_tensor))
|
Reference in New Issue
Block a user