Sketch a quantized llama using the pyo3 api. (#715)

* Sketch a quantized llama using the pyo3 api.

* Add more ops.

* Expose a few more functions to use in the quantized model.

* Rope embeddings.

* Get the forward pass to work.
This commit is contained in:
Laurent Mazare
2023-09-02 12:26:05 +02:00
committed by GitHub
parent dabaa479b9
commit e8e33752f4
3 changed files with 277 additions and 6 deletions

View File

@ -145,6 +145,22 @@ pydtype!(bf16, f32::from);
pydtype!(f32, |v| v);
pydtype!(f64, |v| v);
fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
let rank = t.rank();
if 0 <= dim {
let dim = dim as usize;
if rank <= dim {
::candle::bail!("dimension index {dim} is too large for tensor rank {rank}")
}
Ok(dim)
} else {
if (rank as i64) < -dim {
::candle::bail!("dimension index {dim} is too low for tensor rank {rank}")
}
Ok((rank as i64 + dim) as usize)
}
}
// TODO: Something similar to this should probably be a part of candle core.
trait MapDType {
type Output;
@ -182,7 +198,10 @@ impl PyTensor {
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
} else {
Err(PyTypeError::new_err("incorrect type for tensor"))?
let ty = vs.as_ref(py).get_type();
Err(PyTypeError::new_err(format!(
"incorrect type {ty} for tensor"
)))?
};
Ok(Self(tensor))
}
@ -295,10 +314,31 @@ impl PyTensor {
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
}
fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
}
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
}
fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
}
fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
}
fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
}
fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
}
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
Ok(PyTensor(
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
@ -346,6 +386,17 @@ impl PyTensor {
Ok(Self(tensor))
}
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
(&self.0 / &rhs.0).map_err(wrap_err)?
} else if let Ok(rhs) = rhs.extract::<f64>() {
(&self.0 / rhs).map_err(wrap_err)?
} else {
Err(PyTypeError::new_err("unsupported rhs for div"))?
};
Ok(Self(tensor))
}
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
}
@ -374,7 +425,8 @@ impl PyTensor {
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
}
fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
}
@ -400,6 +452,16 @@ impl PyTensor {
Ok(PyTensor(mean))
}
fn flatten_from(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
}
fn flatten_to(&self, dim: i64) -> PyResult<Self> {
let dim = actual_dim(self, dim).map_err(wrap_err)?;
Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
}
fn flatten_all(&self) -> PyResult<Self> {
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
}
@ -467,7 +529,11 @@ impl PyTensor {
/// Concatenate the tensors across one axis.
#[pyfunction]
fn cat(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
if tensors.is_empty() {
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
}
let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?;
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
Ok(PyTensor(tensor))
@ -595,16 +661,27 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
}
#[pyfunction]
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<PyObject> {
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
let mut file = std::fs::File::open(path)?;
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
let res = ggml
let tensors = ggml
.tensors
.into_iter()
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
.collect::<::candle::Result<Vec<_>>>()
.map_err(wrap_err)?;
Ok(res.into_py_dict(py).to_object(py))
let tensors = tensors.into_py_dict(py).to_object(py);
let hparams = [
("n_vocab", ggml.hparams.n_vocab),
("n_embd", ggml.hparams.n_embd),
("n_mult", ggml.hparams.n_mult),
("n_head", ggml.hparams.n_head),
("n_layer", ggml.hparams.n_layer),
("n_rot", ggml.hparams.n_rot),
("ftype", ggml.hparams.ftype),
];
let hparams = hparams.into_py_dict(py).to_object(py);
Ok((tensors, hparams))
}
#[pyfunction]
@ -651,11 +728,33 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
Ok(())
}
#[pyfunction]
fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
let dim = actual_dim(&t, dim).map_err(wrap_err)?;
let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
Ok(PyTensor(sm))
}
#[pyfunction]
fn silu(t: PyTensor) -> PyResult<PyTensor> {
let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
Ok(PyTensor(s))
}
fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
m.add_function(wrap_pyfunction!(silu, m)?)?;
m.add_function(wrap_pyfunction!(softmax, m)?)?;
Ok(())
}
#[pymodule]
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
let utils = PyModule::new(py, "utils")?;
candle_utils(py, utils)?;
m.add_submodule(utils)?;
let nn = PyModule::new(py, "nn")?;
candle_nn_m(py, nn)?;
m.add_submodule(nn)?;
m.add_class::<PyTensor>()?;
m.add_class::<PyQTensor>()?;
m.add_class::<PyDType>()?;