mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Expose a couple more ops.
This commit is contained in:
@ -155,7 +155,7 @@ impl PyTensor {
|
|||||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
(&self.0 + rhs).map_err(wrap_err)?
|
(&self.0 + rhs).map_err(wrap_err)?
|
||||||
} else {
|
} else {
|
||||||
Err(PyTypeError::new_err("unsupported for add"))?
|
Err(PyTypeError::new_err("unsupported rhs for add"))?
|
||||||
};
|
};
|
||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
@ -170,7 +170,7 @@ impl PyTensor {
|
|||||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
(&self.0 * rhs).map_err(wrap_err)?
|
(&self.0 * rhs).map_err(wrap_err)?
|
||||||
} else {
|
} else {
|
||||||
Err(PyTypeError::new_err("unsupported for mul"))?
|
Err(PyTypeError::new_err("unsupported rhs for mul"))?
|
||||||
};
|
};
|
||||||
Ok(Self(tensor))
|
Ok(Self(tensor))
|
||||||
}
|
}
|
||||||
@ -179,21 +179,98 @@ impl PyTensor {
|
|||||||
self.__mul__(rhs)
|
self.__mul__(rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn __sub__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||||
|
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||||
|
(&self.0 - &rhs.0).map_err(wrap_err)?
|
||||||
|
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||||
|
(&self.0 - rhs).map_err(wrap_err)?
|
||||||
|
} else {
|
||||||
|
Err(PyTypeError::new_err("unsupported rhs for sub"))?
|
||||||
|
};
|
||||||
|
Ok(Self(tensor))
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: Add a PyShape type?
|
// TODO: Add a PyShape type?
|
||||||
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
fn reshape(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn broadcast_as(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.broadcast_as(shape).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_left(&self, shape: Vec<usize>) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.broadcast_left(shape).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn squeeze(&self, dim: usize) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.squeeze(dim).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unsqueeze(&self, dim: usize) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.unsqueeze(dim).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self, index: usize) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.get(index).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn transpose(&self, dim1: usize, dim2: usize) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sum_all(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.sum_all().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn flatten_all(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn t(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.t().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn contiguous(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.contiguous().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_contiguous(&self) -> bool {
|
||||||
|
self.0.is_contiguous()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_fortran_contiguous(&self) -> bool {
|
||||||
|
self.0.is_fortran_contiguous()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn detach(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.detach().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn copy(&self) -> PyResult<Self> {
|
||||||
|
Ok(PyTensor(self.0.copy().map_err(wrap_err)?))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Concatenate the tensors across one axis.
|
||||||
|
#[pyfunction]
|
||||||
|
fn cat(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||||
|
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||||
|
let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
|
||||||
|
Ok(PyTensor(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pyfunction]
|
#[pyfunction]
|
||||||
fn add(tensor: &PyTensor, f: f64) -> PyResult<PyTensor> {
|
fn stack(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||||
let tensor = (&tensor.0 + f).map_err(wrap_err)?;
|
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||||
|
let tensor = Tensor::stack(&tensors, dim).map_err(wrap_err)?;
|
||||||
Ok(PyTensor(tensor))
|
Ok(PyTensor(tensor))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[pymodule]
|
#[pymodule]
|
||||||
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
fn candle(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||||
m.add_class::<PyTensor>()?;
|
m.add_class::<PyTensor>()?;
|
||||||
m.add_function(wrap_pyfunction!(add, m)?)?;
|
m.add_function(wrap_pyfunction!(cat, m)?)?;
|
||||||
|
m.add_function(wrap_pyfunction!(stack, m)?)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user