PyO3: Add pytorch like .to() operator to candle.Tensor (#1100)

* add `.to()` operator

* Only allow each value to be provided once via `args` or `kwargs`
This commit is contained in:
Lukas Kreussel
2023-10-19 22:46:21 +02:00
committed by GitHub
parent 93c25e8844
commit 6684b7127a
3 changed files with 176 additions and 0 deletions

View File

@ -381,6 +381,11 @@ class Tensor:
Transposes the tensor.
"""
pass
def to(self, *args, **kwargs) -> Tensor:
"""
Performs Tensor dtype and/or device conversion.
"""
pass
def to_device(self, device: Union[str, Device]) -> Tensor:
"""
Move the tensor to a new device.