mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Sketch a quantized llama using the pyo3 api. (#715)
* Sketch a quantized llama using the pyo3 api. * Add more ops. * Expose a few more functions to use in the quantized model. * Rope embeddings. * Get the forward pass to work.
This commit is contained in:
@ -16,6 +16,7 @@ doc = false
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", version = "0.2.1", package = "candle-core" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.2.1" }
|
||||
half = { workspace = true }
|
||||
pyo3 = { version = "0.19.0", features = ["extension-module"] }
|
||||
|
||||
|
171
candle-pyo3/quant-llama.py
Normal file
171
candle-pyo3/quant-llama.py
Normal file
@ -0,0 +1,171 @@
|
||||
# This example shows how the candle Python api can be used to replicate llama.cpp.
|
||||
import os
|
||||
import sys
|
||||
|
||||
# The "import candle" statement below works if there is a "candle.so" file in sys.path.
|
||||
# Here we check for shared libraries that can be used in the build directory.
|
||||
BUILD_DIR = "./target/release-with-debug"
|
||||
so_file = BUILD_DIR + "/candle.so"
|
||||
if os.path.islink(so_file): os.remove(so_file)
|
||||
for lib_file in ["libcandle.dylib", "libcandle.so"]:
|
||||
lib_file_ = BUILD_DIR + "/" + lib_file
|
||||
if os.path.isfile(lib_file_):
|
||||
os.symlink(lib_file, so_file)
|
||||
sys.path.insert(0, BUILD_DIR)
|
||||
break
|
||||
|
||||
import candle
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
|
||||
def masked_fill(on_false, mask, on_true):
|
||||
shape = mask.shape
|
||||
on_true = candle.tensor(on_true).broadcast_as(shape)
|
||||
return mask.where_cond(on_true, on_false)
|
||||
|
||||
class RmsNorm:
|
||||
def __init__(self, qtensor):
|
||||
self.weight = qtensor.dequantize()
|
||||
|
||||
def __call__(self, x):
|
||||
b_size, seq_len, hidden_size = x.shape
|
||||
norm_x = x.sqr().sum_keepdim(2) / hidden_size
|
||||
x_normed = x.broadcast_div((norm_x + 1e-5).sqrt())
|
||||
return x_normed.broadcast_mul(self.weight)
|
||||
|
||||
class QuantizedLayer:
|
||||
def __init__(self, layer_idx, hparams, all_tensors, cos_sin):
|
||||
p = f"layers.{layer_idx}"
|
||||
self.attention_wq = all_tensors[f"{p}.attention.wq.weight"]
|
||||
self.attention_wk = all_tensors[f"{p}.attention.wk.weight"]
|
||||
self.attention_wv = all_tensors[f"{p}.attention.wv.weight"]
|
||||
self.attention_wo = all_tensors[f"{p}.attention.wo.weight"]
|
||||
self.ffw1 = all_tensors[f"{p}.feed_forward.w1.weight"]
|
||||
self.ffw2 = all_tensors[f"{p}.feed_forward.w2.weight"]
|
||||
self.ffw3 = all_tensors[f"{p}.feed_forward.w3.weight"]
|
||||
self.attn_norm = RmsNorm(all_tensors[f"{p}.attention_norm.weight"])
|
||||
self.ffn_norm = RmsNorm(all_tensors[f"{p}.ffn_norm.weight"])
|
||||
|
||||
self.n_head = hparams["n_head"]
|
||||
self.n_kv_head = self.n_head
|
||||
self.head_dim = hparams["n_embd"] // self.n_head
|
||||
|
||||
self.kv_cache = None
|
||||
self.cos = cos_sin[0]
|
||||
self.sin = cos_sin[1]
|
||||
|
||||
def __call__(self, x, mask, index_pos):
|
||||
residual = x
|
||||
x = self.attn_norm(x)
|
||||
attn = self.forward_attn(x, mask, index_pos)
|
||||
x = attn + residual
|
||||
|
||||
residual = x
|
||||
x = self.ffn_norm(x)
|
||||
w1 = self.ffw1.matmul_t(x)
|
||||
w3 = self.ffw3.matmul_t(x)
|
||||
mlp = self.ffw2.matmul_t(candle.nn.silu(w1) * w3)
|
||||
|
||||
return mlp + residual
|
||||
|
||||
def forward_attn(self, x, mask, index_pos):
|
||||
b_size, seq_len, n_embd = x.shape
|
||||
q = self.attention_wq.matmul_t(x)
|
||||
k = self.attention_wk.matmul_t(x)
|
||||
v = self.attention_wv.matmul_t(x)
|
||||
|
||||
q = q.reshape((b_size, seq_len, self.n_head, self.head_dim)).transpose(1, 2)
|
||||
k = k.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
v = v.reshape((b_size, seq_len, self.n_kv_head, self.head_dim)).transpose(1, 2)
|
||||
|
||||
q = self.apply_rotary_emb(q, index_pos)
|
||||
k = self.apply_rotary_emb(k, index_pos)
|
||||
|
||||
if self.kv_cache is not None and index_pos > 0:
|
||||
prev_k, prev_v = self.kv_cache
|
||||
k = candle.cat([prev_k, k], 2).contiguous()
|
||||
v = candle.cat([prev_v, v], 2).contiguous()
|
||||
|
||||
self.kv_cache = (k, v)
|
||||
|
||||
# TODO: maybe repeat k/v here if we start supporting MQA.
|
||||
|
||||
att = q.matmul(k.t()) / self.head_dim**0.5
|
||||
mask = mask.broadcast_as(att.shape)
|
||||
att = masked_fill(att, mask, float("-inf"))
|
||||
att = candle.nn.softmax(att, -1)
|
||||
y = att.matmul(v.contiguous())
|
||||
y = y.transpose(1, 2).reshape((b_size, seq_len, n_embd))
|
||||
return self.attention_wo.matmul_t(y)
|
||||
|
||||
def apply_rotary_emb(self, x, index_pos):
|
||||
(b_size, n_head, seq_len, n_embd) = x.shape
|
||||
cos = self.cos.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
|
||||
sin = self.sin.narrow(0, index_pos, seq_len).reshape((seq_len, n_embd//2, 1))
|
||||
x = x.reshape((b_size, n_head, seq_len, n_embd//2, 2))
|
||||
x0 = x.narrow(-1, 0, 1)
|
||||
x1 = x.narrow(-1, 1, 1)
|
||||
y0 = x0.broadcast_mul(cos) - x1.broadcast_mul(sin)
|
||||
y1 = x0.broadcast_mul(sin) + x1.broadcast_mul(cos)
|
||||
rope = candle.cat([y0, y1], -1)
|
||||
return rope.flatten_from(-2)
|
||||
|
||||
def precompute_freqs_cis(hparams, freq_base):
|
||||
head_dim = hparams["n_embd"] // hparams["n_head"]
|
||||
theta = [1.0 / freq_base ** (i / head_dim) for i in range(0, head_dim, 2)]
|
||||
theta = candle.tensor(theta)
|
||||
idx_theta = [float(i) for i in range(MAX_SEQ_LEN)]
|
||||
idx_theta = candle.tensor(idx_theta).reshape((MAX_SEQ_LEN, 1))
|
||||
m = idx_theta.matmul(theta.unsqueeze(0))
|
||||
print(m.shape)
|
||||
return (m.cos(), m.sin())
|
||||
|
||||
class QuantizedLlama:
|
||||
def __init__(self, hparams, all_tensors):
|
||||
self.tok_embeddings = all_tensors["tok_embeddings.weight"].dequantize()
|
||||
self.norm = RmsNorm(all_tensors["norm.weight"])
|
||||
self.output = all_tensors["output.weight"]
|
||||
self.layers = []
|
||||
cos_sin = precompute_freqs_cis(hparams, 10000.)
|
||||
for layer_idx in range(hparams["n_layer"]):
|
||||
layer = QuantizedLayer(layer_idx, hparams, all_tensors, cos_sin)
|
||||
self.layers.append(layer)
|
||||
|
||||
def __call__(self, token, index_pos):
|
||||
b_size, seq_len = token.shape
|
||||
vocab_size, hidden_size = self.tok_embeddings.shape
|
||||
token = token.reshape((b_size * seq_len,))
|
||||
x = self.tok_embeddings.index_select(token, 0)
|
||||
x = x.reshape((b_size, seq_len, hidden_size))
|
||||
|
||||
mask = [int(j > i) for j in range(seq_len) for i in range(seq_len)]
|
||||
mask = candle.tensor(mask).reshape((seq_len, seq_len))
|
||||
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask, index_pos)
|
||||
return x
|
||||
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
raise ValueError("missing weight file argument")
|
||||
filename = sys.argv[1]
|
||||
if filename.endswith("gguf"):
|
||||
all_tensors = candle.load_gguf(sys.argv[1])
|
||||
hparams = None
|
||||
else:
|
||||
all_tensors, hparams = candle.load_ggml(sys.argv[1])
|
||||
print(hparams)
|
||||
model = QuantizedLlama(hparams, all_tensors)
|
||||
|
||||
tokens = [1]
|
||||
for token_idx in range(1):
|
||||
print(tokens)
|
||||
last_token = tokens[-1]
|
||||
lt = candle.tensor([last_token]).unsqueeze(0)
|
||||
logits = model(lt, len(tokens))
|
||||
print(logits)
|
||||
next_token = "TODO: sample"
|
||||
tokens.append(next_token)
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@ -145,6 +145,22 @@ pydtype!(bf16, f32::from);
|
||||
pydtype!(f32, |v| v);
|
||||
pydtype!(f64, |v| v);
|
||||
|
||||
fn actual_dim(t: &Tensor, dim: i64) -> ::candle::Result<usize> {
|
||||
let rank = t.rank();
|
||||
if 0 <= dim {
|
||||
let dim = dim as usize;
|
||||
if rank <= dim {
|
||||
::candle::bail!("dimension index {dim} is too large for tensor rank {rank}")
|
||||
}
|
||||
Ok(dim)
|
||||
} else {
|
||||
if (rank as i64) < -dim {
|
||||
::candle::bail!("dimension index {dim} is too low for tensor rank {rank}")
|
||||
}
|
||||
Ok((rank as i64 + dim) as usize)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Something similar to this should probably be a part of candle core.
|
||||
trait MapDType {
|
||||
type Output;
|
||||
@ -182,7 +198,10 @@ impl PyTensor {
|
||||
} else if let Ok(vs) = vs.extract::<Vec<f32>>(py) {
|
||||
Tensor::new(vs.as_slice(), &Cpu).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("incorrect type for tensor"))?
|
||||
let ty = vs.as_ref(py).get_type();
|
||||
Err(PyTypeError::new_err(format!(
|
||||
"incorrect type {ty} for tensor"
|
||||
)))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
@ -295,10 +314,31 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.powf(p).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn index_select(&self, rhs: &Self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.index_select(rhs, dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn matmul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.matmul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_add(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_add(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_sub(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_sub(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_mul(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_mul(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn broadcast_div(&self, rhs: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.broadcast_div(rhs).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn where_cond(&self, on_true: &Self, on_false: &Self) -> PyResult<Self> {
|
||||
Ok(PyTensor(
|
||||
self.0.where_cond(on_true, on_false).map_err(wrap_err)?,
|
||||
@ -346,6 +386,17 @@ impl PyTensor {
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
fn __truediv__(&self, rhs: &PyAny) -> PyResult<Self> {
|
||||
let tensor = if let Ok(rhs) = rhs.extract::<Self>() {
|
||||
(&self.0 / &rhs.0).map_err(wrap_err)?
|
||||
} else if let Ok(rhs) = rhs.extract::<f64>() {
|
||||
(&self.0 / rhs).map_err(wrap_err)?
|
||||
} else {
|
||||
Err(PyTypeError::new_err("unsupported rhs for div"))?
|
||||
};
|
||||
Ok(Self(tensor))
|
||||
}
|
||||
|
||||
fn reshape(&self, shape: PyShape) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.reshape(shape).map_err(wrap_err)?))
|
||||
}
|
||||
@ -374,7 +425,8 @@ impl PyTensor {
|
||||
Ok(PyTensor(self.0.transpose(dim1, dim2).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn narrow(&self, dim: usize, start: usize, len: usize) -> PyResult<Self> {
|
||||
fn narrow(&self, dim: i64, start: usize, len: usize) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.narrow(dim, start, len).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
@ -400,6 +452,16 @@ impl PyTensor {
|
||||
Ok(PyTensor(mean))
|
||||
}
|
||||
|
||||
fn flatten_from(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.flatten_from(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn flatten_to(&self, dim: i64) -> PyResult<Self> {
|
||||
let dim = actual_dim(self, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(self.0.flatten_to(dim).map_err(wrap_err)?))
|
||||
}
|
||||
|
||||
fn flatten_all(&self) -> PyResult<Self> {
|
||||
Ok(PyTensor(self.0.flatten_all().map_err(wrap_err)?))
|
||||
}
|
||||
@ -467,7 +529,11 @@ impl PyTensor {
|
||||
|
||||
/// Concatenate the tensors across one axis.
|
||||
#[pyfunction]
|
||||
fn cat(tensors: Vec<PyTensor>, dim: usize) -> PyResult<PyTensor> {
|
||||
fn cat(tensors: Vec<PyTensor>, dim: i64) -> PyResult<PyTensor> {
|
||||
if tensors.is_empty() {
|
||||
return Err(PyErr::new::<PyValueError, _>("empty input to cat"));
|
||||
}
|
||||
let dim = actual_dim(&tensors[0], dim).map_err(wrap_err)?;
|
||||
let tensors = tensors.into_iter().map(|t| t.0).collect::<Vec<_>>();
|
||||
let tensor = Tensor::cat(&tensors, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(tensor))
|
||||
@ -595,16 +661,27 @@ fn load_safetensors(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<PyObject> {
|
||||
fn load_ggml(path: &str, py: Python<'_>) -> PyResult<(PyObject, PyObject)> {
|
||||
let mut file = std::fs::File::open(path)?;
|
||||
let ggml = ::candle::quantized::ggml_file::Content::read(&mut file).map_err(wrap_err)?;
|
||||
let res = ggml
|
||||
let tensors = ggml
|
||||
.tensors
|
||||
.into_iter()
|
||||
.map(|(key, qtensor)| Ok((key, PyQTensor(Arc::new(qtensor)).into_py(py))))
|
||||
.collect::<::candle::Result<Vec<_>>>()
|
||||
.map_err(wrap_err)?;
|
||||
Ok(res.into_py_dict(py).to_object(py))
|
||||
let tensors = tensors.into_py_dict(py).to_object(py);
|
||||
let hparams = [
|
||||
("n_vocab", ggml.hparams.n_vocab),
|
||||
("n_embd", ggml.hparams.n_embd),
|
||||
("n_mult", ggml.hparams.n_mult),
|
||||
("n_head", ggml.hparams.n_head),
|
||||
("n_layer", ggml.hparams.n_layer),
|
||||
("n_rot", ggml.hparams.n_rot),
|
||||
("ftype", ggml.hparams.ftype),
|
||||
];
|
||||
let hparams = hparams.into_py_dict(py).to_object(py);
|
||||
Ok((tensors, hparams))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
@ -651,11 +728,33 @@ fn candle_utils(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn softmax(t: PyTensor, dim: i64) -> PyResult<PyTensor> {
|
||||
let dim = actual_dim(&t, dim).map_err(wrap_err)?;
|
||||
let sm = candle_nn::ops::softmax(&t.0, dim).map_err(wrap_err)?;
|
||||
Ok(PyTensor(sm))
|
||||
}
|
||||
|
||||
#[pyfunction]
|
||||
fn silu(t: PyTensor) -> PyResult<PyTensor> {
|
||||
let s = candle_nn::ops::silu(&t.0).map_err(wrap_err)?;
|
||||
Ok(PyTensor(s))
|
||||
}
|
||||
|
||||
fn candle_nn_m(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
m.add_function(wrap_pyfunction!(silu, m)?)?;
|
||||
m.add_function(wrap_pyfunction!(softmax, m)?)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[pymodule]
|
||||
fn candle(py: Python<'_>, m: &PyModule) -> PyResult<()> {
|
||||
let utils = PyModule::new(py, "utils")?;
|
||||
candle_utils(py, utils)?;
|
||||
m.add_submodule(utils)?;
|
||||
let nn = PyModule::new(py, "nn")?;
|
||||
candle_nn_m(py, nn)?;
|
||||
m.add_submodule(nn)?;
|
||||
m.add_class::<PyTensor>()?;
|
||||
m.add_class::<PyQTensor>()?;
|
||||
m.add_class::<PyDType>()?;
|
||||
|
Reference in New Issue
Block a user