mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Always broadcast magic methods (#1101)
This commit is contained in:
@ -536,7 +536,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __add__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 + &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_add(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 + rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
@ -553,7 +553,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __mul__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 * &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_mul(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 * rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
@ -570,7 +570,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 - &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_sub(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 - rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
@ -583,7 +583,7 @@ impl PyTensor {
|
||||
/// &RETURNS&: Tensor
|
||||
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 / &rhs.0).map_err(wrap_err)?
|
||||
self.0.broadcast_div(&rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 / rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
|
@ -1,5 +1,6 @@
|
||||
import candle
|
||||
from candle import Tensor
|
||||
import pytest
|
||||
|
||||
|
||||
def test_tensor_can_be_constructed():
|
||||
@ -72,3 +73,75 @@ def test_tensor_can_be_scliced_3d():
|
||||
assert t[:, 0, 0].values() == [1, 9]
|
||||
assert t[..., 0].values() == [[1, 5], [9, 13]]
|
||||
assert t[..., 0:2].values() == [[[1, 2], [5, 6]], [[9, 10], [13, 14]]]
|
||||
|
||||
|
||||
def test_tensor_can_be_added():
|
||||
t = Tensor(42.0)
|
||||
result = t + t
|
||||
assert result.values() == 84.0
|
||||
result = t + 2.0
|
||||
assert result.values() == 44.0
|
||||
a = candle.rand((3, 1, 4))
|
||||
b = candle.rand((2, 1))
|
||||
c_native = a.broadcast_add(b)
|
||||
c = a + b
|
||||
assert c.shape == (3, 2, 4)
|
||||
assert c.values() == c_native.values()
|
||||
with pytest.raises(ValueError):
|
||||
d = candle.rand((3, 4, 5))
|
||||
e = candle.rand((4, 6))
|
||||
f = d + e
|
||||
|
||||
|
||||
def test_tensor_can_be_subtracted():
|
||||
t = Tensor(42.0)
|
||||
result = t - t
|
||||
assert result.values() == 0
|
||||
result = t - 2.0
|
||||
assert result.values() == 40.0
|
||||
a = candle.rand((3, 1, 4))
|
||||
b = candle.rand((2, 1))
|
||||
c_native = a.broadcast_sub(b)
|
||||
c = a - b
|
||||
assert c.shape == (3, 2, 4)
|
||||
assert c.values() == c_native.values()
|
||||
with pytest.raises(ValueError):
|
||||
d = candle.rand((3, 4, 5))
|
||||
e = candle.rand((4, 6))
|
||||
f = d - e
|
||||
|
||||
|
||||
def test_tensor_can_be_multiplied():
|
||||
t = Tensor(42.0)
|
||||
result = t * t
|
||||
assert result.values() == 1764.0
|
||||
result = t * 2.0
|
||||
assert result.values() == 84.0
|
||||
a = candle.rand((3, 1, 4))
|
||||
b = candle.rand((2, 1))
|
||||
c_native = a.broadcast_mul(b)
|
||||
c = a * b
|
||||
assert c.shape == (3, 2, 4)
|
||||
assert c.values() == c_native.values()
|
||||
with pytest.raises(ValueError):
|
||||
d = candle.rand((3, 4, 5))
|
||||
e = candle.rand((4, 6))
|
||||
f = d * e
|
||||
|
||||
|
||||
def test_tensor_can_be_divided():
|
||||
t = Tensor(42.0)
|
||||
result = t / t
|
||||
assert result.values() == 1.0
|
||||
result = t / 2.0
|
||||
assert result.values() == 21.0
|
||||
a = candle.rand((3, 1, 4))
|
||||
b = candle.rand((2, 1))
|
||||
c_native = a.broadcast_div(b)
|
||||
c = a / b
|
||||
assert c.shape == (3, 2, 4)
|
||||
assert c.values() == c_native.values()
|
||||
with pytest.raises(ValueError):
|
||||
d = candle.rand((3, 4, 5))
|
||||
e = candle.rand((4, 6))
|
||||
f = d / e
|
||||
|
Reference in New Issue
Block a user