mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Compare commits
18 Commits
Author | SHA1 | Date | |
---|---|---|---|
17313a4226 | |||
0224a749f0 | |||
cd7b877d6b | |||
5aed817f1b | |||
1a183c988a | |||
cac51fe16a | |||
61ddb9535e | |||
9a62c91643 | |||
92106c8762 | |||
9ce4fe6194 | |||
450a49ed1a | |||
6bd61727bc | |||
485ddf2996 | |||
36508a2c93 | |||
3d05f5cf3d | |||
637473cb5e | |||
e27b4700ad | |||
1fdfb58de5 |
@ -43,7 +43,7 @@ candle-onnx = { path = "./candle-onnx", version = "0.9.1" }
|
|||||||
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.16.3", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
|
@ -483,17 +483,22 @@ impl<I: IntDType> Map1 for Gather<'_, I> {
|
|||||||
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
let start_dst_idx = start_dst_idx + i * dst_right_len;
|
||||||
for right_i in 0..dst_right_len {
|
for right_i in 0..dst_right_len {
|
||||||
let dst_idx = start_dst_idx + right_i;
|
let dst_idx = start_dst_idx + right_i;
|
||||||
let index = ids[dst_idx].as_usize();
|
let index = ids[dst_idx];
|
||||||
if index >= src_dim_len {
|
if index == I::max_value() {
|
||||||
Err(Error::InvalidIndex {
|
dst[dst_idx] = T::zero();
|
||||||
index,
|
} else {
|
||||||
size: src_dim_len,
|
let index = index.as_usize();
|
||||||
op: "gather",
|
if index >= src_dim_len {
|
||||||
|
Err(Error::InvalidIndex {
|
||||||
|
index,
|
||||||
|
size: src_dim_len,
|
||||||
|
op: "gather",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
}
|
}
|
||||||
.bt())?
|
let src_idx = start_src_idx + index * src_right_len + right_i;
|
||||||
|
dst[dst_idx] = src[src_idx]
|
||||||
}
|
}
|
||||||
let src_idx = start_src_idx + index * src_right_len + right_i;
|
|
||||||
dst[dst_idx] = src[src_idx]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -535,19 +540,24 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
|||||||
let start_src_idx = left_i * right_len * src_dim;
|
let start_src_idx = left_i * right_len * src_dim;
|
||||||
let start_dst_idx = left_i * right_len * n_ids;
|
let start_dst_idx = left_i * right_len * n_ids;
|
||||||
for i in 0..n_ids {
|
for i in 0..n_ids {
|
||||||
let index = self.ids[self.ids_l.start_offset() + stride_ids * i].as_usize();
|
|
||||||
if index >= src_dim {
|
|
||||||
Err(Error::InvalidIndex {
|
|
||||||
index,
|
|
||||||
size: src_dim,
|
|
||||||
op: "index-select",
|
|
||||||
}
|
|
||||||
.bt())?
|
|
||||||
}
|
|
||||||
let start_src_idx = start_src_idx + index * right_len;
|
|
||||||
let start_dst_idx = start_dst_idx + i * right_len;
|
let start_dst_idx = start_dst_idx + i * right_len;
|
||||||
dst[start_dst_idx..start_dst_idx + right_len]
|
let index = self.ids[self.ids_l.start_offset() + stride_ids * i];
|
||||||
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
if index == I::max_value() {
|
||||||
|
dst[start_dst_idx..start_dst_idx + right_len].fill(T::zero());
|
||||||
|
} else {
|
||||||
|
let index = index.as_usize();
|
||||||
|
if index >= src_dim {
|
||||||
|
Err(Error::InvalidIndex {
|
||||||
|
index,
|
||||||
|
size: src_dim,
|
||||||
|
op: "index-select",
|
||||||
|
}
|
||||||
|
.bt())?
|
||||||
|
}
|
||||||
|
let start_src_idx = start_src_idx + index * right_len;
|
||||||
|
dst[start_dst_idx..start_dst_idx + right_len]
|
||||||
|
.copy_from_slice(&src[start_src_idx..start_src_idx + right_len])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(dst)
|
Ok(dst)
|
||||||
@ -631,7 +641,11 @@ impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
|||||||
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
let start_ids_idx = start_ids_idx + i * ids_right_len;
|
||||||
for right_i in 0..dst_right_len {
|
for right_i in 0..dst_right_len {
|
||||||
let ids_idx = start_ids_idx + right_i;
|
let ids_idx = start_ids_idx + right_i;
|
||||||
let index = ids[ids_idx].as_usize();
|
let index = ids[ids_idx];
|
||||||
|
if index == I::max_value() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let index = index.as_usize();
|
||||||
if index >= dst_dim_len {
|
if index >= dst_dim_len {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
index,
|
index,
|
||||||
@ -674,6 +688,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
|||||||
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
let post_dim = src_l.dims()[dim + 1..].iter().product::<usize>();
|
||||||
if dim == 0 {
|
if dim == 0 {
|
||||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||||
|
if *dst_idx == I::max_value() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let dst_idx = dst_idx.as_usize();
|
let dst_idx = dst_idx.as_usize();
|
||||||
if dst_idx >= max_idx {
|
if dst_idx >= max_idx {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
@ -692,6 +709,9 @@ impl<I: IntDType> Map2 for IndexAdd<'_, I> {
|
|||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
for (src_idx, dst_idx) in self.ids.iter().enumerate() {
|
||||||
|
if *dst_idx == I::max_value() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let dst_idx = dst_idx.as_usize();
|
let dst_idx = dst_idx.as_usize();
|
||||||
if dst_idx >= max_idx {
|
if dst_idx >= max_idx {
|
||||||
Err(Error::InvalidIndex {
|
Err(Error::InvalidIndex {
|
||||||
|
@ -180,7 +180,7 @@ with_dtype!(bf16, BF16, bf16::from_f64, bf16::to_f64);
|
|||||||
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
with_dtype!(f32, F32, |v: f64| v as f32, |v: f32| v as f64);
|
||||||
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
with_dtype!(f64, F64, |v: f64| v, |v: f64| v);
|
||||||
|
|
||||||
pub trait IntDType: WithDType {
|
pub trait IntDType: WithDType + num_traits::Bounded {
|
||||||
fn is_true(&self) -> bool;
|
fn is_true(&self) -> bool;
|
||||||
fn as_usize(&self) -> usize;
|
fn as_usize(&self) -> usize;
|
||||||
}
|
}
|
||||||
|
@ -226,8 +226,8 @@ where
|
|||||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
||||||
///
|
///
|
||||||
/// let d = a.i((2.., ..))?;
|
/// let d = a.i((2.., ..))?;
|
||||||
/// assert_eq!(c.shape().dims(), &[2]);
|
/// assert_eq!(d.shape().dims(), &[1, 3]);
|
||||||
/// assert_eq!(c.to_vec1::<f32>()?, &[1., 4.]);
|
/// assert_eq!(d.to_vec2::<f32>()?, &[[6., 7., 8.]]);
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
fn i(&self, (a, b): (A, B)) -> Result<Tensor, Error> {
|
||||||
|
@ -1235,6 +1235,83 @@ impl Tensor {
|
|||||||
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
Ok(from_storage(storage, (n, c, h_out, w_out), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Computes the dot product of two 1D tensors.
|
||||||
|
///
|
||||||
|
/// - If inputs are 1D vectors (`[n]`), returns their scalar dot product.
|
||||||
|
/// - Panics if shapes are not compatible
|
||||||
|
/// - Not supported for integer dtypes
|
||||||
|
///
|
||||||
|
/// # Example (vectors)
|
||||||
|
/// ```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let t1 = Tensor::new(&[1.0, 2.0, 3.0], &Device::Cpu)?;
|
||||||
|
/// let t2 = Tensor::new(&[4.0, 5.0, 6.0], &Device::Cpu)?;
|
||||||
|
/// let res = t1.dot(&t2)?;
|
||||||
|
/// assert_eq!(res.to_scalar::<f64>()?, 32.);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn dot(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
if self.dims().len() != 1 || rhs.dims().len() != 1 {
|
||||||
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: rhs.shape().clone(),
|
||||||
|
op: "dot",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
(self * rhs).and_then(|ret| ret.sum_all())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Computes the **Frobenius norm** (L2 norm of all elements) of the tensor.
|
||||||
|
/// - Output is `sqrt(sum(x^2))`.
|
||||||
|
/// - Always returns a scalar (`[]` shape).
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||||
|
/// let norm = t.norm()?;
|
||||||
|
/// assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn norm(&self) -> Result<Self> {
|
||||||
|
if self.dtype().is_int() {
|
||||||
|
bail!("norm not supported for integer dtypes");
|
||||||
|
}
|
||||||
|
|
||||||
|
self.sqr().and_then(|x| x.sum_all()).and_then(|x| x.sqrt())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Performs strict matrix-vector multiplication (`[m, n] * [n] = [m]`).
|
||||||
|
///
|
||||||
|
/// - If `self` is a matrix (`[m, n]`) and `rhs` is a vector (`[n]`), returns a vector (`[m]`).
|
||||||
|
/// - **No broadcasting**: Panics if `self` is not 2D or if `rhs` is not 1D with matching size.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// ```rust
|
||||||
|
/// use candle_core::{Tensor, Device};
|
||||||
|
/// let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
|
/// let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||||
|
/// let res = mat.mv(&vec)?;
|
||||||
|
/// assert_eq!(res.to_vec1::<f64>()?, [6., 15.]);
|
||||||
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
|
/// ```
|
||||||
|
pub fn mv(&self, rhs: &Self) -> Result<Self> {
|
||||||
|
// Strict shape checks
|
||||||
|
let lhs_dims = self.dims();
|
||||||
|
let rhs_dims = rhs.dims();
|
||||||
|
if lhs_dims.len() != 2 || rhs_dims.len() != 1 || lhs_dims[1] != rhs_dims[0] {
|
||||||
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
lhs: self.shape().clone(),
|
||||||
|
rhs: rhs.shape().clone(),
|
||||||
|
op: "mv",
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct matmul after ensuring rhs is column vector
|
||||||
|
self.matmul(&rhs.unsqueeze(1)?)?.squeeze(1)
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
/// Returns the matrix-multiplication of the input tensor with the other provided tensor.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
@ -82,6 +82,26 @@ fn broadcast_matmul(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tensor_dot() -> Result<()> {
|
||||||
|
let lhs = Tensor::new(&[1., 2., 3.], &Device::Cpu)?;
|
||||||
|
let rhs = Tensor::new(&[4., 5., 6.], &Device::Cpu)?;
|
||||||
|
let expected = Tensor::new(32., &Device::Cpu)?;
|
||||||
|
let dot_ret = lhs.dot(&rhs)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&dot_ret, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tensor_mv() -> Result<()> {
|
||||||
|
let mat = Tensor::new(&[[1., 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
|
let vec = Tensor::new(&[1., 1., 1.], &Device::Cpu)?;
|
||||||
|
let expected = Tensor::new(&[6., 15.], &Device::Cpu)?;
|
||||||
|
let mv_ret = mat.mv(&vec)?;
|
||||||
|
candle_core::test_utils::assert_tensor_eq(&mv_ret, &expected)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
// https://github.com/huggingface/candle/issues/1948
|
// https://github.com/huggingface/candle/issues/1948
|
||||||
fn squeeze_mm(device: &Device) -> Result<()> {
|
fn squeeze_mm(device: &Device) -> Result<()> {
|
||||||
let seq_len = 8_usize;
|
let seq_len = 8_usize;
|
||||||
|
@ -845,6 +845,9 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
let hs = t.index_select(&ids.to_dtype(DType::I64)?, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 1.0], [4.0, 5.0], [2.0, 3.0]]);
|
||||||
|
let ids = Tensor::new(&[u32::MAX, 2u32, u32::MAX], device)?;
|
||||||
|
let hs = t.index_select(&ids, 0)?;
|
||||||
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 0.0], [4.0, 5.0], [0.0, 0.0]]);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1087,6 +1090,31 @@ fn scatter(device: &Device) -> Result<()> {
|
|||||||
[1.0, 1.0, 1.0]
|
[1.0, 1.0, 1.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let hs = {
|
||||||
|
let ids = Tensor::new(
|
||||||
|
&[
|
||||||
|
[0u32, u32::MAX, 2],
|
||||||
|
[3, 4, u32::MAX],
|
||||||
|
[3, 3, 1],
|
||||||
|
[u32::MAX, u32::MAX, 4],
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
init.scatter(&ids, &t, 0)?
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 1.0],
|
||||||
|
[1.0, 1.0, 8.0],
|
||||||
|
[1.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 1.0],
|
||||||
|
[1.0, 4.0, 11.0],
|
||||||
|
[1.0, 1.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
init.scatter_set(&ids, &t, 0)?;
|
init.scatter_set(&ids, &t, 0)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
init.to_vec2::<f32>()?,
|
init.to_vec2::<f32>()?,
|
||||||
@ -1099,6 +1127,7 @@ fn scatter(device: &Device) -> Result<()> {
|
|||||||
[1.0, 1.0, 1.0]
|
[1.0, 1.0, 1.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1132,6 +1161,23 @@ fn gather(device: &Device) -> Result<()> {
|
|||||||
let hs = t.gather(&ids, 0)?;
|
let hs = t.gather(&ids, 0)?;
|
||||||
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
assert_eq!(hs.to_vec2::<f32>()?, &[[0.0, 7.0, 2.0], [0.0, 4.0, 5.0]]);
|
||||||
|
|
||||||
|
let hs = {
|
||||||
|
let ids = Tensor::new(
|
||||||
|
&[
|
||||||
|
[0u32, 0u32],
|
||||||
|
[2u32, u32::MAX],
|
||||||
|
[u32::MAX, 1u32],
|
||||||
|
[0u32, 2u32],
|
||||||
|
],
|
||||||
|
device,
|
||||||
|
)?;
|
||||||
|
t.gather(&ids, 1)?
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[[0.0, 0.0], [5.0, 0.0], [0.0, 7.0], [9.0, 11.0]]
|
||||||
|
);
|
||||||
|
|
||||||
// Random data
|
// Random data
|
||||||
|
|
||||||
// Dim: 0
|
// Dim: 0
|
||||||
@ -1834,3 +1880,11 @@ fn tensor_new() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tensor_norm() -> Result<()> {
|
||||||
|
let t = Tensor::new(&[[3., 4.], [0., 0.]], &Device::Cpu)?;
|
||||||
|
let norm = t.norm()?;
|
||||||
|
assert_eq!(norm.to_scalar::<f64>()?, 5.);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -16,10 +16,9 @@ fn read_u32<T: Read>(reader: &mut T) -> std::io::Result<u32> {
|
|||||||
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
fn check_magic_number<T: Read>(reader: &mut T, expected: u32) -> Result<()> {
|
||||||
let magic_number = read_u32(reader)?;
|
let magic_number = read_u32(reader)?;
|
||||||
if magic_number != expected {
|
if magic_number != expected {
|
||||||
Err(io::Error::new(
|
Err(io::Error::other(format!(
|
||||||
io::ErrorKind::Other,
|
"incorrect magic number {magic_number} != {expected}"
|
||||||
format!("incorrect magic number {magic_number} != {expected}"),
|
)))?;
|
||||||
))?;
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -84,6 +84,10 @@ required-features = ["pyo3"]
|
|||||||
name = "onnx"
|
name = "onnx"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "onnx-llm"
|
||||||
|
required-features = ["onnx"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "onnx_basics"
|
name = "onnx_basics"
|
||||||
required-features = ["onnx"]
|
required-features = ["onnx"]
|
||||||
|
@ -20,8 +20,8 @@ use hf_hub::{api::sync::Api, Repo, RepoType};
|
|||||||
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
use tokenizers::{Encoding, PaddingParams, Tokenizer};
|
||||||
|
|
||||||
enum TaskType {
|
enum TaskType {
|
||||||
Ner(DebertaV2NERModel),
|
Ner(Box<DebertaV2NERModel>),
|
||||||
TextClassification(DebertaV2SeqClassificationModel),
|
TextClassification(Box<DebertaV2SeqClassificationModel>),
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug, Clone, ValueEnum)]
|
#[derive(Parser, Debug, Clone, ValueEnum)]
|
||||||
@ -169,21 +169,16 @@ impl Args {
|
|||||||
|
|
||||||
match self.task {
|
match self.task {
|
||||||
ArgsTask::Ner => Ok((
|
ArgsTask::Ner => Ok((
|
||||||
TaskType::Ner(DebertaV2NERModel::load(
|
TaskType::Ner(DebertaV2NERModel::load(vb, &config, Some(id2label.clone()))?.into()),
|
||||||
vb,
|
|
||||||
&config,
|
|
||||||
Some(id2label.clone()),
|
|
||||||
)?),
|
|
||||||
config,
|
config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
id2label,
|
id2label,
|
||||||
)),
|
)),
|
||||||
ArgsTask::TextClassification => Ok((
|
ArgsTask::TextClassification => Ok((
|
||||||
TaskType::TextClassification(DebertaV2SeqClassificationModel::load(
|
TaskType::TextClassification(
|
||||||
vb,
|
DebertaV2SeqClassificationModel::load(vb, &config, Some(id2label.clone()))?
|
||||||
&config,
|
.into(),
|
||||||
Some(id2label.clone()),
|
),
|
||||||
)?),
|
|
||||||
config,
|
config,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
id2label,
|
id2label,
|
||||||
|
@ -16,8 +16,8 @@ use std::path::PathBuf;
|
|||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
enum ModelType {
|
enum ModelType {
|
||||||
Masked(DistilBertForMaskedLM),
|
Masked(Box<DistilBertForMaskedLM>),
|
||||||
UnMasked(DistilBertModel),
|
UnMasked(Box<DistilBertModel>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ModelType {
|
impl ModelType {
|
||||||
@ -144,10 +144,12 @@ impl Args {
|
|||||||
|
|
||||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||||
match self.model {
|
match self.model {
|
||||||
Which::DistilbertForMaskedLM => {
|
Which::DistilbertForMaskedLM => Ok(ModelType::Masked(
|
||||||
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
DistilBertForMaskedLM::load(vb, config)?.into(),
|
||||||
}
|
)),
|
||||||
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
Which::DistilBert => Ok(ModelType::UnMasked(
|
||||||
|
DistilBertModel::load(vb, config)?.into(),
|
||||||
|
)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
OLMo is a series of Open Language Models designed to enable the science of language models.
|
OLMo is a series of Open Language Models designed to enable the science of language models.
|
||||||
|
|
||||||
- **Project Page:** https://allenai.org/olmo
|
- **Project Page:** https://allenai.org/olmo
|
||||||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838)
|
- **Papers:** [OLMo](https://arxiv.org/abs/2402.00838) [OLMo 2](https://arxiv.org/abs/2501.00656)
|
||||||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580
|
||||||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1
|
||||||
<!-- - **Press release:** TODO -->
|
<!-- - **Press release:** TODO -->
|
||||||
|
@ -8,6 +8,7 @@ use anyhow::{Error as E, Result};
|
|||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
use candle_transformers::models::olmo::{Config, Model as OLMo};
|
||||||
|
use candle_transformers::models::olmo2::{Config as Config2, Model as OLMo2};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -18,6 +19,7 @@ use tokenizers::Tokenizer;
|
|||||||
|
|
||||||
enum Model {
|
enum Model {
|
||||||
OLMo(OLMo),
|
OLMo(OLMo),
|
||||||
|
OLMo2(OLMo2),
|
||||||
}
|
}
|
||||||
|
|
||||||
struct TextGeneration {
|
struct TextGeneration {
|
||||||
@ -82,6 +84,7 @@ impl TextGeneration {
|
|||||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||||
let logits = match &mut self.model {
|
let logits = match &mut self.model {
|
||||||
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
Model::OLMo(m) => m.forward(&input, start_pos)?,
|
||||||
|
Model::OLMo2(m) => m.forward(&input, start_pos)?,
|
||||||
};
|
};
|
||||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||||
let logits = if self.repeat_penalty == 1. {
|
let logits = if self.repeat_penalty == 1. {
|
||||||
@ -129,6 +132,8 @@ enum Which {
|
|||||||
W7bTwin2T,
|
W7bTwin2T,
|
||||||
#[value(name = "1.7-7b")]
|
#[value(name = "1.7-7b")]
|
||||||
V1_7W7b,
|
V1_7W7b,
|
||||||
|
#[value(name = "2-1b")]
|
||||||
|
V2W1b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -220,6 +225,7 @@ fn main() -> Result<()> {
|
|||||||
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
Which::W7b => "allenai/OLMo-7B-hf".to_string(),
|
||||||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(),
|
||||||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(),
|
||||||
|
Which::V2W1b => "allenai/OLMo-2-0425-1B-Instruct".to_string(),
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -238,33 +244,36 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
Which::W1b => {
|
Which::W1b | Which::V2W1b => {
|
||||||
vec![repo.get("model.safetensors")?]
|
vec![repo.get("model.safetensors")?]
|
||||||
}
|
}
|
||||||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let config_filename = repo.get("config.json")?;
|
||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let config = {
|
|
||||||
let config_filename = repo.get("config.json")?;
|
|
||||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
|
||||||
config
|
|
||||||
};
|
|
||||||
|
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let model = {
|
let dtype = if device.is_cuda() {
|
||||||
let dtype = if device.is_cuda() {
|
DType::BF16
|
||||||
DType::BF16
|
} else {
|
||||||
} else {
|
DType::F32
|
||||||
DType::F32
|
};
|
||||||
};
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let model = match args.model {
|
||||||
let model = OLMo::new(&config, vb)?;
|
Which::W1b | Which::W7b | Which::W7bTwin2T | Which::V1_7W7b => {
|
||||||
Model::OLMo(model)
|
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let model = OLMo::new(&config, vb)?;
|
||||||
|
Model::OLMo(model)
|
||||||
|
}
|
||||||
|
Which::V2W1b => {
|
||||||
|
let config: Config2 = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||||
|
let model = OLMo2::new(&config, vb)?;
|
||||||
|
Model::OLMo2(model)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
11
candle-examples/examples/onnx-llm/README.md
Normal file
11
candle-examples/examples/onnx-llm/README.md
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
## Using ONNX models in Candle
|
||||||
|
|
||||||
|
This example demonstrates how to run [ONNX](https://github.com/onnx/onnx) based LLM models in Candle.
|
||||||
|
|
||||||
|
This script only implements SmolLM-135M right now.
|
||||||
|
|
||||||
|
You can run the examples with following commands:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example onnx-llm --features onnx
|
||||||
|
```
|
209
candle-examples/examples/onnx-llm/main.rs
Normal file
209
candle-examples/examples/onnx-llm/main.rs
Normal file
@ -0,0 +1,209 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{DType, Tensor};
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
SmolLM135M,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser)]
|
||||||
|
struct Args {
|
||||||
|
/// The prompt to be used.
|
||||||
|
#[arg(long, default_value = "My favorite theorem is ")]
|
||||||
|
prompt: String,
|
||||||
|
|
||||||
|
/// The model to be used.
|
||||||
|
#[arg(value_enum, long, default_value_t = Which::SmolLM135M)]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// The number of tokens to generate.
|
||||||
|
#[arg(long, default_value_t = 100)]
|
||||||
|
max_tokens: usize,
|
||||||
|
|
||||||
|
/// The temperature used for sampling.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f32,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let (model_id, tokenizer_id) = match args.which {
|
||||||
|
Which::SmolLM135M => ("HuggingFaceTB/SmolLM-135M", "HuggingFaceTB/SmolLM-135M"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let api = Api::new()?;
|
||||||
|
let model_repo = api.model(model_id.to_string());
|
||||||
|
let tokenizer_repo = api.model(tokenizer_id.to_string());
|
||||||
|
|
||||||
|
let model_path = model_repo.get("onnx/model.onnx")?;
|
||||||
|
let config_file = model_repo.get("config.json")?;
|
||||||
|
let config: Config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||||
|
|
||||||
|
let tokenizer_path = tokenizer_repo.get("tokenizer.json")?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
|
let tokens_u32 = tokenizer
|
||||||
|
.encode(args.prompt.as_str(), true)
|
||||||
|
.map_err(anyhow::Error::msg)?
|
||||||
|
.get_ids()
|
||||||
|
.to_vec();
|
||||||
|
|
||||||
|
let tokens: Vec<i64> = tokens_u32.iter().map(|&t| t as i64).collect();
|
||||||
|
|
||||||
|
println!("Loading ONNX model from {:?}", model_path);
|
||||||
|
let model = candle_onnx::read_file(model_path)?;
|
||||||
|
|
||||||
|
let mut generated_tokens = tokens.clone();
|
||||||
|
print!("{}", args.prompt);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature as f64;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut past_key_values: Option<Vec<(Tensor, Tensor)>> = None;
|
||||||
|
let num_layers = config.num_hidden_layers;
|
||||||
|
|
||||||
|
for _ in 0..args.max_tokens {
|
||||||
|
let mut inputs = std::collections::HashMap::new();
|
||||||
|
|
||||||
|
if let Some(past_kv) = &past_key_values {
|
||||||
|
let last_token = vec![generated_tokens[generated_tokens.len() - 1]];
|
||||||
|
let input_tensor = Tensor::new(last_token, &device)?.unsqueeze(0)?;
|
||||||
|
inputs.insert("input_ids".to_string(), input_tensor);
|
||||||
|
|
||||||
|
let seq_len = generated_tokens.len();
|
||||||
|
let attention_mask = vec![vec![1i64; seq_len]];
|
||||||
|
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||||
|
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||||
|
|
||||||
|
let position_ids = vec![vec![(seq_len - 1) as i64]];
|
||||||
|
let position_ids_tensor = Tensor::new(position_ids, &device)?;
|
||||||
|
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||||
|
|
||||||
|
for (i, (key, value)) in past_kv.iter().enumerate() {
|
||||||
|
inputs.insert(format!("past_key_values.{}.key", i), key.clone());
|
||||||
|
inputs.insert(format!("past_key_values.{}.value", i), value.clone());
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let input_tensor = Tensor::new(generated_tokens.clone(), &device)?.unsqueeze(0)?;
|
||||||
|
inputs.insert("input_ids".to_string(), input_tensor);
|
||||||
|
|
||||||
|
let seq_len = generated_tokens.len();
|
||||||
|
let attention_mask = vec![vec![1i64; seq_len]];
|
||||||
|
let attention_mask_tensor = Tensor::new(attention_mask, &device)?;
|
||||||
|
inputs.insert("attention_mask".to_string(), attention_mask_tensor);
|
||||||
|
|
||||||
|
let position_ids: Vec<i64> = (0..seq_len as i64).collect();
|
||||||
|
let position_ids_tensor = Tensor::new(position_ids, &device)?.unsqueeze(0)?;
|
||||||
|
inputs.insert("position_ids".to_string(), position_ids_tensor);
|
||||||
|
|
||||||
|
// Create empty key and value tensors
|
||||||
|
for i in 0..num_layers {
|
||||||
|
let batch_size = 1;
|
||||||
|
let num_heads = config.num_key_value_heads;
|
||||||
|
let head_dim = config.hidden_size / config.num_attention_heads;
|
||||||
|
let seq_len = 0;
|
||||||
|
|
||||||
|
let empty_key = Tensor::zeros(
|
||||||
|
&[batch_size, num_heads, seq_len, head_dim],
|
||||||
|
DType::F32,
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
let empty_value = Tensor::zeros(
|
||||||
|
&[batch_size, num_heads, seq_len, head_dim],
|
||||||
|
DType::F32,
|
||||||
|
&device,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
inputs.insert(format!("past_key_values.{}.key", i), empty_key);
|
||||||
|
inputs.insert(format!("past_key_values.{}.value", i), empty_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let outputs = candle_onnx::simple_eval(&model, inputs)?;
|
||||||
|
|
||||||
|
let logits = outputs.get("logits").unwrap();
|
||||||
|
|
||||||
|
let mut new_past_kv = Vec::with_capacity(num_layers);
|
||||||
|
for i in 0..num_layers {
|
||||||
|
let key = outputs
|
||||||
|
.get(&format!("present.{}.key", i))
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.key", i))?;
|
||||||
|
let value = outputs
|
||||||
|
.get(&format!("present.{}.value", i))
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing present.{}.value", i))?;
|
||||||
|
new_past_kv.push((key.clone(), value.clone()));
|
||||||
|
}
|
||||||
|
past_key_values = Some(new_past_kv);
|
||||||
|
|
||||||
|
let logits_dim = logits.dims();
|
||||||
|
let seq_len = logits_dim[1];
|
||||||
|
|
||||||
|
let next_token_id = logits_processor.sample(&logits.get(0)?.get(seq_len - 1)?)?;
|
||||||
|
generated_tokens.push(next_token_id as i64);
|
||||||
|
|
||||||
|
if let Some(token_str) = tokenizer.decode(&[next_token_id], true).ok() {
|
||||||
|
print!("{}", token_str);
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(eos_id) = tokenizer.token_to_id("<|endoftext|>") {
|
||||||
|
if next_token_id == eos_id {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("\nGeneration complete!");
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -5,12 +5,14 @@ extern crate intel_mkl_src;
|
|||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle::{IndexOp, D};
|
use candle::{IndexOp, D};
|
||||||
|
use candle_examples::save_image;
|
||||||
use clap::{Parser, ValueEnum};
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
enum Which {
|
enum Which {
|
||||||
SqueezeNet,
|
SqueezeNet,
|
||||||
EfficientNet,
|
EfficientNet,
|
||||||
|
EsrGan,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser)]
|
#[derive(Parser)]
|
||||||
@ -28,10 +30,21 @@ struct Args {
|
|||||||
|
|
||||||
pub fn main() -> anyhow::Result<()> {
|
pub fn main() -> anyhow::Result<()> {
|
||||||
let args = Args::parse();
|
let args = Args::parse();
|
||||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
let image = match args.which {
|
||||||
|
Which::SqueezeNet | Which::EfficientNet => {
|
||||||
|
candle_examples::imagenet::load_image224(&args.image)?
|
||||||
|
}
|
||||||
|
Which::EsrGan => candle_examples::imagenet::load_image_with_std_mean(
|
||||||
|
&args.image,
|
||||||
|
128,
|
||||||
|
&[0.0f32, 0.0, 0.0],
|
||||||
|
&[1.0f32, 1.0, 1.0],
|
||||||
|
)?,
|
||||||
|
};
|
||||||
let image = match args.which {
|
let image = match args.which {
|
||||||
Which::SqueezeNet => image,
|
Which::SqueezeNet => image,
|
||||||
Which::EfficientNet => image.permute((1, 2, 0))?,
|
Which::EfficientNet => image.permute((1, 2, 0))?,
|
||||||
|
Which::EsrGan => image,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("loaded image {image:?}");
|
println!("loaded image {image:?}");
|
||||||
@ -45,6 +58,9 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
Which::EfficientNet => hf_hub::api::sync::Api::new()?
|
||||||
.model("onnx/EfficientNet-Lite4".into())
|
.model("onnx/EfficientNet-Lite4".into())
|
||||||
.get("efficientnet-lite4-11.onnx")?,
|
.get("efficientnet-lite4-11.onnx")?,
|
||||||
|
Which::EsrGan => hf_hub::api::sync::Api::new()?
|
||||||
|
.model("qualcomm/Real-ESRGAN-x4plus".into())
|
||||||
|
.get("Real-ESRGAN-x4plus.onnx")?,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -57,21 +73,40 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
let prs = match args.which {
|
let prs = match args.which {
|
||||||
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
Which::SqueezeNet => candle_nn::ops::softmax(&output, D::Minus1)?,
|
||||||
Which::EfficientNet => output,
|
Which::EfficientNet => output,
|
||||||
|
Which::EsrGan => output,
|
||||||
};
|
};
|
||||||
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
|
||||||
|
|
||||||
// Sort the predictions and take the top 5
|
match args.which {
|
||||||
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
Which::EfficientNet | Which::SqueezeNet => {
|
||||||
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
let prs = prs.i(0)?.to_vec1::<f32>()?;
|
||||||
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
|
||||||
|
|
||||||
// Print the top predictions
|
// Sort the predictions and take the top 5
|
||||||
for &(i, p) in &top {
|
let mut top: Vec<_> = prs.iter().enumerate().collect();
|
||||||
println!(
|
top.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
|
||||||
"{:50}: {:.2}%",
|
let top = top.into_iter().take(5).collect::<Vec<_>>();
|
||||||
candle_examples::imagenet::CLASSES[i],
|
|
||||||
p * 100.0
|
// Print the top predictions
|
||||||
);
|
for &(i, p) in &top {
|
||||||
|
println!(
|
||||||
|
"{:50}: {:.2}%",
|
||||||
|
candle_examples::imagenet::CLASSES[i],
|
||||||
|
p * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Which::EsrGan => {
|
||||||
|
let max_pixel_val = candle::Tensor::try_from(255.0f32)?
|
||||||
|
.to_device(prs.device())?
|
||||||
|
.broadcast_as(prs.shape())?;
|
||||||
|
let out = (prs * max_pixel_val)?.i(0)?.to_dtype(candle::DType::U8)?;
|
||||||
|
|
||||||
|
let pb = std::path::PathBuf::from(args.image);
|
||||||
|
let input_file_name = pb.file_name().unwrap();
|
||||||
|
let mut output_file_name = std::ffi::OsString::from("super_");
|
||||||
|
output_file_name.push(input_file_name);
|
||||||
|
|
||||||
|
save_image(&out, output_file_name)?;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -147,9 +147,9 @@ enum WhichModel {
|
|||||||
V3,
|
V3,
|
||||||
#[value(name = "3-medium")]
|
#[value(name = "3-medium")]
|
||||||
V3Medium,
|
V3Medium,
|
||||||
#[value(name = "2-old")]
|
|
||||||
V4Mini,
|
|
||||||
#[value(name = "4-mini")]
|
#[value(name = "4-mini")]
|
||||||
|
V4Mini,
|
||||||
|
#[value(name = "2-old")]
|
||||||
V2Old,
|
V2Old,
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
PhiHermes,
|
PhiHermes,
|
||||||
|
@ -8,4 +8,8 @@
|
|||||||
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
|
cargo run --example quantized-qwen2-instruct --release -- --prompt "Write a function to count prime numbers up to N."
|
||||||
```
|
```
|
||||||
|
|
||||||
0.5b, 1.5b, 7b and 72b models are available via `--model` argument.
|
0.5b, 1.5b, 7b and 72b models are available via `--which` argument.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --release --example quantized-qwen2-instruct -- --which 0.5b --prompt "Write a function to count prime numbers up to N."
|
||||||
|
```
|
||||||
|
17
candle-examples/examples/quantized-qwen3/README.md
Normal file
17
candle-examples/examples/quantized-qwen3/README.md
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# candle-quantized-qwen3
|
||||||
|
|
||||||
|
[Qwen3]((https://qwenlm.github.io/blog/qwen3/)) is an upgraded version of Qwen2.5, released by Alibaba Cloud.
|
||||||
|
|
||||||
|
## Running the example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized-qwen3 --release -- --prompt "Write a function to count prime numbers up to N."
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
0.6b is used by default, 1.7b, 4b, 8b, 14b, and 32b models are available via `--which` argument.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example quantized-qwen3 --release -- --which 4b --prompt "A train is travelling at 120mph, how far does it travel in 3 minutes 30 seconds?"
|
||||||
|
```
|
||||||
|
|
314
candle-examples/examples/quantized-qwen3/main.rs
Normal file
314
candle-examples/examples/quantized-qwen3/main.rs
Normal file
@ -0,0 +1,314 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
use candle::quantized::gguf_file;
|
||||||
|
use candle::Tensor;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_transformers::models::quantized_qwen3::ModelWeights as Qwen3;
|
||||||
|
|
||||||
|
const DEFAULT_PROMPT: &str = "Write a Rust function to calculate the factorial of a given number.";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "0.6b")]
|
||||||
|
W3_0_6b,
|
||||||
|
#[value(name = "1.7b")]
|
||||||
|
W3_1_7b,
|
||||||
|
#[value(name = "4b")]
|
||||||
|
W3_4b,
|
||||||
|
#[value(name = "8b")]
|
||||||
|
W3_8b,
|
||||||
|
#[value(name = "14b")]
|
||||||
|
W3_14b,
|
||||||
|
#[value(name = "32b")]
|
||||||
|
W3_32b,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// GGUF file to load, typically a .gguf file generated by the quantize command from llama.cpp
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||||
|
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||||
|
/// is preserved.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The tokenizer config in json format.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Process prompt elements separately.
|
||||||
|
#[arg(long)]
|
||||||
|
split_prompt: bool,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU even if a GPU is available.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "0.6b")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer_path = match &self.tokenizer {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = match self.which {
|
||||||
|
Which::W3_0_6b => "Qwen/Qwen3-0.6B",
|
||||||
|
Which::W3_1_7b => "Qwen/Qwen3-1.7B",
|
||||||
|
Which::W3_4b => "Qwen/Qwen3-4B",
|
||||||
|
Which::W3_8b => "Qwen/Qwen3-8B",
|
||||||
|
Which::W3_14b => "Qwen/Qwen3-14B",
|
||||||
|
Which::W3_32b => "Qwen/Qwen3-32B",
|
||||||
|
};
|
||||||
|
let api = api.model(repo.to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||||
|
let model_path = match &self.model {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let (repo, filename, revision) = match self.which {
|
||||||
|
Which::W3_0_6b => ("unsloth/Qwen3-0.6B-GGUF", "Qwen3-0.6B-Q4_K_M.gguf", "main"),
|
||||||
|
Which::W3_1_7b => ("unsloth/Qwen3-1.7B-GGUF", "Qwen3-1.7B-Q4_K_M.gguf", "main"),
|
||||||
|
Which::W3_4b => ("unsloth/Qwen3-4B-GGUF", "Qwen3-4B-Q4_K_M.gguf", "main"),
|
||||||
|
Which::W3_8b => ("unsloth/Qwen3-8B-GGUF", "Qwen3-8B-Q4_K_M.gguf", "main"),
|
||||||
|
Which::W3_14b => ("unsloth/Qwen3-14B-GGUF", "Qwen3-14B-Q4_K_M.gguf", "main"),
|
||||||
|
Which::W3_32b => ("unsloth/Qwen3-32B-GGUF", "Qwen3-32B-Q4_K_M.gguf", "main"),
|
||||||
|
};
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
|
repo.to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
revision.to_string(),
|
||||||
|
))
|
||||||
|
.get(filename)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(model_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
|
if size_in_bytes < 1_000 {
|
||||||
|
format!("{}B", size_in_bytes)
|
||||||
|
} else if size_in_bytes < 1_000_000 {
|
||||||
|
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||||
|
} else if size_in_bytes < 1_000_000_000 {
|
||||||
|
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||||
|
} else {
|
||||||
|
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let model_path = args.model()?;
|
||||||
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let mut model = {
|
||||||
|
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(model_path))?;
|
||||||
|
let mut total_size_in_bytes = 0;
|
||||||
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
|
let elem_count = tensor.shape.elem_count();
|
||||||
|
total_size_in_bytes +=
|
||||||
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
|
model.tensor_infos.len(),
|
||||||
|
&format_size(total_size_in_bytes),
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
Qwen3::from_gguf(model, &mut file, &device)?
|
||||||
|
};
|
||||||
|
println!("model built");
|
||||||
|
|
||||||
|
let tokenizer = args.tokenizer()?;
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
|
let prompt_str = args
|
||||||
|
.prompt
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| DEFAULT_PROMPT.to_string());
|
||||||
|
|
||||||
|
let prompt_str = format!("<|im_start|>user\n{prompt_str}<|im_end|>\n<|im_start|>assistant\n");
|
||||||
|
print!("formatted prompt: {}", &prompt_str);
|
||||||
|
|
||||||
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt_str, true)
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
|
let tokens = tokens.get_ids();
|
||||||
|
|
||||||
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
|
|
||||||
|
let mut all_tokens = vec![];
|
||||||
|
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
|
|
||||||
|
let mut next_token = if !args.split_prompt {
|
||||||
|
let input = Tensor::new(tokens, &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
} else {
|
||||||
|
let mut next_token = 0;
|
||||||
|
for (pos, token) in tokens.iter().enumerate() {
|
||||||
|
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
next_token = logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
next_token
|
||||||
|
};
|
||||||
|
|
||||||
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let eos_token = *tos.tokenizer().get_vocab(true).get("<|im_end|>").unwrap();
|
||||||
|
|
||||||
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
|
||||||
|
let mut sampled = 0;
|
||||||
|
for index in 0..to_sample {
|
||||||
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&all_tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let dt = start_post_prompt.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
|
tokens.len(),
|
||||||
|
tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -25,3 +25,28 @@ def print_prime(n: int): # n is the number of primes to be printed
|
|||||||
print(i)
|
print(i)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
The qwen3 MoE variant is also an option.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example qwen --features metal --release -- --prompt "Write a poem about butterflies. <think></think>." --model "3-moe-a3b"
|
||||||
|
> In morning's hush, where daisies sleep,
|
||||||
|
> A fleeting dance through sunlit deep—
|
||||||
|
> They flutter soft on gossamer thread,
|
||||||
|
> The messengers of spring’s own head.
|
||||||
|
>
|
||||||
|
> With painted sails and delicate grace,
|
||||||
|
> They drift from bloom to blossom's face.
|
||||||
|
> Each wing a tale in hues unseen,
|
||||||
|
> Of ancient dreams and secrets between.
|
||||||
|
>
|
||||||
|
> No sound they make, yet still they speak—
|
||||||
|
> Of time that flies, of life so brief.
|
||||||
|
> A fleeting kiss on summer’s breath,
|
||||||
|
> A whisper lost before death.
|
||||||
|
>
|
||||||
|
> Yet in their flight, the soul takes wing,
|
||||||
|
> And for a moment, all is spring.
|
||||||
|
> For though they fade, they never die—
|
||||||
|
> Their beauty lives where hearts can fly.
|
||||||
|
> 161 tokens generated (3.00 token/s)
|
||||||
|
```
|
||||||
|
@ -9,6 +9,8 @@ use clap::Parser;
|
|||||||
|
|
||||||
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
|
||||||
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};
|
||||||
|
use candle_transformers::models::qwen3::{Config as Config3, ModelForCausalLM as Model3};
|
||||||
|
use candle_transformers::models::qwen3_moe::{Config as ConfigMoe3, ModelForCausalLM as ModelMoe3};
|
||||||
|
|
||||||
use candle::{DType, Device, Tensor};
|
use candle::{DType, Device, Tensor};
|
||||||
use candle_examples::token_output_stream::TokenOutputStream;
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
@ -20,6 +22,8 @@ use tokenizers::Tokenizer;
|
|||||||
enum Model {
|
enum Model {
|
||||||
Base(ModelBase),
|
Base(ModelBase),
|
||||||
Moe(ModelMoe),
|
Moe(ModelMoe),
|
||||||
|
Base3(Model3),
|
||||||
|
Moe3(ModelMoe3),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Model {
|
impl Model {
|
||||||
@ -27,6 +31,8 @@ impl Model {
|
|||||||
match self {
|
match self {
|
||||||
Self::Moe(ref mut m) => m.forward(xs, s),
|
Self::Moe(ref mut m) => m.forward(xs, s),
|
||||||
Self::Base(ref mut m) => m.forward(xs, s),
|
Self::Base(ref mut m) => m.forward(xs, s),
|
||||||
|
Self::Base3(ref mut m) => m.forward(xs, s),
|
||||||
|
Self::Moe3(ref mut m) => m.forward(xs, s),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -85,6 +91,10 @@ impl TextGeneration {
|
|||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||||
};
|
};
|
||||||
|
let eos_token2 = match self.tokenizer.get_token("<|im_end|>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => anyhow::bail!("cannot find the <|im_end|> token"),
|
||||||
|
};
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
@ -107,7 +117,7 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token || next_token == eos_token2 {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
@ -152,6 +162,16 @@ enum WhichModel {
|
|||||||
W2_7b,
|
W2_7b,
|
||||||
#[value(name = "2-72b")]
|
#[value(name = "2-72b")]
|
||||||
W2_72b,
|
W2_72b,
|
||||||
|
#[value(name = "3-0.6b")]
|
||||||
|
W3_0_6b,
|
||||||
|
#[value(name = "3-1.7b")]
|
||||||
|
W3_1_7b,
|
||||||
|
#[value(name = "3-4b")]
|
||||||
|
W3_4b,
|
||||||
|
#[value(name = "3-8b")]
|
||||||
|
W3_8b,
|
||||||
|
#[value(name = "3-moe-a3b")]
|
||||||
|
W3MoeA3b,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -254,6 +274,11 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::W14b => ("1.5", "14B"),
|
WhichModel::W14b => ("1.5", "14B"),
|
||||||
WhichModel::W72b => ("1.5", "72B"),
|
WhichModel::W72b => ("1.5", "72B"),
|
||||||
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
WhichModel::MoeA27b => ("1.5", "MoE-A2.7B"),
|
||||||
|
WhichModel::W3_0_6b => ("3", "0.6B"),
|
||||||
|
WhichModel::W3_1_7b => ("3", "1.7B"),
|
||||||
|
WhichModel::W3_4b => ("3", "4B"),
|
||||||
|
WhichModel::W3_8b => ("3", "8B"),
|
||||||
|
WhichModel::W3MoeA3b => ("3", "30B-A3B"),
|
||||||
};
|
};
|
||||||
format!("Qwen/Qwen{version}-{size}")
|
format!("Qwen/Qwen{version}-{size}")
|
||||||
}
|
}
|
||||||
@ -273,7 +298,11 @@ fn main() -> Result<()> {
|
|||||||
.map(std::path::PathBuf::from)
|
.map(std::path::PathBuf::from)
|
||||||
.collect::<Vec<_>>(),
|
.collect::<Vec<_>>(),
|
||||||
None => match args.model {
|
None => match args.model {
|
||||||
WhichModel::W0_5b | WhichModel::W2_0_5b | WhichModel::W2_1_5b | WhichModel::W1_8b => {
|
WhichModel::W0_5b
|
||||||
|
| WhichModel::W2_0_5b
|
||||||
|
| WhichModel::W2_1_5b
|
||||||
|
| WhichModel::W1_8b
|
||||||
|
| WhichModel::W3_0_6b => {
|
||||||
vec![repo.get("model.safetensors")?]
|
vec![repo.get("model.safetensors")?]
|
||||||
}
|
}
|
||||||
WhichModel::W4b
|
WhichModel::W4b
|
||||||
@ -282,7 +311,11 @@ fn main() -> Result<()> {
|
|||||||
| WhichModel::W14b
|
| WhichModel::W14b
|
||||||
| WhichModel::W72b
|
| WhichModel::W72b
|
||||||
| WhichModel::W2_72b
|
| WhichModel::W2_72b
|
||||||
| WhichModel::MoeA27b => {
|
| WhichModel::MoeA27b
|
||||||
|
| WhichModel::W3_1_7b
|
||||||
|
| WhichModel::W3_4b
|
||||||
|
| WhichModel::W3_8b
|
||||||
|
| WhichModel::W3MoeA3b => {
|
||||||
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@ -304,6 +337,14 @@ fn main() -> Result<()> {
|
|||||||
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
let config: ConfigMoe = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
Model::Moe(ModelMoe::new(&config, vb)?)
|
Model::Moe(ModelMoe::new(&config, vb)?)
|
||||||
}
|
}
|
||||||
|
WhichModel::W3_0_6b | WhichModel::W3_1_7b | WhichModel::W3_4b | WhichModel::W3_8b => {
|
||||||
|
let config: Config3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
|
Model::Base3(Model3::new(&config, vb)?)
|
||||||
|
}
|
||||||
|
WhichModel::W3MoeA3b => {
|
||||||
|
let config: ConfigMoe3 = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
|
Model::Moe3(ModelMoe3::new(&config, vb)?)
|
||||||
|
}
|
||||||
_ => {
|
_ => {
|
||||||
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
let config: ConfigBase = serde_json::from_slice(&std::fs::read(config_file)?)?;
|
||||||
Model::Base(ModelBase::new(&config, vb)?)
|
Model::Base(ModelBase::new(&config, vb)?)
|
||||||
|
@ -28,3 +28,26 @@ Ranking Results:
|
|||||||
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
> Rank #1 | Score: 0.9990 | The giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China.
|
||||||
--------------------------------------------------------------------------------
|
--------------------------------------------------------------------------------
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Text-Classification:
|
||||||
|
```bash
|
||||||
|
cargo run --example xlm-roberta -- --task text-classification --model xlmr-formality-classifier
|
||||||
|
```
|
||||||
|
```markdown
|
||||||
|
Formality Scores:
|
||||||
|
Text 1: "I like you. I love you"
|
||||||
|
formal: 0.9933
|
||||||
|
informal: 0.0067
|
||||||
|
|
||||||
|
Text 2: "Hey, what's up?"
|
||||||
|
formal: 0.8812
|
||||||
|
informal: 0.1188
|
||||||
|
|
||||||
|
Text 3: "Siema, co porabiasz?"
|
||||||
|
formal: 0.9358
|
||||||
|
informal: 0.0642
|
||||||
|
|
||||||
|
Text 4: "I feel deep regret and sadness about the situation in international politics."
|
||||||
|
formal: 0.9987
|
||||||
|
informal: 0.0013
|
||||||
|
```
|
@ -2,6 +2,7 @@ use std::path::PathBuf;
|
|||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
|
use candle_nn::ops::softmax;
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use candle_transformers::models::xlm_roberta::{
|
use candle_transformers::models::xlm_roberta::{
|
||||||
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
Config, XLMRobertaForMaskedLM, XLMRobertaForSequenceClassification,
|
||||||
@ -17,12 +18,14 @@ enum Model {
|
|||||||
BgeRerankerBaseV2,
|
BgeRerankerBaseV2,
|
||||||
XLMRobertaBase,
|
XLMRobertaBase,
|
||||||
XLMRobertaLarge,
|
XLMRobertaLarge,
|
||||||
|
XLMRFormalityClassifier,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, ValueEnum)]
|
#[derive(Debug, Clone, ValueEnum)]
|
||||||
enum Task {
|
enum Task {
|
||||||
FillMask,
|
FillMask,
|
||||||
Reranker,
|
Reranker,
|
||||||
|
TextClassification,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
@ -83,6 +86,12 @@ fn main() -> Result<()> {
|
|||||||
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
Model::BgeRerankerBaseV2 => "BAAI/bge-reranker-base-v2-m3".to_string(),
|
||||||
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
_ => anyhow::bail!("XLM-RoBERTa models are not supported for reranker task"),
|
||||||
},
|
},
|
||||||
|
Task::TextClassification => match args.model {
|
||||||
|
Model::XLMRFormalityClassifier => "s-nlp/xlmr_formality_classifier".to_string(),
|
||||||
|
_ => anyhow::bail!(
|
||||||
|
"XLM-RoBERTa models are not supported for text classification task"
|
||||||
|
),
|
||||||
|
},
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
let repo = api.repo(Repo::with_revision(
|
let repo = api.repo(Repo::with_revision(
|
||||||
@ -217,6 +226,36 @@ fn main() -> Result<()> {
|
|||||||
});
|
});
|
||||||
println!("{:-<80}", "");
|
println!("{:-<80}", "");
|
||||||
}
|
}
|
||||||
|
Task::TextClassification => {
|
||||||
|
let sentences = vec![
|
||||||
|
"I like you. I love you".to_string(),
|
||||||
|
"Hey, what's up?".to_string(),
|
||||||
|
"Siema, co porabiasz?".to_string(),
|
||||||
|
"I feel deep regret and sadness about the situation in international politics."
|
||||||
|
.to_string(),
|
||||||
|
];
|
||||||
|
let model = XLMRobertaForSequenceClassification::new(2, &config, vb)?;
|
||||||
|
let input_ids = tokenize_batch(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||||
|
|
||||||
|
let attention_mask =
|
||||||
|
get_attention_mask(&tokenizer, TokenizeInput::Single(&sentences), &device)?;
|
||||||
|
let token_type_ids = Tensor::zeros(input_ids.dims(), input_ids.dtype(), &device)?;
|
||||||
|
|
||||||
|
let logits = model
|
||||||
|
.forward(&input_ids, &attention_mask, &token_type_ids)?
|
||||||
|
.to_dtype(candle::DType::F32)?;
|
||||||
|
|
||||||
|
let probabilities = softmax(&logits, 1)?;
|
||||||
|
let probs_vec = probabilities.to_vec2::<f32>()?;
|
||||||
|
|
||||||
|
println!("Formality Scores:");
|
||||||
|
for (i, (text, probs)) in sentences.iter().zip(probs_vec.iter()).enumerate() {
|
||||||
|
println!("Text {}: \"{}\"", i + 1, text);
|
||||||
|
println!(" formal: {:.4}", probs[0]);
|
||||||
|
println!(" informal: {:.4}", probs[1]);
|
||||||
|
println!();
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -3,6 +3,28 @@
|
|||||||
#include "cuda_utils.cuh"
|
#include "cuda_utils.cuh"
|
||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr T max_value();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr int64_t max_value<int64_t>() {
|
||||||
|
return 0x7FFFFFFFFFFFFFFFLL;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr uint32_t max_value<uint32_t>() {
|
||||||
|
return 0xFFFFFFFFu;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
__host__ __device__
|
||||||
|
constexpr uint8_t max_value<uint8_t>() {
|
||||||
|
return 0xFFu;
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, typename I>
|
template<typename T, typename I>
|
||||||
__device__ void index_select(
|
__device__ void index_select(
|
||||||
const size_t numel,
|
const size_t numel,
|
||||||
@ -23,10 +45,14 @@ __device__ void index_select(
|
|||||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||||
unsigned int right_i = dst_i % right_size;
|
unsigned int right_i = dst_i % right_size;
|
||||||
assert(ids[id_i] < src_dim_size);
|
if (ids[id_i] == max_value<I>()) {
|
||||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
out[dst_i] = static_cast<T>(0);
|
||||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
} else {
|
||||||
out[dst_i] = inp[strided_i];
|
assert(ids[id_i] < src_dim_size);
|
||||||
|
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||||
|
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||||
|
out[dst_i] = inp[strided_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -57,11 +83,15 @@ __device__ void gather(
|
|||||||
) {
|
) {
|
||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
size_t post = i % right_size;
|
size_t post = i % right_size;
|
||||||
size_t idx = ids[i];
|
const I idx = ids[i];
|
||||||
assert(idx < src_dim_size);
|
if (ids[i] == max_value<I>()) {
|
||||||
size_t pre = i / (right_size * ids_dim_size);
|
out[i] = static_cast<T>(0);
|
||||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
} else {
|
||||||
out[i] = inp[src_i];
|
assert(idx < src_dim_size);
|
||||||
|
size_t pre = i / (right_size * ids_dim_size);
|
||||||
|
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||||
|
out[i] = inp[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -93,11 +123,13 @@ __device__ void index_add(
|
|||||||
const size_t pre = i / right_size;
|
const size_t pre = i / right_size;
|
||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const size_t idx = ids[j];
|
const I idx = ids[j];
|
||||||
assert(idx < dst_dim_size);
|
|
||||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
if (idx < max_value<I>()) {
|
||||||
out[dst_i] += inp[src_i];
|
assert(idx < dst_dim_size);
|
||||||
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] += inp[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -130,10 +162,12 @@ __device__ void scatter(
|
|||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||||
const size_t idx = ids[src_i];
|
const I idx = ids[src_i];
|
||||||
assert(idx < dst_dim_size);
|
if (idx < max_value<I>()) {
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
assert(idx < dst_dim_size);
|
||||||
out[dst_i] = inp[src_i];
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] = inp[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -154,10 +188,12 @@ __device__ void scatter_add(
|
|||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||||
const size_t idx = ids[src_i];
|
const I idx = ids[src_i];
|
||||||
assert(idx < dst_dim_size);
|
if (idx < max_value<I>()) {
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
assert(idx < dst_dim_size);
|
||||||
out[dst_i] += inp[src_i];
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] += inp[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,24 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T max_value();
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline int64_t max_value<int64_t>() {
|
||||||
|
return 0x7FFFFFFFFFFFFFFF;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint32_t max_value<uint32_t>() {
|
||||||
|
return 0xFFFFFFFFu;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
inline uint8_t max_value<uint8_t>() {
|
||||||
|
return 0xFF;
|
||||||
|
}
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
uint idx,
|
uint idx,
|
||||||
constant size_t &num_dims,
|
constant size_t &num_dims,
|
||||||
@ -35,17 +53,21 @@ METAL_FUNC void index(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const size_t id_i = (tid / right_size) % ids_size;
|
const size_t id_i = (tid / right_size) % ids_size;
|
||||||
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
if (input_ids[id_i] == max_value<INDEX_TYPENAME>()) {
|
||||||
const size_t right_rank_i = tid % right_size;
|
output[tid] = static_cast<TYPENAME>(0);
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
} else {
|
||||||
/*
|
const INDEX_TYPENAME input_i = min(input_ids[id_i], (INDEX_TYPENAME)(src_dim_size - 1));
|
||||||
// Force prevent out of bounds indexing
|
const size_t right_rank_i = tid % right_size;
|
||||||
// since there doesn't seem to be a good way to force crash
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
// No need to check for zero we're only allowing unsized.
|
/*
|
||||||
*/
|
// Force prevent out of bounds indexing
|
||||||
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
// since there doesn't seem to be a good way to force crash
|
||||||
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
// No need to check for zero we're only allowing unsized.
|
||||||
output[tid] = input[strided_src_i];
|
*/
|
||||||
|
const size_t src_i = left_rank_i * src_dim_size * right_size + input_i * right_size + right_rank_i;
|
||||||
|
const size_t strided_src_i = contiguous ? src_i : get_strided_index(src_i, src_dim_size, src_dims, src_strides);
|
||||||
|
output[tid] = input[strided_src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define INDEX_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
@ -83,10 +105,14 @@ METAL_FUNC void gather(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const INDEX_TYPENAME input_i = input_ids[tid];
|
const INDEX_TYPENAME input_i = input_ids[tid];
|
||||||
const size_t right_rank_i = tid % right_size;
|
if (input_i == max_value<INDEX_TYPENAME>()) {
|
||||||
const size_t left_rank_i = tid / right_size / ids_size;
|
output[tid] = static_cast<TYPENAME>(0);
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
} else {
|
||||||
output[tid] = input[src_i];
|
const size_t right_rank_i = tid % right_size;
|
||||||
|
const size_t left_rank_i = tid / right_size / ids_size;
|
||||||
|
const size_t src_i = (left_rank_i * src_dim_size + input_i) * right_size + right_rank_i;
|
||||||
|
output[tid] = input[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define GATHER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
@ -124,8 +150,10 @@ METAL_FUNC void scatter(
|
|||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||||
output[dst_i] = input[src_i];
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
|
output[dst_i] = input[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -149,8 +177,10 @@ METAL_FUNC void scatter_add(
|
|||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||||
output[dst_i] += input[src_i];
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
|
output[dst_i] += input[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -204,9 +234,11 @@ METAL_FUNC void index_add(
|
|||||||
const size_t left_rank_i = tid / right_size;
|
const size_t left_rank_i = tid / right_size;
|
||||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const INDEX_TYPENAME idx = input_ids[j];
|
const INDEX_TYPENAME idx = input_ids[j];
|
||||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
if (idx < max_value<INDEX_TYPENAME>()) {
|
||||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
output[dst_i] += input[src_i];
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
|
output[dst_i] += input[src_i];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -583,7 +583,13 @@ fn simple_eval_(
|
|||||||
&Device::Cpu,
|
&Device::Cpu,
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
let xs = Tensor::ones(input.shape(), value.dtype(), input.device())?
|
let shape_vec: Vec<usize> = input
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.iter()
|
||||||
|
.map(|&x| x as usize)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let xs = Tensor::ones(shape_vec, value.dtype(), input.device())?
|
||||||
.broadcast_mul(&value)?;
|
.broadcast_mul(&value)?;
|
||||||
values.insert(node.output[0].clone(), xs);
|
values.insert(node.output[0].clone(), xs);
|
||||||
}
|
}
|
||||||
@ -1238,7 +1244,7 @@ fn simple_eval_(
|
|||||||
}
|
}
|
||||||
|
|
||||||
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
let indexes = Tensor::arange_step(s, e, p, data.device())?;
|
||||||
out = out.index_select(&indexes, axis)?
|
out = out.contiguous()?.index_select(&indexes, axis)?
|
||||||
}
|
}
|
||||||
values.insert(node.output[0].clone(), out);
|
values.insert(node.output[0].clone(), out);
|
||||||
}
|
}
|
||||||
@ -1960,6 +1966,273 @@ fn simple_eval_(
|
|||||||
let output = input.sign()?;
|
let output = input.sign()?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Resize" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
|
||||||
|
if input.rank() != 4 {
|
||||||
|
bail!("Unsupported rank for nearest resize: {}", input.rank());
|
||||||
|
}
|
||||||
|
|
||||||
|
let scales = if node.input.len() > 2 && !node.input[2].is_empty() {
|
||||||
|
Some(get(&node.input[2])?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let sizes = if node.input.len() > 3 && !node.input[3].is_empty() {
|
||||||
|
Some(get(&node.input[3])?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let output_dims = match (scales, sizes) {
|
||||||
|
(Some(_), Some(_)) => {
|
||||||
|
bail!("Scales and sizes cannot both be set for Resize operation")
|
||||||
|
}
|
||||||
|
(Some(scales_tensor), None) => {
|
||||||
|
let scale_values = scales_tensor.to_vec1::<f32>()?;
|
||||||
|
input
|
||||||
|
.dims()
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, &d)| (d as f32 * scale_values[i]) as usize)
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
(None, Some(sizes_tensor)) => sizes_tensor
|
||||||
|
.to_vec1::<i64>()?
|
||||||
|
.iter()
|
||||||
|
.map(|&d| d as usize)
|
||||||
|
.collect::<Vec<_>>(),
|
||||||
|
(None, None) => bail!("Either scales or sizes should be present"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let coordinate_transformation_mode =
|
||||||
|
get_attr_opt::<str>(node, "coordinate_transformation_mode")?
|
||||||
|
.unwrap_or("half_pixel");
|
||||||
|
// Interpolation mode: nearest, linear, or cubic.
|
||||||
|
let mode = get_attr_opt::<str>(node, "mode")?.unwrap_or("nearest");
|
||||||
|
// How to determine the "nearest" pixel in nearest interpolation mode.
|
||||||
|
let nearest_mode =
|
||||||
|
get_attr_opt::<str>(node, "nearest_mode")?.unwrap_or("round_prefer_floor");
|
||||||
|
|
||||||
|
if mode != "nearest" {
|
||||||
|
bail!("Unsupported resize mode: {}", mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
if nearest_mode != "floor" {
|
||||||
|
bail!("Unsupported nearest_mode for resize: {}", nearest_mode);
|
||||||
|
}
|
||||||
|
|
||||||
|
if coordinate_transformation_mode != "asymmetric" {
|
||||||
|
bail!(
|
||||||
|
"Unsupported coordinate_transformation_mode for resize: {}",
|
||||||
|
coordinate_transformation_mode
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let h = output_dims[2];
|
||||||
|
let w = output_dims[3];
|
||||||
|
let output = input.upsample_nearest2d(h, w)?;
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"Trilu" => {
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
|
||||||
|
// Get the diagonal offset 'k' from the second input if provided
|
||||||
|
let k = if node.input.len() > 1 && !node.input[1].is_empty() {
|
||||||
|
get(&node.input[1])?.to_vec0::<i64>()?
|
||||||
|
} else {
|
||||||
|
0
|
||||||
|
};
|
||||||
|
|
||||||
|
// Get the 'upper' attribute
|
||||||
|
let upper = get_attr_opt::<i64>(node, "upper")?.copied().unwrap_or(1);
|
||||||
|
|
||||||
|
// For batched inputs, we need to handle each matrix separately
|
||||||
|
let dims = input.dims();
|
||||||
|
if dims.len() < 2 {
|
||||||
|
bail!("Trilu expects input with at least 2 dimensions: {:?}", dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the last two dimensions which represent the matrix
|
||||||
|
let n = dims[dims.len() - 2];
|
||||||
|
let m = dims[dims.len() - 1];
|
||||||
|
let max_dim = std::cmp::max(n, m);
|
||||||
|
|
||||||
|
// Handle the diagonal offset k
|
||||||
|
let mask = if k != 0 {
|
||||||
|
let mut data = vec![0u32; n * m];
|
||||||
|
for i in 0..n {
|
||||||
|
for j in 0..m {
|
||||||
|
if (upper != 0 && (j as i64) >= (i as i64) + k)
|
||||||
|
|| (upper == 0 && (j as i64) <= (i as i64) + k)
|
||||||
|
{
|
||||||
|
data[i * m + j] = 1u32;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Tensor::from_vec(data, (n, m), input.device())?.to_dtype(input.dtype())?
|
||||||
|
} else if upper == 0 {
|
||||||
|
Tensor::tril2(max_dim, input.dtype(), input.device())?
|
||||||
|
} else {
|
||||||
|
Tensor::triu2(max_dim, input.dtype(), input.device())?
|
||||||
|
};
|
||||||
|
|
||||||
|
let final_mask = if n != m {
|
||||||
|
mask.narrow(0, 0, n)?.narrow(1, 0, m)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
|
||||||
|
let output = (input * &final_mask)?;
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
|
"ScatterND" => {
|
||||||
|
let data = get(&node.input[0])?;
|
||||||
|
|
||||||
|
let indices = get(&node.input[1])?;
|
||||||
|
let indices = indices.to_dtype(DType::I64)?;
|
||||||
|
|
||||||
|
let updates = get(&node.input[2])?;
|
||||||
|
|
||||||
|
let reduction = get_attr_opt::<str>(node, "reduction")?.unwrap_or("none");
|
||||||
|
|
||||||
|
let indices_shape = indices.dims();
|
||||||
|
let data_shape = data.dims();
|
||||||
|
let updates_shape = updates.dims();
|
||||||
|
|
||||||
|
// Last dimension of indices represents the depth of indexing
|
||||||
|
let k = indices_shape.last().unwrap().clone();
|
||||||
|
|
||||||
|
if k > data.rank() {
|
||||||
|
bail!("ScatterND expects k (indices.shape[-1]) to be at most the rank of data");
|
||||||
|
}
|
||||||
|
|
||||||
|
let num_updates = indices_shape[..indices_shape.len() - 1]
|
||||||
|
.iter()
|
||||||
|
.product::<usize>();
|
||||||
|
|
||||||
|
let flat_indices = if indices.rank() == 1 && k == 1 {
|
||||||
|
indices.unsqueeze(0)?
|
||||||
|
} else {
|
||||||
|
indices.reshape((num_updates, k))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculate the shape of each update element
|
||||||
|
let update_element_shape = if k < data_shape.len() {
|
||||||
|
data_shape[k..].to_vec()
|
||||||
|
} else {
|
||||||
|
vec![]
|
||||||
|
};
|
||||||
|
|
||||||
|
// Expected shape for updates based on indices and target tensor
|
||||||
|
let expected_updates_shape = {
|
||||||
|
let mut shape = indices_shape[..indices_shape.len() - 1].to_vec();
|
||||||
|
shape.extend(&update_element_shape);
|
||||||
|
shape
|
||||||
|
};
|
||||||
|
|
||||||
|
// Validate or reshape updates to expected shape
|
||||||
|
let updates = if updates.dims() != expected_updates_shape {
|
||||||
|
if updates.rank() == 0 {
|
||||||
|
// Handle scalar updates
|
||||||
|
let mut target_shape = vec![num_updates];
|
||||||
|
target_shape.extend(&update_element_shape);
|
||||||
|
updates.broadcast_as(target_shape)?
|
||||||
|
} else {
|
||||||
|
// Try to broadcast or reshape updates to expected shape
|
||||||
|
let flat_shape =
|
||||||
|
vec![num_updates, update_element_shape.iter().product::<usize>()];
|
||||||
|
let flattened = updates.reshape(flat_shape)?;
|
||||||
|
flattened.reshape(expected_updates_shape)?
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
updates.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut output = data.clone();
|
||||||
|
|
||||||
|
// convert indices to flat indices
|
||||||
|
let mut flat_output = output.flatten_all()?;
|
||||||
|
let flat_updates = if update_element_shape.is_empty() {
|
||||||
|
updates.reshape(num_updates)?
|
||||||
|
} else {
|
||||||
|
let product = update_element_shape.iter().product::<usize>();
|
||||||
|
updates.reshape((num_updates, product))?
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calculate strides for the output tensor
|
||||||
|
let mut strides: Vec<usize> = vec![1];
|
||||||
|
for i in (0..data_shape.len() - 1).rev() {
|
||||||
|
strides.push(strides.last().unwrap() * data_shape[i + 1]);
|
||||||
|
}
|
||||||
|
strides.reverse();
|
||||||
|
|
||||||
|
// Process each update
|
||||||
|
for i in 0..num_updates {
|
||||||
|
let index_slice = flat_indices.narrow(0, i, 1)?;
|
||||||
|
let indices_vec = index_slice.squeeze(0)?.to_vec1::<i64>()?;
|
||||||
|
|
||||||
|
// Convert multi-dimensional indices to flat index
|
||||||
|
let mut flat_idx: usize = 0;
|
||||||
|
for (dim, &idx) in indices_vec.iter().enumerate() {
|
||||||
|
let dim_size = data_shape[dim] as i64;
|
||||||
|
let norm_idx = if idx < 0 { dim_size + idx } else { idx };
|
||||||
|
|
||||||
|
if norm_idx < 0 || norm_idx >= dim_size {
|
||||||
|
bail!(
|
||||||
|
"Index {} out of bounds for dimension {} with size {}",
|
||||||
|
idx,
|
||||||
|
dim,
|
||||||
|
dim_size
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
flat_idx += (norm_idx as usize) * strides[dim];
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract current update
|
||||||
|
let update_slice = if update_element_shape.is_empty() {
|
||||||
|
flat_updates.narrow(0, i, 1)?.squeeze(0)?
|
||||||
|
} else {
|
||||||
|
flat_updates.narrow(0, i, 1)?
|
||||||
|
};
|
||||||
|
|
||||||
|
match reduction {
|
||||||
|
"add" => {
|
||||||
|
if update_element_shape.is_empty() {
|
||||||
|
let existing = flat_output.narrow(0, flat_idx, 1)?;
|
||||||
|
let new_value = existing.add(&update_slice.unsqueeze(0)?)?;
|
||||||
|
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||||
|
} else {
|
||||||
|
let slice_size = update_element_shape.iter().product::<usize>();
|
||||||
|
let existing = flat_output.narrow(0, flat_idx, slice_size)?;
|
||||||
|
let new_value = existing.add(&update_slice)?;
|
||||||
|
flat_output = flat_output.slice_scatter(&new_value, 0, flat_idx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"none" | _ => {
|
||||||
|
if update_element_shape.is_empty() {
|
||||||
|
flat_output = flat_output.slice_scatter(
|
||||||
|
&update_slice.unsqueeze(0)?,
|
||||||
|
0,
|
||||||
|
flat_idx,
|
||||||
|
)?;
|
||||||
|
} else {
|
||||||
|
flat_output =
|
||||||
|
flat_output.slice_scatter(&update_slice, 0, flat_idx)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reshape flat output back to original shape
|
||||||
|
output = flat_output.reshape(data_shape.to_vec())?;
|
||||||
|
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
op_type => bail!("unsupported op_type {op_type} for op {node:?}"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -842,13 +842,22 @@ fn test_flatten_operation() -> Result<()> {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_constant_of_shape() -> Result<()> {
|
fn test_constant_of_shape() -> Result<()> {
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||||
test(&[4i64, 3, 2], Some(1.), &[1., 1., 1.])?;
|
test(
|
||||||
|
&[4i64, 3, 2],
|
||||||
|
Some(1.),
|
||||||
|
&[
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
[[1., 1.], [1., 1.], [1., 1.]],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
// https://github.com/onnx/onnx/blob/main/docs/Operators.md#examples-31
|
||||||
test(&[0.], Some(0i64), &[0i64])?;
|
test(&[1i64], Some(0i64), &[0i64])?;
|
||||||
|
|
||||||
// "value" defaults to 0 f32
|
// "value" defaults to 0 f32
|
||||||
test(&[1i64, 2, 3, 4], None as Option<i64>, &[0., 0., 0., 0.])?;
|
test(&[4i64], None as Option<i64>, &[0., 0., 0., 0.])?;
|
||||||
|
|
||||||
fn test(
|
fn test(
|
||||||
input: impl NdArray,
|
input: impl NdArray,
|
||||||
@ -5968,3 +5977,512 @@ fn test_sign_operation() -> Result<()> {
|
|||||||
);
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_scatternd_operation() -> Result<()> {
|
||||||
|
// Example 1 based on ONNX documentation
|
||||||
|
test(
|
||||||
|
&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
|
||||||
|
&[[4i64], [3], [1], [7]],
|
||||||
|
&[9.0f32, 10.0, 11.0, 12.0],
|
||||||
|
&[1.0f32, 11.0, 3.0, 10.0, 9.0, 6.0, 7.0, 12.0],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// A more complex example with 2D data
|
||||||
|
test(
|
||||||
|
&[[1.0f32, 2.0], [3.0, 4.0], [5.0, 6.0]],
|
||||||
|
&[[0i64, 1], [1, 0]],
|
||||||
|
&[10.0f32, 20.0],
|
||||||
|
&[[1.0f32, 10.0], [20.0, 4.0], [5.0, 6.0]],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// 3D example with indices pointing to specific locations
|
||||||
|
test(
|
||||||
|
&[
|
||||||
|
[[1.0f32, 2.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [7.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
],
|
||||||
|
&[[0i64, 0, 1], [1, 1, 0]],
|
||||||
|
&[100.0f32, 200.0],
|
||||||
|
&[
|
||||||
|
[[1.0f32, 100.0], [3.0, 4.0]],
|
||||||
|
[[5.0, 6.0], [200.0, 8.0]],
|
||||||
|
[[9.0, 10.0], [11.0, 12.0]],
|
||||||
|
],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
fn test(
|
||||||
|
data: impl NdArray,
|
||||||
|
indices: impl NdArray,
|
||||||
|
updates: impl NdArray,
|
||||||
|
expected: impl NdArray,
|
||||||
|
) -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "ScatterND".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![
|
||||||
|
INPUT_X.to_string(),
|
||||||
|
INPUT_Y.to_string(),
|
||||||
|
INPUT_A.to_string(),
|
||||||
|
],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), Tensor::new(data, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), Tensor::new(indices, &Device::Cpu)?);
|
||||||
|
inputs.insert(INPUT_A.to_string(), Tensor::new(updates, &Device::Cpu)?);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let expected = Tensor::new(expected, &Device::Cpu)?;
|
||||||
|
|
||||||
|
match expected.dims().len() {
|
||||||
|
1 => assert_eq!(z.to_vec1::<f32>()?, expected.to_vec1::<f32>()?),
|
||||||
|
2 => assert_eq!(z.to_vec2::<f32>()?, expected.to_vec2::<f32>()?),
|
||||||
|
3 => assert_eq!(z.to_vec3::<f32>()?, expected.to_vec3::<f32>()?),
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trilu_operation() -> Result<()> {
|
||||||
|
// Test 1: Upper triangular matrix (default behavior with upper=true)
|
||||||
|
{
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Trilu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![], // empty attribute means default upper=true
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||||
|
],
|
||||||
|
&[4, 5],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<i64>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
vec![4, 7, 3, 7, 9],
|
||||||
|
vec![0, 2, 8, 6, 9],
|
||||||
|
vec![0, 0, 0, 8, 7],
|
||||||
|
vec![0, 0, 0, 2, 4]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 2: Upper triangular with positive k=1 (diagonal above main)
|
||||||
|
{
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Trilu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![1i64, 4, 9, 7, 1, 9, 2, 8, 8, 4, 3, 9, 7, 4, 2],
|
||||||
|
&[3, 5],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let k = Tensor::from_vec(vec![1i64], (), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), k);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<i64>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
vec![0, 4, 9, 7, 1],
|
||||||
|
vec![0, 0, 8, 8, 4],
|
||||||
|
vec![0, 0, 0, 4, 2]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 3: Upper triangular with negative k=-1 (one diagonal below main)
|
||||||
|
{
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Trilu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 0, 8, 7, 4, 3, 4, 2, 4,
|
||||||
|
],
|
||||||
|
&[4, 5],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), k);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<i64>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
vec![4, 7, 3, 7, 9],
|
||||||
|
vec![1, 2, 8, 6, 9],
|
||||||
|
vec![0, 4, 0, 8, 7],
|
||||||
|
vec![0, 0, 4, 2, 4]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 4: Lower triangular matrix (upper=0)
|
||||||
|
{
|
||||||
|
let att_upper = AttributeProto {
|
||||||
|
name: "upper".to_string(),
|
||||||
|
ref_attr_name: "upper".to_string(),
|
||||||
|
i: 0, // 0 means false, use lower triangular
|
||||||
|
doc_string: "upper".to_string(),
|
||||||
|
r#type: 2,
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Trilu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_upper],
|
||||||
|
input: vec![INPUT_X.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||||
|
],
|
||||||
|
&[4, 5],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<i64>()?;
|
||||||
|
|
||||||
|
// Lower triangular matrix (default k=0)
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
vec![4, 0, 0, 0, 0],
|
||||||
|
vec![1, 2, 0, 0, 0],
|
||||||
|
vec![9, 4, 1, 0, 0],
|
||||||
|
vec![4, 3, 4, 2, 0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 5: Lower triangular with negative k=-1
|
||||||
|
{
|
||||||
|
let att_upper = AttributeProto {
|
||||||
|
name: "upper".to_string(),
|
||||||
|
ref_attr_name: "upper".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "upper".to_string(),
|
||||||
|
r#type: 2,
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Trilu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_upper],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||||
|
],
|
||||||
|
&[4, 5],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let k = Tensor::from_vec(vec![-1i64], (), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), k);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<i64>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
vec![0, 0, 0, 0, 0],
|
||||||
|
vec![1, 0, 0, 0, 0],
|
||||||
|
vec![9, 4, 0, 0, 0],
|
||||||
|
vec![4, 3, 4, 0, 0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test 6: Lower triangular with positive k=2
|
||||||
|
{
|
||||||
|
let att_upper = AttributeProto {
|
||||||
|
name: "upper".to_string(),
|
||||||
|
ref_attr_name: "upper".to_string(),
|
||||||
|
i: 0,
|
||||||
|
doc_string: "upper".to_string(),
|
||||||
|
r#type: 2,
|
||||||
|
f: 0.0,
|
||||||
|
s: vec![],
|
||||||
|
t: None,
|
||||||
|
g: None,
|
||||||
|
sparse_tensor: None,
|
||||||
|
tp: None,
|
||||||
|
floats: vec![],
|
||||||
|
ints: vec![],
|
||||||
|
strings: vec![],
|
||||||
|
tensors: vec![],
|
||||||
|
graphs: vec![],
|
||||||
|
sparse_tensors: vec![],
|
||||||
|
type_protos: vec![],
|
||||||
|
};
|
||||||
|
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "Trilu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![att_upper],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
|
||||||
|
let x = Tensor::from_vec(
|
||||||
|
vec![
|
||||||
|
4i64, 7, 3, 7, 9, 1, 2, 8, 6, 9, 9, 4, 1, 8, 7, 4, 3, 4, 2, 4,
|
||||||
|
],
|
||||||
|
&[4, 5],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let k = Tensor::from_vec(vec![2i64], (), &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), k);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<i64>()?;
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
results,
|
||||||
|
vec![
|
||||||
|
vec![4, 7, 3, 0, 0],
|
||||||
|
vec![1, 2, 8, 6, 0],
|
||||||
|
vec![9, 4, 1, 8, 7],
|
||||||
|
vec![4, 3, 4, 2, 4]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
@ -869,8 +869,8 @@ impl Moe {
|
|||||||
}
|
}
|
||||||
|
|
||||||
enum MoeOrMlp {
|
enum MoeOrMlp {
|
||||||
Moe(Moe),
|
Moe(Box<Moe>),
|
||||||
Mlp(Mlp),
|
Mlp(Box<Mlp>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MoeOrMlp {
|
impl MoeOrMlp {
|
||||||
@ -908,14 +908,17 @@ impl DecoderLayer {
|
|||||||
&& layer_idx >= cfg.first_k_dense_replace
|
&& layer_idx >= cfg.first_k_dense_replace
|
||||||
&& layer_idx % cfg.moe_layer_freq == 0
|
&& layer_idx % cfg.moe_layer_freq == 0
|
||||||
{
|
{
|
||||||
MoeOrMlp::Moe(Moe::new(
|
MoeOrMlp::Moe(
|
||||||
cfg,
|
Moe::new(
|
||||||
vb.pp("mlp"),
|
cfg,
|
||||||
cfg.n_shared_experts,
|
vb.pp("mlp"),
|
||||||
cfg.n_routed_experts.unwrap(),
|
cfg.n_shared_experts,
|
||||||
)?)
|
cfg.n_routed_experts.unwrap(),
|
||||||
|
)?
|
||||||
|
.into(),
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?)
|
MoeOrMlp::Mlp(Mlp::new(cfg, vb.pp("mlp"), None, None)?.into())
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
@ -70,6 +70,7 @@ pub mod moondream;
|
|||||||
pub mod mpt;
|
pub mod mpt;
|
||||||
pub mod nvembed_v2;
|
pub mod nvembed_v2;
|
||||||
pub mod olmo;
|
pub mod olmo;
|
||||||
|
pub mod olmo2;
|
||||||
pub mod openclip;
|
pub mod openclip;
|
||||||
pub mod paligemma;
|
pub mod paligemma;
|
||||||
pub mod parler_tts;
|
pub mod parler_tts;
|
||||||
@ -90,6 +91,7 @@ pub mod quantized_mpt;
|
|||||||
pub mod quantized_phi;
|
pub mod quantized_phi;
|
||||||
pub mod quantized_phi3;
|
pub mod quantized_phi3;
|
||||||
pub mod quantized_qwen2;
|
pub mod quantized_qwen2;
|
||||||
|
pub mod quantized_qwen3;
|
||||||
pub mod quantized_recurrent_gemma;
|
pub mod quantized_recurrent_gemma;
|
||||||
pub mod quantized_rwkv_v5;
|
pub mod quantized_rwkv_v5;
|
||||||
pub mod quantized_rwkv_v6;
|
pub mod quantized_rwkv_v6;
|
||||||
@ -97,6 +99,8 @@ pub mod quantized_stable_lm;
|
|||||||
pub mod quantized_t5;
|
pub mod quantized_t5;
|
||||||
pub mod qwen2;
|
pub mod qwen2;
|
||||||
pub mod qwen2_moe;
|
pub mod qwen2_moe;
|
||||||
|
pub mod qwen3;
|
||||||
|
pub mod qwen3_moe;
|
||||||
pub mod recurrent_gemma;
|
pub mod recurrent_gemma;
|
||||||
pub mod repvgg;
|
pub mod repvgg;
|
||||||
pub mod resnet;
|
pub mod resnet;
|
||||||
|
348
candle-transformers/src/models/olmo2.rs
Normal file
348
candle-transformers/src/models/olmo2.rs
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
//! OLMo 2 (Open Language Model) implementation
|
||||||
|
//!
|
||||||
|
//! See OLMo 2 model details at:
|
||||||
|
//! - [Hugging Face Collection](https://huggingface.co/collections/allenai/olmo-2-674117b93ab84e98afc72edc)
|
||||||
|
//! - [OLMo 2 Paper](https://arxiv.org/abs/2501.00656)
|
||||||
|
//!
|
||||||
|
//!
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{linear_b, linear_no_bias, rms_norm, Activation, Linear, RmsNorm, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub attention_bias: bool,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
pub hidden_act: candle_nn::Activation,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
pub clip_qkv: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.hidden_size / cfg.num_attention_heads;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?,
|
||||||
|
cos: freqs.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
#[allow(clippy::upper_case_acronyms)]
|
||||||
|
struct MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MLP {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let intermediate_sz = cfg.intermediate_size;
|
||||||
|
let gate_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("gate_proj"))?;
|
||||||
|
let up_proj = linear_no_bias(hidden_sz, intermediate_sz, vb.pp("up_proj"))?;
|
||||||
|
let down_proj = linear_no_bias(intermediate_sz, hidden_sz, vb.pp("down_proj"))?;
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MLP {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = xs.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = xs.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Attention {
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Attention {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let hidden_sz = cfg.hidden_size;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
let head_dim = hidden_sz / num_heads;
|
||||||
|
let b = cfg.attention_bias;
|
||||||
|
let q_proj = linear_b(hidden_sz, num_heads * head_dim, b, vb.pp("q_proj"))?;
|
||||||
|
let k_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("k_proj"))?;
|
||||||
|
let v_proj = linear_b(hidden_sz, num_kv_heads * head_dim, b, vb.pp("v_proj"))?;
|
||||||
|
let o_proj = linear_b(num_heads * head_dim, hidden_sz, b, vb.pp("o_proj"))?;
|
||||||
|
let q_norm = rms_norm(hidden_sz, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||||
|
let k_norm = rms_norm(num_kv_heads * head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size: hidden_sz,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b_sz, q_len, _) = xs.dims3()?;
|
||||||
|
|
||||||
|
let query_states = self.q_proj.forward(xs)?;
|
||||||
|
let key_states = self.k_proj.forward(xs)?;
|
||||||
|
let value_states = self.v_proj.forward(xs)?;
|
||||||
|
|
||||||
|
let query_states = self.q_norm.forward(&query_states)?;
|
||||||
|
let key_states = self.k_norm.forward(&key_states)?;
|
||||||
|
|
||||||
|
let query_states = query_states
|
||||||
|
.reshape((b_sz, q_len, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let key_states = key_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let value_states = value_states
|
||||||
|
.reshape((b_sz, q_len, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let (query_states, key_states) =
|
||||||
|
self.rotary_emb
|
||||||
|
.apply_rotary_emb_qkv(&query_states, &key_states, seqlen_offset)?;
|
||||||
|
|
||||||
|
let (key_states, value_states) = match &self.kv_cache {
|
||||||
|
None => (key_states, value_states),
|
||||||
|
Some((prev_k, prev_v)) => {
|
||||||
|
let key_states = Tensor::cat(&[prev_k, &key_states], 2)?;
|
||||||
|
let value_states = Tensor::cat(&[prev_v, &value_states], 2)?;
|
||||||
|
(key_states, value_states)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((key_states.clone(), value_states.clone()));
|
||||||
|
|
||||||
|
let key_states = crate::utils::repeat_kv(key_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let value_states =
|
||||||
|
crate::utils::repeat_kv(value_states, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
let attn_output = {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (query_states.matmul(&key_states.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
let attn_weights = match attention_mask {
|
||||||
|
None => attn_weights,
|
||||||
|
Some(mask) => attn_weights.broadcast_add(mask)?,
|
||||||
|
};
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&value_states)?
|
||||||
|
};
|
||||||
|
attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, q_len, self.hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache = None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Attention,
|
||||||
|
mlp: MLP,
|
||||||
|
post_attention_layernorm: RmsNorm,
|
||||||
|
post_feedforward_layernorm: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(rotary_emb: Arc<RotaryEmbedding>, cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_attn = Attention::new(rotary_emb, cfg, vb.pp("self_attn"))?;
|
||||||
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let post_feedforward_layernorm = rms_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_feedforward_layernorm"),
|
||||||
|
)?;
|
||||||
|
let post_attention_layernorm = rms_norm(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
post_attention_layernorm,
|
||||||
|
post_feedforward_layernorm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(
|
||||||
|
&mut self,
|
||||||
|
xs: &Tensor,
|
||||||
|
attention_mask: Option<&Tensor>,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let residual = xs;
|
||||||
|
let xs = self.self_attn.forward(xs, attention_mask, seqlen_offset)?;
|
||||||
|
let xs = self.post_attention_layernorm.forward(&xs)?;
|
||||||
|
let xs = (xs + residual)?;
|
||||||
|
let residual = &xs;
|
||||||
|
let xs = self.mlp.forward(&xs)?;
|
||||||
|
let xs = self.post_feedforward_layernorm.forward(&xs)?;
|
||||||
|
residual + xs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
lm_head: Linear,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let vb_m = vb.pp("model");
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb_m.pp("layers");
|
||||||
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
|
let layer = DecoderLayer::new(rotary_emb.clone(), cfg, vb_l.pp(layer_idx))?;
|
||||||
|
layers.push(layer)
|
||||||
|
}
|
||||||
|
let norm = rms_norm(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::new(embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
lm_head,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
&self,
|
||||||
|
b_size: usize,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
// Sliding window mask?
|
||||||
|
let mask: Vec<_> = (0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||||
|
.collect();
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len) = input_ids.dims2()?;
|
||||||
|
let attention_mask = if seq_len <= 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||||
|
Some(mask)
|
||||||
|
};
|
||||||
|
let mut xs = self.embed_tokens.forward(input_ids)?;
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||||
|
}
|
||||||
|
xs.narrow(1, seq_len - 1, 1)?
|
||||||
|
.apply(&self.norm)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
layer.clear_kv_cache()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
@ -20,10 +20,24 @@
|
|||||||
// This implementation is based on:
|
// This implementation is based on:
|
||||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
|
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/modeling_phi3.py
|
||||||
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
use crate::models::with_tracing::{linear_no_bias as linear, Linear, RmsNorm};
|
||||||
use candle::{DType, Device, Module, Result, Tensor, D};
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub enum RopeScalingType {
|
||||||
|
#[serde(rename = "longrope")]
|
||||||
|
LongRope,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
|
pub struct RopeScaling {
|
||||||
|
pub short_factor: Vec<f32>,
|
||||||
|
pub long_factor: Vec<f32>,
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub type_: RopeScalingType,
|
||||||
|
}
|
||||||
|
|
||||||
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
|
// https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json
|
||||||
#[derive(Debug, Clone, serde::Deserialize)]
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
@ -38,8 +52,12 @@ pub struct Config {
|
|||||||
pub rope_theta: f64,
|
pub rope_theta: f64,
|
||||||
pub bos_token_id: Option<u32>,
|
pub bos_token_id: Option<u32>,
|
||||||
pub eos_token_id: Option<u32>,
|
pub eos_token_id: Option<u32>,
|
||||||
pub rope_scaling: Option<String>,
|
pub rope_scaling: Option<RopeScaling>,
|
||||||
pub max_position_embeddings: usize,
|
pub max_position_embeddings: usize,
|
||||||
|
pub original_max_position_embeddings: Option<usize>,
|
||||||
|
pub partial_rotary_factor: Option<f64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
@ -50,30 +68,88 @@ impl Config {
|
|||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct RotaryEmbedding {
|
pub struct RotaryEmbedding {
|
||||||
|
partial_dim: Option<usize>,
|
||||||
sin: Tensor,
|
sin: Tensor,
|
||||||
cos: Tensor,
|
cos: Tensor,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
pub fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
let dim = cfg.head_dim();
|
let partial_dim = cfg
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
.partial_rotary_factor
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
.as_ref()
|
||||||
.step_by(2)
|
.map(|v| (v * cfg.head_dim() as f64) as usize);
|
||||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
let dim = partial_dim.unwrap_or(cfg.head_dim());
|
||||||
.collect();
|
let freqs = match cfg.rope_scaling.as_ref() {
|
||||||
let inv_freq_len = inv_freq.len();
|
None => {
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.to_dtype(dtype)?
|
.step_by(2)
|
||||||
.reshape((max_seq_len, 1))?;
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
.collect();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, ()), dev)?.to_dtype(dtype)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
t.matmul(&inv_freq)?
|
||||||
|
}
|
||||||
|
Some(rope_scaling) => {
|
||||||
|
let inv_freq_s: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.zip(rope_scaling.short_factor.iter())
|
||||||
|
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_s = Tensor::from_vec(inv_freq_s, (1, ()), dev)?.to_dtype(dtype)?;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
match cfg.original_max_position_embeddings {
|
||||||
|
None => {
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
t.matmul(&inv_freq_s)?
|
||||||
|
}
|
||||||
|
Some(original_max_seq_len) => {
|
||||||
|
let t_s = Tensor::arange(0u32, original_max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((original_max_seq_len, 1))?;
|
||||||
|
let freq_s = t_s.matmul(&inv_freq_s)?;
|
||||||
|
let inv_freq_l: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.zip(rope_scaling.long_factor.iter())
|
||||||
|
.map(|(i, &f)| f / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_l =
|
||||||
|
Tensor::from_vec(inv_freq_l, (1, ()), dev)?.to_dtype(dtype)?;
|
||||||
|
let t_l =
|
||||||
|
Tensor::arange(original_max_seq_len as u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape(((), 1))?;
|
||||||
|
let freq_l = t_l.matmul(&inv_freq_l)?;
|
||||||
|
Tensor::cat(&[&freq_s, &freq_l], 0)?
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
|
partial_dim,
|
||||||
sin: freqs.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: freqs.cos()?,
|
cos: freqs.cos()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn rope(&self, xs: &Tensor, cos: &Tensor, sin: &Tensor) -> Result<Tensor> {
|
||||||
|
let x = match self.partial_dim {
|
||||||
|
None => candle_nn::rotary_emb::rope(&xs.contiguous()?, cos, sin)?,
|
||||||
|
Some(dim) => {
|
||||||
|
let xs_rot = xs.i((.., .., .., ..dim))?.contiguous()?;
|
||||||
|
let xs_pass = xs.i((.., .., .., dim..))?;
|
||||||
|
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, cos, sin)?;
|
||||||
|
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)?.contiguous()?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(x)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn apply_rotary_emb_qkv(
|
pub fn apply_rotary_emb_qkv(
|
||||||
&self,
|
&self,
|
||||||
q: &Tensor,
|
q: &Tensor,
|
||||||
@ -83,8 +159,8 @@ impl RotaryEmbedding {
|
|||||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let cos = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let sin = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
let q_embed = self.rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
let k_embed = self.rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
Ok((q_embed, k_embed))
|
Ok((q_embed, k_embed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -292,7 +368,11 @@ impl Model {
|
|||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
|
||||||
let lm_head = linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::from_weights(embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
embed_tokens,
|
embed_tokens,
|
||||||
layers,
|
layers,
|
||||||
|
429
candle-transformers/src/models/quantized_qwen3.rs
Normal file
429
candle-transformers/src/models/quantized_qwen3.rs
Normal file
@ -0,0 +1,429 @@
|
|||||||
|
//! Qwen3 implementation with quantization support.
|
||||||
|
//!
|
||||||
|
//! Based on the Qwen3 architecture and implemented with quantized weights
|
||||||
|
//! for reduced memory usage and faster inference on compatible hardware.
|
||||||
|
//!
|
||||||
|
//! References:
|
||||||
|
//! - [Qwen3 Models](https://huggingface.co/Qwen/Qwen3-0.6B) (architecture based on official implementations)
|
||||||
|
//!
|
||||||
|
use super::with_tracing::QMatMul;
|
||||||
|
use crate::{quantized_nn::RmsNorm, utils::repeat_kv};
|
||||||
|
use candle::quantized::{gguf_file, QTensor};
|
||||||
|
use candle::{DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::{kv_cache::KvCache, Activation, Embedding, Module};
|
||||||
|
use std::io::{Read, Seek};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
struct Gguf<R: Read + Seek> {
|
||||||
|
ct: gguf_file::Content,
|
||||||
|
reader: R,
|
||||||
|
device: Device,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<R: Read + Seek> Gguf<R> {
|
||||||
|
fn new(ct: gguf_file::Content, reader: R, device: Device) -> Self {
|
||||||
|
Self { ct, reader, device }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn qmatmul(&mut self, name: &str) -> Result<QMatMul> {
|
||||||
|
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
|
||||||
|
QMatMul::from_weights(ws.into())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn rms_norm(&mut self, name: &str, eps: f64) -> Result<RmsNorm> {
|
||||||
|
let ws = self.ct.tensor(&mut self.reader, name, &self.device)?;
|
||||||
|
RmsNorm::from_qtensor(ws, eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn metadata(&self) -> &std::collections::HashMap<String, gguf_file::Value> {
|
||||||
|
&self.ct.metadata
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tensor(&mut self, name: &str) -> Result<QTensor> {
|
||||||
|
self.ct.tensor(&mut self.reader, name, &self.device)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct MlpWeights {
|
||||||
|
gate_proj: QMatMul,
|
||||||
|
up_proj: QMatMul,
|
||||||
|
down_proj: QMatMul,
|
||||||
|
act_fn: Activation,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MlpWeights {
|
||||||
|
fn new<R: Read + Seek>(gg: &mut Gguf<R>, prefix: &str) -> Result<Self> {
|
||||||
|
let gate_proj = gg.qmatmul(&format!("{prefix}.ffn_gate.weight"))?;
|
||||||
|
let up_proj = gg.qmatmul(&format!("{prefix}.ffn_up.weight"))?;
|
||||||
|
let down_proj = gg.qmatmul(&format!("{prefix}.ffn_down.weight"))?;
|
||||||
|
let act_fn = Activation::Silu;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "mlp");
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj,
|
||||||
|
up_proj,
|
||||||
|
down_proj,
|
||||||
|
act_fn,
|
||||||
|
span,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for MlpWeights {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let gate = self.gate_proj.forward(x)?.apply(&self.act_fn)?;
|
||||||
|
let up = self.up_proj.forward(x)?;
|
||||||
|
let gated = (gate * up)?;
|
||||||
|
self.down_proj.forward(&gated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(
|
||||||
|
dtype: DType,
|
||||||
|
head_dim: usize,
|
||||||
|
max_position_embeddings: usize,
|
||||||
|
rope_theta: f64,
|
||||||
|
dev: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let dim = head_dim;
|
||||||
|
let max_seq_len = max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?,
|
||||||
|
cos: freqs.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply RoPE (q, k shape: B x H x L x D)
|
||||||
|
fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?.to_dtype(q.dtype())?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct AttentionWeights {
|
||||||
|
q_proj: QMatMul,
|
||||||
|
k_proj: QMatMul,
|
||||||
|
v_proj: QMatMul,
|
||||||
|
o_proj: QMatMul,
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
kv_cache: KvCache,
|
||||||
|
span_attn: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AttentionWeights {
|
||||||
|
fn new<R: Read + Seek>(
|
||||||
|
gg: &mut Gguf<R>,
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
rms_norm_eps: f64,
|
||||||
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
|
prefix: &str,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
let q_proj = gg.qmatmul(&format!("{prefix}.attn_q.weight"))?;
|
||||||
|
let k_proj = gg.qmatmul(&format!("{prefix}.attn_k.weight"))?;
|
||||||
|
let v_proj = gg.qmatmul(&format!("{prefix}.attn_v.weight"))?;
|
||||||
|
let o_proj = gg.qmatmul(&format!("{prefix}.attn_output.weight"))?;
|
||||||
|
|
||||||
|
let q_norm = gg.rms_norm(&format!("{prefix}.attn_q_norm.weight"), rms_norm_eps)?;
|
||||||
|
let k_norm = gg.rms_norm(&format!("{prefix}.attn_k_norm.weight"), rms_norm_eps)?;
|
||||||
|
|
||||||
|
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
|
||||||
|
// The cache will grow in chunks of 512 tokens when needed.
|
||||||
|
let kv_cache = KvCache::new(2, 512);
|
||||||
|
|
||||||
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache,
|
||||||
|
span_attn,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, x: &Tensor, attn_mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_attn.enter();
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
let q = self.q_proj.forward(x)?;
|
||||||
|
let k = self.k_proj.forward(x)?;
|
||||||
|
let v = self.v_proj.forward(x)?;
|
||||||
|
|
||||||
|
let q = q
|
||||||
|
.reshape((b, l, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let q_flat = q.flatten(0, 2)?;
|
||||||
|
let k_flat = k.flatten(0, 2)?;
|
||||||
|
|
||||||
|
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||||
|
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||||
|
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
|
||||||
|
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
|
||||||
|
|
||||||
|
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// Reset KV cache if we're at the first position
|
||||||
|
if offset == 0 {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||||
|
|
||||||
|
// Make tensor contiguous to avoid some strided copies
|
||||||
|
let k = k.contiguous()?;
|
||||||
|
let v = v.contiguous()?;
|
||||||
|
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?.contiguous()?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?.contiguous()?;
|
||||||
|
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
let m_dtype = m.dtype();
|
||||||
|
let scores_dtype = scores.dtype();
|
||||||
|
let mask = if m_dtype != scores_dtype {
|
||||||
|
m.to_dtype(scores_dtype)?
|
||||||
|
} else {
|
||||||
|
m.clone()
|
||||||
|
};
|
||||||
|
scores = scores.broadcast_add(&mask)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||||
|
let reshaped_ctx = ctx
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b, l, self.num_heads * self.head_dim))?;
|
||||||
|
self.o_proj.forward(&reshaped_ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LayerWeights {
|
||||||
|
self_attn: AttentionWeights,
|
||||||
|
mlp: MlpWeights,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerWeights {
|
||||||
|
fn new<R: Read + Seek>(
|
||||||
|
gg: &mut Gguf<R>,
|
||||||
|
num_attention_heads: usize,
|
||||||
|
num_key_value_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
rms_norm_eps: f64,
|
||||||
|
rotary: Arc<RotaryEmbedding>,
|
||||||
|
layer_idx: usize,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let prefix = format!("blk.{layer_idx}");
|
||||||
|
|
||||||
|
let ln1 = gg.rms_norm(&format!("{prefix}.attn_norm.weight"), rms_norm_eps)?;
|
||||||
|
let ln2 = gg.rms_norm(&format!("{prefix}.ffn_norm.weight"), rms_norm_eps)?;
|
||||||
|
let self_attn = AttentionWeights::new(
|
||||||
|
gg,
|
||||||
|
num_attention_heads,
|
||||||
|
num_key_value_heads,
|
||||||
|
head_dim,
|
||||||
|
rms_norm_eps,
|
||||||
|
rotary,
|
||||||
|
&prefix,
|
||||||
|
)?;
|
||||||
|
let mlp = MlpWeights::new(gg, &prefix)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.mlp)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelWeights {
|
||||||
|
embed_tokens: Embedding,
|
||||||
|
layers: Vec<LayerWeights>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
lm_head: QMatMul,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_output: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelWeights {
|
||||||
|
pub fn from_gguf<R: Read + Seek>(
|
||||||
|
ct: gguf_file::Content,
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let mut gg = Gguf::new(ct, reader, device.clone());
|
||||||
|
let md_get = |s: &str| match gg.metadata().get(s) {
|
||||||
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
|
Some(v) => Ok(v),
|
||||||
|
};
|
||||||
|
|
||||||
|
let num_attention_heads = md_get("qwen3.attention.head_count")?.to_u32()? as usize;
|
||||||
|
let num_kv_heads = md_get("qwen3.attention.head_count_kv")?.to_u32()? as usize;
|
||||||
|
let head_dim = md_get("qwen3.attention.key_length")?.to_u32()? as usize;
|
||||||
|
let num_layers = md_get("qwen3.block_count")?.to_u32()? as usize;
|
||||||
|
let hidden_size = md_get("qwen3.embedding_length")?.to_u32()? as usize;
|
||||||
|
let max_position_embeddings = md_get("qwen3.context_length")?.to_u32()? as usize;
|
||||||
|
let rms_norm_eps = md_get("qwen3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||||
|
let rope_freq_base = md_get("qwen3.rope.freq_base")?.to_f32()? as f64;
|
||||||
|
|
||||||
|
let dtype = match gg.metadata().get("general.dtype") {
|
||||||
|
Some(v) => match v.to_u32() {
|
||||||
|
Ok(0) => DType::F32,
|
||||||
|
Ok(1) => DType::F16,
|
||||||
|
_ => DType::F16,
|
||||||
|
},
|
||||||
|
None => DType::F16,
|
||||||
|
};
|
||||||
|
|
||||||
|
let embed_tensor = gg.tensor("token_embd.weight")?;
|
||||||
|
let embed_tokens = Embedding::new(embed_tensor.dequantize(device)?, hidden_size);
|
||||||
|
|
||||||
|
let rotary = Arc::new(RotaryEmbedding::new(
|
||||||
|
dtype,
|
||||||
|
head_dim,
|
||||||
|
max_position_embeddings,
|
||||||
|
rope_freq_base,
|
||||||
|
device,
|
||||||
|
)?);
|
||||||
|
|
||||||
|
let mut layers = Vec::with_capacity(num_layers);
|
||||||
|
for i in 0..num_layers {
|
||||||
|
layers.push(LayerWeights::new(
|
||||||
|
&mut gg,
|
||||||
|
num_attention_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
head_dim,
|
||||||
|
rms_norm_eps,
|
||||||
|
rotary.clone(),
|
||||||
|
i,
|
||||||
|
)?);
|
||||||
|
}
|
||||||
|
|
||||||
|
let norm = gg.rms_norm("output_norm.weight", rms_norm_eps)?;
|
||||||
|
// Load output projection tensor, falling back to tied embeddings like gemma3
|
||||||
|
let lm_head_tensor = match gg.tensor("output.weight") {
|
||||||
|
Ok(tensor) => tensor,
|
||||||
|
Err(_) => gg.tensor("token_embd.weight")?,
|
||||||
|
};
|
||||||
|
let lm_head = QMatMul::from_weights(lm_head_tensor.into())?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||||
|
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
lm_head,
|
||||||
|
device: device.clone(),
|
||||||
|
dtype,
|
||||||
|
span,
|
||||||
|
span_output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(
|
||||||
|
&self,
|
||||||
|
b: usize,
|
||||||
|
tgt: usize,
|
||||||
|
offset: usize,
|
||||||
|
sw: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..(tgt + offset)).map(move |j| {
|
||||||
|
let past_ok = j <= i + offset;
|
||||||
|
let sw_ok = match sw {
|
||||||
|
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||||
|
None => true,
|
||||||
|
};
|
||||||
|
if past_ok && sw_ok {
|
||||||
|
0.
|
||||||
|
} else {
|
||||||
|
minf
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
let causal_mask = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset, None)?)
|
||||||
|
};
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal_mask.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
let h = self.norm.forward(&h)?;
|
||||||
|
let _enter = self.span_output.enter();
|
||||||
|
let last_hidden = h.narrow(1, l - 1, 1)?;
|
||||||
|
self.lm_head.forward(&last_hidden)?.squeeze(1)
|
||||||
|
}
|
||||||
|
}
|
389
candle-transformers/src/models/qwen3.rs
Normal file
389
candle-transformers/src/models/qwen3.rs
Normal file
@ -0,0 +1,389 @@
|
|||||||
|
use crate::{
|
||||||
|
models::with_tracing::{linear_b, linear_no_bias, Linear, RmsNorm},
|
||||||
|
utils::repeat_kv,
|
||||||
|
};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor};
|
||||||
|
use candle_nn::{kv_cache::KvCache, Activation, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub attention_bias: bool,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub sliding_window: Option<usize>,
|
||||||
|
pub max_window_layers: usize,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
pub use_sliding_window: bool,
|
||||||
|
pub hidden_act: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Qwen3RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3RotaryEmbedding {
|
||||||
|
pub(crate) fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||||
|
let dim = cfg.head_dim;
|
||||||
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||||
|
.collect();
|
||||||
|
let inv_freq_len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
|
let t = Tensor::arange(0u32, max_seq_len as u32, dev)?
|
||||||
|
.to_dtype(dtype)?
|
||||||
|
.reshape((max_seq_len, 1))?;
|
||||||
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
|
Ok(Self {
|
||||||
|
sin: freqs.sin()?,
|
||||||
|
cos: freqs.cos()?,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Apply RoPE (q, k shape: B x H x L x D)
|
||||||
|
pub(crate) fn apply(&self, q: &Tensor, k: &Tensor, offset: usize) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_, _, seq_len, _) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, offset, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, offset, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Qwen3MLP {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3MLP {
|
||||||
|
pub(crate) fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("gate_proj"))?,
|
||||||
|
up_proj: linear_no_bias(cfg.hidden_size, cfg.intermediate_size, vb.pp("up_proj"))?,
|
||||||
|
down_proj: linear_no_bias(cfg.intermediate_size, cfg.hidden_size, vb.pp("down_proj"))?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3MLP {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = x.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub(crate) struct Qwen3Attention {
|
||||||
|
// projections
|
||||||
|
q_proj: Linear,
|
||||||
|
k_proj: Linear,
|
||||||
|
v_proj: Linear,
|
||||||
|
o_proj: Linear,
|
||||||
|
// norms
|
||||||
|
q_norm: RmsNorm,
|
||||||
|
k_norm: RmsNorm,
|
||||||
|
// hyper params
|
||||||
|
num_heads: usize,
|
||||||
|
num_kv_heads: usize,
|
||||||
|
num_kv_groups: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
hidden_size: usize,
|
||||||
|
// utils
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
kv_cache: KvCache,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3Attention {
|
||||||
|
pub(crate) fn new(
|
||||||
|
cfg: &Config,
|
||||||
|
rotary_emb: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
if cfg.use_sliding_window {
|
||||||
|
candle::bail!("sliding window is not suppored")
|
||||||
|
}
|
||||||
|
|
||||||
|
let head_dim = cfg.head_dim;
|
||||||
|
let num_heads = cfg.num_attention_heads;
|
||||||
|
let num_kv_heads = cfg.num_key_value_heads;
|
||||||
|
let num_kv_groups = num_heads / num_kv_heads;
|
||||||
|
|
||||||
|
let q_proj = linear_b(
|
||||||
|
cfg.hidden_size,
|
||||||
|
num_heads * head_dim,
|
||||||
|
cfg.attention_bias,
|
||||||
|
vb.pp("q_proj"),
|
||||||
|
)?;
|
||||||
|
let k_proj = linear_b(
|
||||||
|
cfg.hidden_size,
|
||||||
|
num_kv_heads * head_dim,
|
||||||
|
cfg.attention_bias,
|
||||||
|
vb.pp("k_proj"),
|
||||||
|
)?;
|
||||||
|
let v_proj = linear_b(
|
||||||
|
cfg.hidden_size,
|
||||||
|
num_kv_heads * head_dim,
|
||||||
|
cfg.attention_bias,
|
||||||
|
vb.pp("v_proj"),
|
||||||
|
)?;
|
||||||
|
let o_proj = linear_b(
|
||||||
|
num_heads * head_dim,
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.attention_bias,
|
||||||
|
vb.pp("o_proj"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||||
|
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||||
|
|
||||||
|
// Necessary because the hidden_size in the config isn't always accurate
|
||||||
|
let hidden_size = head_dim * cfg.num_attention_heads;
|
||||||
|
|
||||||
|
// Initialize KV cache with 512 tokens capacity to reduce initial memory allocation.
|
||||||
|
// The cache will grow in chunks of 512 tokens when needed.
|
||||||
|
let kv_cache = KvCache::new(2, 512);
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
q_proj,
|
||||||
|
k_proj,
|
||||||
|
v_proj,
|
||||||
|
o_proj,
|
||||||
|
q_norm,
|
||||||
|
k_norm,
|
||||||
|
num_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
num_kv_groups,
|
||||||
|
head_dim,
|
||||||
|
hidden_size,
|
||||||
|
rotary_emb,
|
||||||
|
kv_cache,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn forward(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
attn_mask: Option<&Tensor>,
|
||||||
|
offset: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let (b, l, _) = x.dims3()?;
|
||||||
|
|
||||||
|
// 1. Proj
|
||||||
|
let q = self.q_proj.forward(x)?;
|
||||||
|
let k = self.k_proj.forward(x)?;
|
||||||
|
let v = self.v_proj.forward(x)?;
|
||||||
|
|
||||||
|
// 2. Reshape: (B, L, H, D) -> (B, H, L, D)
|
||||||
|
let q = q
|
||||||
|
.reshape((b, l, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b, l, self.num_kv_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
// 3. Per‑head RMSNorm
|
||||||
|
let q_flat = q.flatten(0, 2)?; // (B*H, L, D) -> (BHL, D) after transpose later
|
||||||
|
let k_flat = k.flatten(0, 2)?;
|
||||||
|
let q_flat = self.q_norm.forward(&q_flat)?;
|
||||||
|
let k_flat = self.k_norm.forward(&k_flat)?;
|
||||||
|
let q = q_flat.reshape((b, self.num_heads, l, self.head_dim))?;
|
||||||
|
let k = k_flat.reshape((b, self.num_kv_heads, l, self.head_dim))?;
|
||||||
|
|
||||||
|
// 4. RoPE
|
||||||
|
let (q, k) = self.rotary_emb.apply(&q, &k, offset)?;
|
||||||
|
|
||||||
|
// 5. Accumulate KV cache
|
||||||
|
let (k, v) = self.kv_cache.append(&k.contiguous()?, &v.contiguous()?)?;
|
||||||
|
|
||||||
|
// 6. GQA repeat_kv
|
||||||
|
let k = repeat_kv(k, self.num_kv_groups)?;
|
||||||
|
let v = repeat_kv(v, self.num_kv_groups)?;
|
||||||
|
|
||||||
|
// 7. Attention score
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut scores = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
if let Some(m) = attn_mask {
|
||||||
|
scores = scores.broadcast_add(m)?;
|
||||||
|
}
|
||||||
|
let probs = candle_nn::ops::softmax_last_dim(&scores)?;
|
||||||
|
let ctx = probs.matmul(&v)?; // (B, H, L, D)
|
||||||
|
|
||||||
|
// 8. Output proj
|
||||||
|
ctx.transpose(1, 2)?
|
||||||
|
.reshape((b, l, self.hidden_size))?
|
||||||
|
.apply(&self.o_proj)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn clear_kv_cache(&mut self) {
|
||||||
|
self.kv_cache.reset();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Qwen3Attention,
|
||||||
|
mlp: Qwen3MLP,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(cfg: &Config, rotary: Arc<Qwen3RotaryEmbedding>, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let self_attn = Qwen3Attention::new(cfg, rotary, vb.pp("self_attn"))?;
|
||||||
|
let mlp = Qwen3MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
|
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
|
let ln2 = RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
mlp,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.mlp)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(vb.dtype(), cfg, vb.device())?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(DecoderLayer::new(cfg, rotary.clone(), vb_l.pp(i))?);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(
|
||||||
|
&self,
|
||||||
|
b: usize,
|
||||||
|
tgt: usize,
|
||||||
|
offset: usize,
|
||||||
|
sw: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..(tgt + offset)).map(move |j| {
|
||||||
|
let past_ok = j <= i + offset;
|
||||||
|
let sw_ok = match sw {
|
||||||
|
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||||
|
None => true,
|
||||||
|
};
|
||||||
|
if past_ok && sw_ok {
|
||||||
|
0.
|
||||||
|
} else {
|
||||||
|
minf
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset, None)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelForCausalLM {
|
||||||
|
base: Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelForCausalLM {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let base = Model::new(cfg, vb.clone())?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
self.base
|
||||||
|
.forward(input, offset)?
|
||||||
|
.narrow(1, l - 1, 1)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
355
candle-transformers/src/models/qwen3_moe.rs
Normal file
355
candle-transformers/src/models/qwen3_moe.rs
Normal file
@ -0,0 +1,355 @@
|
|||||||
|
use crate::models::{
|
||||||
|
qwen3::{Config as Qwen3Config, Qwen3Attention, Qwen3MLP, Qwen3RotaryEmbedding},
|
||||||
|
with_tracing::{linear_no_bias, Linear, RmsNorm},
|
||||||
|
};
|
||||||
|
use candle::{DType, Device, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{Activation, VarBuilder};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub vocab_size: usize,
|
||||||
|
pub hidden_size: usize,
|
||||||
|
pub intermediate_size: usize,
|
||||||
|
pub num_hidden_layers: usize,
|
||||||
|
pub num_attention_heads: usize,
|
||||||
|
pub head_dim: usize,
|
||||||
|
pub attention_bias: bool,
|
||||||
|
pub num_key_value_heads: usize,
|
||||||
|
pub max_position_embeddings: usize,
|
||||||
|
pub sliding_window: Option<usize>,
|
||||||
|
pub max_window_layers: usize,
|
||||||
|
pub tie_word_embeddings: bool,
|
||||||
|
pub rope_theta: f64,
|
||||||
|
pub rms_norm_eps: f64,
|
||||||
|
pub use_sliding_window: bool,
|
||||||
|
pub hidden_act: Activation,
|
||||||
|
// MoE specific configuration
|
||||||
|
pub decoder_sparse_step: usize,
|
||||||
|
pub moe_intermediate_size: usize,
|
||||||
|
pub num_experts_per_tok: usize,
|
||||||
|
pub num_experts: usize,
|
||||||
|
pub norm_topk_prob: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&Config> for Qwen3Config {
|
||||||
|
fn from(val: &Config) -> Self {
|
||||||
|
Qwen3Config {
|
||||||
|
vocab_size: val.vocab_size,
|
||||||
|
hidden_size: val.hidden_size,
|
||||||
|
intermediate_size: val.intermediate_size,
|
||||||
|
num_hidden_layers: val.num_hidden_layers,
|
||||||
|
num_attention_heads: val.num_attention_heads,
|
||||||
|
head_dim: val.head_dim,
|
||||||
|
attention_bias: val.attention_bias,
|
||||||
|
num_key_value_heads: val.num_key_value_heads,
|
||||||
|
max_position_embeddings: val.max_position_embeddings,
|
||||||
|
sliding_window: val.sliding_window,
|
||||||
|
max_window_layers: val.max_window_layers,
|
||||||
|
tie_word_embeddings: val.tie_word_embeddings,
|
||||||
|
rope_theta: val.rope_theta,
|
||||||
|
rms_norm_eps: val.rms_norm_eps,
|
||||||
|
use_sliding_window: val.use_sliding_window,
|
||||||
|
hidden_act: val.hidden_act,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Qwen3MLPExpert {
|
||||||
|
gate_proj: Linear,
|
||||||
|
up_proj: Linear,
|
||||||
|
down_proj: Linear,
|
||||||
|
act_fn: Activation,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3MLPExpert {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
Ok(Self {
|
||||||
|
gate_proj: linear_no_bias(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.moe_intermediate_size,
|
||||||
|
vb.pp("gate_proj"),
|
||||||
|
)?,
|
||||||
|
up_proj: linear_no_bias(cfg.hidden_size, cfg.moe_intermediate_size, vb.pp("up_proj"))?,
|
||||||
|
down_proj: linear_no_bias(
|
||||||
|
cfg.moe_intermediate_size,
|
||||||
|
cfg.hidden_size,
|
||||||
|
vb.pp("down_proj"),
|
||||||
|
)?,
|
||||||
|
act_fn: cfg.hidden_act,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3MLPExpert {
|
||||||
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
let lhs = x.apply(&self.gate_proj)?.apply(&self.act_fn)?;
|
||||||
|
let rhs = x.apply(&self.up_proj)?;
|
||||||
|
(lhs * rhs)?.apply(&self.down_proj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen3 Sparse MoE Block implementation
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Qwen3SparseMoeBlock {
|
||||||
|
gate: Linear,
|
||||||
|
experts: Vec<Qwen3MLPExpert>,
|
||||||
|
norm_topk_prob: bool,
|
||||||
|
num_experts_per_tok: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Qwen3SparseMoeBlock {
|
||||||
|
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let gate = linear_no_bias(cfg.hidden_size, cfg.num_experts, vb.pp("gate"))?;
|
||||||
|
let mut experts = Vec::with_capacity(cfg.num_experts);
|
||||||
|
let vb_e = vb.pp("experts");
|
||||||
|
for idx in 0..cfg.num_experts {
|
||||||
|
let expert = Qwen3MLPExpert::new(cfg, vb_e.pp(idx))?;
|
||||||
|
experts.push(expert)
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
gate,
|
||||||
|
experts,
|
||||||
|
norm_topk_prob: cfg.norm_topk_prob,
|
||||||
|
num_experts_per_tok: cfg.num_experts_per_tok,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3SparseMoeBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b_size, seq_len, hidden_dim) = xs.dims3()?;
|
||||||
|
let xs = xs.reshape(((), hidden_dim))?;
|
||||||
|
let router_logits = xs.apply(&self.gate)?;
|
||||||
|
let routing_weights = candle_nn::ops::softmax_last_dim(&router_logits)?;
|
||||||
|
|
||||||
|
// Extract topk experts per token
|
||||||
|
let experts_per_tok = routing_weights
|
||||||
|
.arg_sort_last_dim(false)?
|
||||||
|
.narrow(D::Minus1, 0, self.num_experts_per_tok)?
|
||||||
|
.contiguous()?;
|
||||||
|
let routing_weights = routing_weights.gather(&experts_per_tok, D::Minus1)?;
|
||||||
|
|
||||||
|
// Extract needed data
|
||||||
|
let routing_weights = routing_weights.to_dtype(DType::F32)?.to_vec2::<f32>()?;
|
||||||
|
let experts_per_tok = experts_per_tok.to_vec2::<u32>()?;
|
||||||
|
let mut top_x = vec![vec![]; self.experts.len()];
|
||||||
|
let mut selected_experts = vec![vec![]; self.experts.len()];
|
||||||
|
for (row_idx, (rw, expert_idxs)) in routing_weights
|
||||||
|
.iter()
|
||||||
|
.zip(experts_per_tok.iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
let sum_rw = rw.iter().sum::<f32>();
|
||||||
|
for (&rw, &expert_idx) in rw.iter().zip(expert_idxs.iter()) {
|
||||||
|
top_x[expert_idx as usize].push(row_idx as u32);
|
||||||
|
let rw = if self.norm_topk_prob { rw / sum_rw } else { rw };
|
||||||
|
selected_experts[expert_idx as usize].push(rw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process through experts
|
||||||
|
let mut ys = xs.zeros_like()?;
|
||||||
|
for (expert_idx, expert_layer) in self.experts.iter().enumerate() {
|
||||||
|
let top_x = &top_x[expert_idx];
|
||||||
|
if top_x.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let top_x = Tensor::new(top_x.as_slice(), xs.device())?;
|
||||||
|
let selected_experts =
|
||||||
|
Tensor::new(selected_experts[expert_idx].as_slice(), xs.device())?
|
||||||
|
.reshape(((), 1))?
|
||||||
|
.to_dtype(xs.dtype())?;
|
||||||
|
|
||||||
|
let current_state = xs.index_select(&top_x, 0)?.reshape(((), hidden_dim))?;
|
||||||
|
let current_hidden_states = expert_layer.forward(¤t_state)?;
|
||||||
|
let current_hidden_states = current_hidden_states.broadcast_mul(&selected_experts)?;
|
||||||
|
ys = ys.index_add(&top_x, ¤t_hidden_states, 0)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
ys.reshape((b_size, seq_len, hidden_dim))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MLP or MoE decision enum
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum Qwen3FeedForward {
|
||||||
|
Mlp(Qwen3MLP),
|
||||||
|
MoE(Qwen3SparseMoeBlock),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Qwen3FeedForward {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
Self::Mlp(m) => m.forward(xs),
|
||||||
|
Self::MoE(m) => m.forward(xs),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderLayer {
|
||||||
|
self_attn: Qwen3Attention,
|
||||||
|
feed_forward: Qwen3FeedForward,
|
||||||
|
ln1: RmsNorm,
|
||||||
|
ln2: RmsNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderLayer {
|
||||||
|
fn new(
|
||||||
|
layer_idx: usize,
|
||||||
|
cfg: &Config,
|
||||||
|
rotary: Arc<Qwen3RotaryEmbedding>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let self_attn = Qwen3Attention::new(&cfg.into(), rotary, vb.pp("self_attn"))?;
|
||||||
|
|
||||||
|
// Decide whether to use MoE or regular MLP based on layer_idx and decoder_sparse_step
|
||||||
|
let feed_forward = if cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0
|
||||||
|
{
|
||||||
|
Qwen3FeedForward::MoE(Qwen3SparseMoeBlock::new(cfg, vb.pp("mlp"))?)
|
||||||
|
} else {
|
||||||
|
Qwen3FeedForward::Mlp(Qwen3MLP::new(&cfg.into(), vb.pp("mlp"))?)
|
||||||
|
};
|
||||||
|
|
||||||
|
let ln1 = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("input_layernorm"))?;
|
||||||
|
let ln2 = RmsNorm::new(
|
||||||
|
cfg.hidden_size,
|
||||||
|
cfg.rms_norm_eps,
|
||||||
|
vb.pp("post_attention_layernorm"),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
self_attn,
|
||||||
|
feed_forward,
|
||||||
|
ln1,
|
||||||
|
ln2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&mut self, x: &Tensor, mask: Option<&Tensor>, offset: usize) -> Result<Tensor> {
|
||||||
|
let h = self.ln1.forward(x)?;
|
||||||
|
let h = self.self_attn.forward(&h, mask, offset)?;
|
||||||
|
let x = (x + h)?;
|
||||||
|
let h2 = self.ln2.forward(&x)?;
|
||||||
|
let h2 = h2.apply(&self.feed_forward)?;
|
||||||
|
x + h2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
self.self_attn.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
embed_tokens: candle_nn::Embedding,
|
||||||
|
layers: Vec<DecoderLayer>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
device: Device,
|
||||||
|
dtype: DType,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let embed_tokens =
|
||||||
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb.pp("model.embed_tokens"))?;
|
||||||
|
let rotary = Arc::new(Qwen3RotaryEmbedding::new(
|
||||||
|
vb.dtype(),
|
||||||
|
&cfg.into(),
|
||||||
|
vb.device(),
|
||||||
|
)?);
|
||||||
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
|
let vb_l = vb.pp("model.layers");
|
||||||
|
for i in 0..cfg.num_hidden_layers {
|
||||||
|
layers.push(DecoderLayer::new(i, cfg, rotary.clone(), vb_l.pp(i))?);
|
||||||
|
}
|
||||||
|
Ok(Self {
|
||||||
|
embed_tokens,
|
||||||
|
layers,
|
||||||
|
norm: RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb.pp("model.norm"))?,
|
||||||
|
device: vb.device().clone(),
|
||||||
|
dtype: vb.dtype(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear_kv_cache(&mut self) {
|
||||||
|
for l in &mut self.layers {
|
||||||
|
l.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn causal_mask(
|
||||||
|
&self,
|
||||||
|
b: usize,
|
||||||
|
tgt: usize,
|
||||||
|
offset: usize,
|
||||||
|
sw: Option<usize>,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let minf = f32::NEG_INFINITY;
|
||||||
|
let mask: Vec<_> = (0..tgt)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..(tgt + offset)).map(move |j| {
|
||||||
|
let past_ok = j <= i + offset;
|
||||||
|
let sw_ok = match sw {
|
||||||
|
Some(w) => (i + offset) as i64 - j as i64 <= w as i64,
|
||||||
|
None => true,
|
||||||
|
};
|
||||||
|
if past_ok && sw_ok {
|
||||||
|
0.
|
||||||
|
} else {
|
||||||
|
minf
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
Tensor::from_slice(&mask, (b, 1, tgt, tgt + offset), &self.device)?.to_dtype(self.dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let (b, l) = input.dims2()?;
|
||||||
|
let mut h = self.embed_tokens.forward(input)?;
|
||||||
|
|
||||||
|
let causal = if l == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(self.causal_mask(b, l, offset, None)?)
|
||||||
|
};
|
||||||
|
|
||||||
|
for layer in &mut self.layers {
|
||||||
|
h = layer.forward(&h, causal.as_ref(), offset)?;
|
||||||
|
}
|
||||||
|
self.norm.forward(&h)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelForCausalLM {
|
||||||
|
base: Model,
|
||||||
|
lm_head: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelForCausalLM {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let base = Model::new(cfg, vb.clone())?;
|
||||||
|
let lm_head = if cfg.tie_word_embeddings {
|
||||||
|
Linear::from_weights(base.embed_tokens.embeddings().clone(), None)
|
||||||
|
} else {
|
||||||
|
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
|
||||||
|
};
|
||||||
|
Ok(Self { base, lm_head })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, input: &Tensor, offset: usize) -> Result<Tensor> {
|
||||||
|
let (_, l) = input.dims2()?;
|
||||||
|
self.base
|
||||||
|
.forward(input, offset)?
|
||||||
|
.narrow(1, l - 1, 1)?
|
||||||
|
.apply(&self.lm_head)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn clear_kv_cache(&mut self) {
|
||||||
|
self.base.clear_kv_cache();
|
||||||
|
}
|
||||||
|
}
|
@ -17,8 +17,8 @@ const CROP_NMS_THRESH: f32 = 0.7;
|
|||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
enum ImageEncoder {
|
enum ImageEncoder {
|
||||||
Original(ImageEncoderViT),
|
Original(Box<ImageEncoderViT>),
|
||||||
TinyViT(TinyViT),
|
TinyViT(Box<TinyViT>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for ImageEncoder {
|
impl Module for ImageEncoder {
|
||||||
@ -83,7 +83,7 @@ impl Sam {
|
|||||||
let pixel_std =
|
let pixel_std =
|
||||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
image_encoder: ImageEncoder::Original(image_encoder),
|
image_encoder: ImageEncoder::Original(image_encoder.into()),
|
||||||
prompt_encoder,
|
prompt_encoder,
|
||||||
mask_decoder,
|
mask_decoder,
|
||||||
pixel_std,
|
pixel_std,
|
||||||
@ -114,7 +114,7 @@ impl Sam {
|
|||||||
let pixel_std =
|
let pixel_std =
|
||||||
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
Tensor::new(&[58.395f32, 57.12, 57.375], vb.device())?.reshape((3, 1, 1))?;
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
image_encoder: ImageEncoder::TinyViT(image_encoder),
|
image_encoder: ImageEncoder::TinyViT(image_encoder.into()),
|
||||||
prompt_encoder,
|
prompt_encoder,
|
||||||
mask_decoder,
|
mask_decoder,
|
||||||
pixel_std,
|
pixel_std,
|
||||||
|
@ -134,12 +134,7 @@ impl Scheduler for DDIMScheduler {
|
|||||||
timestep
|
timestep
|
||||||
};
|
};
|
||||||
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
// https://github.com/huggingface/diffusers/blob/6e099e2c8ce4c4f5c7318e970a8c093dc5c7046e/src/diffusers/schedulers/scheduling_ddim.py#L195
|
||||||
let prev_timestep = if timestep > self.step_ratio {
|
let prev_timestep = timestep.saturating_sub(self.step_ratio);
|
||||||
timestep - self.step_ratio
|
|
||||||
} else {
|
|
||||||
0
|
|
||||||
};
|
|
||||||
|
|
||||||
let alpha_prod_t = self.alphas_cumprod[timestep];
|
let alpha_prod_t = self.alphas_cumprod[timestep];
|
||||||
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
let alpha_prod_t_prev = self.alphas_cumprod[prev_timestep];
|
||||||
let beta_prod_t = 1. - alpha_prod_t;
|
let beta_prod_t = 1. - alpha_prod_t;
|
||||||
|
@ -482,8 +482,10 @@ impl XLMRobertaClassificationHead {
|
|||||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
let cls_states = hidden_states.get_on_dim(1, 0)?.contiguous()?;
|
||||||
let hidden_states = self.dense.forward(&cls_states)?;
|
let hidden_states = self.dense.forward(&cls_states)?;
|
||||||
let hidden_states = candle_nn::Activation::GeluPytorchTanh.forward(&hidden_states)?;
|
// The activation used in the classification head is tanh, as per the original
|
||||||
let hidden_states = self.out_proj.forward(&hidden_states)?;
|
// implementation.
|
||||||
|
// https://github.com/huggingface/transformers/blob/6e3063422c4b1c014aa60c32b9254fd2902f0f28/src/transformers/models/xlm_roberta/modeling_xlm_roberta.py#L1454
|
||||||
|
let hidden_states = self.out_proj.forward(&hidden_states.tanh()?)?;
|
||||||
Ok(hidden_states)
|
Ok(hidden_states)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user