mirror of
https://github.com/huggingface/candle.git
synced 2025-06-19 03:54:56 +00:00
Compare commits
13 Commits
0.9.0-alph
...
0.9.0
Author | SHA1 | Date | |
---|---|---|---|
fbaf0b0e32 | |||
a2e925462c | |||
3827685524 | |||
3aeb9575c7 | |||
6ff0a6999c | |||
82def7ae38 | |||
99bd69f383 | |||
a4c56a958e | |||
b2904a830b | |||
21055b5697 | |||
9dbaf958dc | |||
ce5f8dd129 | |||
9954981327 |
20
Cargo.toml
20
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
|||||||
resolver = "2"
|
resolver = "2"
|
||||||
|
|
||||||
[workspace.package]
|
[workspace.package]
|
||||||
version = "0.9.0-alpha.4"
|
version = "0.9.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
description = "Minimalist ML framework."
|
description = "Minimalist ML framework."
|
||||||
repository = "https://github.com/huggingface/candle"
|
repository = "https://github.com/huggingface/candle"
|
||||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
|||||||
accelerate-src = { version = "0.3.2" }
|
accelerate-src = { version = "0.3.2" }
|
||||||
anyhow = { version = "1", features = ["backtrace"] }
|
anyhow = { version = "1", features = ["backtrace"] }
|
||||||
byteorder = "1.4.3"
|
byteorder = "1.4.3"
|
||||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0" }
|
||||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
|
candle-datasets = { path = "./candle-datasets", version = "0.9.0" }
|
||||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
|
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0" }
|
||||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
|
candle-kernels = { path = "./candle-kernels", version = "0.9.0" }
|
||||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
|
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0" }
|
||||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
|
candle-nn = { path = "./candle-nn", version = "0.9.0" }
|
||||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
|
candle-onnx = { path = "./candle-onnx", version = "0.9.0" }
|
||||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
|
candle-transformers = { path = "./candle-transformers", version = "0.9.0" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.4.1"
|
hf-hub = "0.4.1"
|
||||||
|
@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
|
|||||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||||
|
|
||||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||||
fn scatter_add(
|
|
||||||
&self,
|
fn scatter_set(
|
||||||
|
&mut self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: usize,
|
_: usize,
|
||||||
) -> Result<Self>;
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn scatter_add_set(
|
||||||
|
&mut self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<()>;
|
||||||
|
|
||||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||||
fn index_add(
|
fn index_add(
|
||||||
&self,
|
&self,
|
||||||
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
|
|||||||
_src_offset: usize,
|
_src_offset: usize,
|
||||||
_dst_offset: usize,
|
_dst_offset: usize,
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
|
|
||||||
|
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||||
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
|||||||
|
|
||||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
|
||||||
|
|
||||||
/// # Safety
|
/// # Safety
|
||||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||||
/// The caller should ensure that the data is properly initialized as early as possible
|
/// The caller should ensure that the data is properly initialized as early as possible
|
||||||
|
@ -53,6 +53,7 @@ impl Tensor {
|
|||||||
} else if let Some(op) = node.op() {
|
} else if let Some(op) = node.op() {
|
||||||
match op {
|
match op {
|
||||||
Op::IndexAdd(t1, t2, t3, _)
|
Op::IndexAdd(t1, t2, t3, _)
|
||||||
|
| Op::Scatter(t1, t2, t3, _)
|
||||||
| Op::ScatterAdd(t1, t2, t3, _)
|
| Op::ScatterAdd(t1, t2, t3, _)
|
||||||
| Op::CustomOp3(t1, t2, t3, _)
|
| Op::CustomOp3(t1, t2, t3, _)
|
||||||
| Op::WhereCond(t1, t2, t3) => {
|
| Op::WhereCond(t1, t2, t3) => {
|
||||||
@ -419,7 +420,7 @@ impl Tensor {
|
|||||||
let sum_grad = grads.or_insert(arg)?;
|
let sum_grad = grads.or_insert(arg)?;
|
||||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||||
}
|
}
|
||||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
Op::Scatter(init, indexes, src, dim) => {
|
||||||
let init_sum_grad = grads.or_insert(init)?;
|
let init_sum_grad = grads.or_insert(init)?;
|
||||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||||
|
|
||||||
@ -427,6 +428,16 @@ impl Tensor {
|
|||||||
let src_sum_grad = grads.or_insert(src)?;
|
let src_sum_grad = grads.or_insert(src)?;
|
||||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||||
}
|
}
|
||||||
|
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||||
|
let init_sum_grad = grads.or_insert(init)?;
|
||||||
|
let mask = init.ones_like()?;
|
||||||
|
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
|
||||||
|
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
|
||||||
|
|
||||||
|
let src_grad = grad.gather(indexes, *dim)?;
|
||||||
|
let src_sum_grad = grads.or_insert(src)?;
|
||||||
|
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||||
|
}
|
||||||
Op::IndexAdd(init, indexes, src, dim) => {
|
Op::IndexAdd(init, indexes, src, dim) => {
|
||||||
let init_sum_grad = grads.or_insert(init)?;
|
let init_sum_grad = grads.or_insert(init)?;
|
||||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||||
|
@ -7,7 +7,7 @@ use rayon::prelude::*;
|
|||||||
|
|
||||||
mod utils;
|
mod utils;
|
||||||
pub use utils::{
|
pub use utils::{
|
||||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
|
||||||
};
|
};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
@ -554,26 +554,65 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ScatterAdd<'a, I: IntDType> {
|
trait ElemUpdate {
|
||||||
|
fn f<T: WithDType>(dst: &mut T, src: T);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Set;
|
||||||
|
struct Add;
|
||||||
|
|
||||||
|
impl ElemUpdate for Set {
|
||||||
|
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||||
|
*dst = src
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ElemUpdate for Add {
|
||||||
|
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||||
|
*dst += src
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
|
||||||
ids: &'a [I],
|
ids: &'a [I],
|
||||||
ids_l: &'a Layout,
|
ids_l: &'a Layout,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
_phantom: std::marker::PhantomData<M>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
|
||||||
const OP: &'static str = "scatter-add";
|
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
Self {
|
||||||
let dst_len = l1.shape().elem_count();
|
ids,
|
||||||
let mut dst = vec![T::zero(); dst_len];
|
ids_l,
|
||||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
dim,
|
||||||
|
_phantom: Default::default(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
||||||
|
const OP: &'static str = "scatter";
|
||||||
|
fn f<T: WithDType>(
|
||||||
|
&self,
|
||||||
|
dst: &mut [T],
|
||||||
|
dst_l: &Layout,
|
||||||
|
src: &[T],
|
||||||
|
src_l: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
let dst = match dst_l.contiguous_offsets() {
|
||||||
|
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||||
|
Some((o1, o2)) => &mut dst[o1..o2],
|
||||||
|
};
|
||||||
|
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||||
Some((o1, o2)) => &src[o1..o2],
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
};
|
};
|
||||||
|
|
||||||
let dim = self.dim;
|
let dim = self.dim;
|
||||||
let ids_dims = self.ids_l.dims();
|
let ids_dims = self.ids_l.dims();
|
||||||
let dst_dims = l1.dims();
|
let dst_dims = dst_l.dims();
|
||||||
let dst_dim_len = dst_dims[dim];
|
let dst_dim_len = dst_dims[dim];
|
||||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||||
|
|
||||||
@ -602,12 +641,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
|||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
||||||
dst[dst_idx] += src[ids_idx]
|
M::f(&mut dst[dst_idx], src[ids_idx])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(dst)
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -2381,19 +2420,36 @@ impl BackendStorage for CpuStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_set(
|
||||||
&self,
|
&mut self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
ids: &Self,
|
ids: &Self,
|
||||||
ids_l: &Layout,
|
ids_l: &Layout,
|
||||||
src: &Self,
|
src: &Self,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<()> {
|
||||||
match ids {
|
match ids {
|
||||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||||
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add_set(
|
||||||
|
&mut self,
|
||||||
|
l: &Layout,
|
||||||
|
ids: &Self,
|
||||||
|
ids_l: &Layout,
|
||||||
|
src: &Self,
|
||||||
|
src_l: &Layout,
|
||||||
|
dim: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
match ids {
|
||||||
|
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||||
|
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||||
|
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -2454,6 +2510,48 @@ impl BackendStorage for CpuStorage {
|
|||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
Ok(self.clone())
|
Ok(self.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||||
|
use crate::scalar::Scalar;
|
||||||
|
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
|
||||||
|
match l.strided_blocks() {
|
||||||
|
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||||
|
src[start_offset..start_offset + len].fill(s)
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len: 1,
|
||||||
|
} => {
|
||||||
|
for src_index in block_start_index {
|
||||||
|
src[src_index] = s
|
||||||
|
}
|
||||||
|
}
|
||||||
|
crate::StridedBlocks::MultipleBlocks {
|
||||||
|
block_start_index,
|
||||||
|
block_len,
|
||||||
|
} => {
|
||||||
|
for src_index in block_start_index {
|
||||||
|
src[src_index..src_index + block_len].fill(s)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
match (self, s) {
|
||||||
|
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
|
||||||
|
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
|
||||||
|
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
|
||||||
|
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
|
||||||
|
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
|
||||||
|
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
|
||||||
|
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
|
||||||
|
(st, s) => crate::bail!(
|
||||||
|
"const_set dtype mismatch, expected {:?} but got {:?}",
|
||||||
|
st.dtype(),
|
||||||
|
s
|
||||||
|
),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl BackendDevice for CpuDevice {
|
impl BackendDevice for CpuDevice {
|
||||||
@ -2628,20 +2726,6 @@ impl BackendDevice for CpuDevice {
|
|||||||
Ok(storage)
|
Ok(storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let storage = match dtype {
|
|
||||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
|
||||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
|
||||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
|
||||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
|
||||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
|
||||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
|
||||||
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
|
||||||
};
|
|
||||||
Ok(storage)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let storage = match dtype {
|
let storage = match dtype {
|
||||||
|
@ -58,6 +58,30 @@ pub trait Map2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait Map2InPlace {
|
||||||
|
const OP: &'static str;
|
||||||
|
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
|
||||||
|
|
||||||
|
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
|
||||||
|
match (v1, v2) {
|
||||||
|
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||||
|
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
|
||||||
|
lhs: v1.dtype(),
|
||||||
|
rhs: v2.dtype(),
|
||||||
|
op: Self::OP,
|
||||||
|
}
|
||||||
|
.bt())?,
|
||||||
|
};
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait Map2U8 {
|
pub trait Map2U8 {
|
||||||
const OP: &'static str;
|
const OP: &'static str;
|
||||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||||
|
@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
|
|||||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||||
pub use candle_kernels as kernels;
|
pub use candle_kernels as kernels;
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
use cudarc::driver::CudaFunction;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::{Arc, Mutex};
|
use std::sync::{Arc, Mutex};
|
||||||
@ -188,100 +188,6 @@ impl CudaDevice {
|
|||||||
self.id
|
self.id
|
||||||
}
|
}
|
||||||
|
|
||||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
let elem_count = shape.elem_count();
|
|
||||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
|
||||||
let slice = match dtype {
|
|
||||||
DType::U8 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
|
||||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
let v = v as u8;
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::U8(data)
|
|
||||||
}
|
|
||||||
DType::U32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
|
||||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
let v = v as u32;
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::U32(data)
|
|
||||||
}
|
|
||||||
DType::I64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
|
||||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
let v = v as i64;
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::I64(data)
|
|
||||||
}
|
|
||||||
DType::BF16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
|
||||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
let v = bf16::from_f64(v);
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::BF16(data)
|
|
||||||
}
|
|
||||||
DType::F16 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
|
||||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
let v = f16::from_f64(v);
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::F16(data)
|
|
||||||
}
|
|
||||||
DType::F32 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
|
||||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
let v = v as f32;
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::F32(data)
|
|
||||||
}
|
|
||||||
DType::F64 => {
|
|
||||||
// SAFETY: Set later by running the fill kernel.
|
|
||||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
|
||||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
|
||||||
let mut builder = self.stream.launch_builder(&func);
|
|
||||||
builder.arg(&data);
|
|
||||||
builder.arg(&v);
|
|
||||||
builder.arg(&elem_count);
|
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
|
||||||
CudaStorageSlice::F64(data)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
Ok(CudaStorage {
|
|
||||||
slice,
|
|
||||||
device: self.clone(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_or_load_custom_func(
|
pub fn get_or_load_custom_func(
|
||||||
&self,
|
&self,
|
||||||
fn_name: &str,
|
fn_name: &str,
|
||||||
@ -504,10 +410,6 @@ impl BackendDevice for CudaDevice {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
|
||||||
self.const_impl(1., shape, dtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||||
let elem_count = shape.elem_count();
|
let elem_count = shape.elem_count();
|
||||||
let slice = match dtype {
|
let slice = match dtype {
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
//!
|
//!
|
||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||||
pub use candle_kernels as kernels;
|
pub use candle_kernels as kernels;
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||||
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl crate::scalar::Scalar {
|
||||||
|
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
|
||||||
|
use crate::scalar::Scalar;
|
||||||
|
match self {
|
||||||
|
Scalar::U8(v) => builder.arg(v),
|
||||||
|
Scalar::U32(v) => builder.arg(v),
|
||||||
|
Scalar::I64(v) => builder.arg(v),
|
||||||
|
Scalar::F32(v) => builder.arg(v),
|
||||||
|
Scalar::F64(v) => builder.arg(v),
|
||||||
|
Scalar::F16(v) => builder.arg(v),
|
||||||
|
Scalar::BF16(v) => builder.arg(v),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl SlicePtrOrNull<usize> {
|
impl SlicePtrOrNull<usize> {
|
||||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||||
let ds = if l.is_contiguous() {
|
let ds = if l.is_contiguous() {
|
||||||
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
|
|||||||
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
|
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
|
||||||
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
|
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
|
||||||
_ => Err(CudaError::UnexpectedDType {
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
msg: "index_select ids should be u8 or u32",
|
msg: "index_select ids should be u8, u32, or i64",
|
||||||
expected: DType::U32,
|
expected: DType::U32,
|
||||||
got: self.0.dtype(),
|
got: self.0.dtype(),
|
||||||
})
|
})
|
||||||
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
|||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut CudaSlice<T>,
|
dst: &mut CudaSlice<T>,
|
||||||
dst_shape: &Shape,
|
dst_l: &Layout,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
|||||||
got: ids.dtype(),
|
got: ids.dtype(),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
let dst = match dst_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => dst.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
|
};
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||||
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
|||||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
let src_dim_sz = src_l.dims()[dim];
|
let src_dim_sz = src_l.dims()[dim];
|
||||||
let dst_dim_sz = dst_shape.dims()[dim];
|
let dst_dim_sz = dst_l.dims()[dim];
|
||||||
let ids_dim_sz = ids_l.dims()[0];
|
let ids_dim_sz = ids_l.dims()[0];
|
||||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||||
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
|
|||||||
barg!(builder, ids);
|
barg!(builder, ids);
|
||||||
barg!(builder, ids_dim_sz);
|
barg!(builder, ids_dim_sz);
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(dst);
|
builder.arg(&dst);
|
||||||
|
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||||
|
impl Map2InPlace for Scatter<'_> {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
dst: &mut CudaSlice<T>,
|
||||||
|
dst_l: &Layout,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
src_l: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<()> {
|
||||||
|
let ids = &self.0;
|
||||||
|
let ids_l = &self.1;
|
||||||
|
let dim = self.2;
|
||||||
|
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||||
|
Some(o12) => o12,
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||||
|
};
|
||||||
|
let (name, (ids, _guard)) = match &ids.slice {
|
||||||
|
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
|
||||||
|
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
|
||||||
|
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
|
||||||
|
_ => Err(CudaError::UnexpectedDType {
|
||||||
|
msg: "scatter ids should be u8/u32/i64",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: ids.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let dst = match dst_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => dst.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||||
|
};
|
||||||
|
let src = match src_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||||
|
};
|
||||||
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||||
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
|
let src_dim_sz = src_l.dims()[dim];
|
||||||
|
let dst_dim_sz = dst_l.dims()[dim];
|
||||||
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||||
|
let mut builder = func.builder();
|
||||||
|
barg!(builder, ids);
|
||||||
|
builder.arg(&src);
|
||||||
|
builder.arg(&dst);
|
||||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
|||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut CudaSlice<T>,
|
dst: &mut CudaSlice<T>,
|
||||||
dst_shape: &Shape,
|
dst_l: &Layout,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
|||||||
got: ids.dtype(),
|
got: ids.dtype(),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
let dst = match dst_l.contiguous_offsets() {
|
||||||
|
Some((o1, o2)) => dst.slice(o1..o2),
|
||||||
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
|
};
|
||||||
let src = match src_l.contiguous_offsets() {
|
let src = match src_l.contiguous_offsets() {
|
||||||
Some((o1, o2)) => src.slice(o1..o2),
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||||
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
|||||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||||
let src_dim_sz = src_l.dims()[dim];
|
let src_dim_sz = src_l.dims()[dim];
|
||||||
let dst_dim_sz = dst_shape.dims()[dim];
|
let dst_dim_sz = dst_l.dims()[dim];
|
||||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||||
let mut builder = func.builder();
|
let mut builder = func.builder();
|
||||||
barg!(builder, ids);
|
barg!(builder, ids);
|
||||||
builder.arg(&src);
|
builder.arg(&src);
|
||||||
builder.arg(dst);
|
builder.arg(&dst);
|
||||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||||
// SAFETY: ffi.
|
// SAFETY: ffi.
|
||||||
unsafe { builder.launch(cfg) }.w()?;
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
@ -1235,6 +1310,36 @@ impl BackendStorage for CudaStorage {
|
|||||||
&self.device
|
&self.device
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
|
||||||
|
let dev = &self.device;
|
||||||
|
let shape = layout.shape();
|
||||||
|
let dims = shape.dims();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||||
|
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||||
|
let src_o = layout.start_offset();
|
||||||
|
let ((src, _guard_src), kernel_name) = match &mut self.slice {
|
||||||
|
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
|
||||||
|
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
|
||||||
|
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
|
||||||
|
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
|
||||||
|
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
|
||||||
|
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
|
||||||
|
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
|
||||||
|
};
|
||||||
|
|
||||||
|
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
|
||||||
|
let mut builder = func.builder();
|
||||||
|
barg!(builder, el_count);
|
||||||
|
barg!(builder, dims.len());
|
||||||
|
ds.builder_arg(&mut builder);
|
||||||
|
s.builder_arg(&mut builder);
|
||||||
|
barg!(builder, src);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { builder.launch(cfg) }.w()?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
let dims = shape.dims();
|
let dims = shape.dims();
|
||||||
@ -1793,20 +1898,29 @@ impl BackendStorage for CudaStorage {
|
|||||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
fn scatter_add(
|
fn scatter_set(
|
||||||
&self,
|
&mut self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
ids: &Self,
|
ids: &Self,
|
||||||
ids_l: &Layout,
|
ids_l: &Layout,
|
||||||
src: &Self,
|
src: &Self,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<()> {
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||||
self.copy_strided_src(&mut acc, 0, l)?;
|
}
|
||||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
fn scatter_add_set(
|
||||||
Ok(acc)
|
&mut self,
|
||||||
|
l: &Layout,
|
||||||
|
ids: &Self,
|
||||||
|
ids_l: &Layout,
|
||||||
|
src: &Self,
|
||||||
|
src_l: &Layout,
|
||||||
|
dim: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
let device = self.device().clone();
|
||||||
|
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||||
}
|
}
|
||||||
fn index_add(
|
fn index_add(
|
||||||
&self,
|
&self,
|
||||||
@ -1820,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
|||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||||
self.copy_strided_src(&mut acc, 0, l)?;
|
self.copy_strided_src(&mut acc, 0, l)?;
|
||||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||||
Ok(acc)
|
Ok(acc)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
/// Helper functions to plug cuda kernels in candle.
|
/// Helper functions to plug cuda kernels in candle.
|
||||||
use crate::{Layout, Result, Shape, WithDType};
|
use crate::{Layout, Result, WithDType};
|
||||||
pub use cudarc;
|
pub use cudarc;
|
||||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||||
|
|
||||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
|||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut CudaSlice<T>,
|
dst: &mut CudaSlice<T>,
|
||||||
dst_shape: &Shape,
|
dst_l: &Layout,
|
||||||
src: &CudaSlice<T>,
|
src: &CudaSlice<T>,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
|||||||
fn map(
|
fn map(
|
||||||
&self,
|
&self,
|
||||||
dst: &mut S,
|
dst: &mut S,
|
||||||
dst_s: &Shape,
|
dst_l: &Layout,
|
||||||
src: &S,
|
src: &S,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
d: &CudaDevice,
|
d: &CudaDevice,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
match (dst, src) {
|
match (dst, src) {
|
||||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -292,23 +292,6 @@ impl Device {
|
|||||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
|
||||||
match self {
|
|
||||||
Device::Cpu => {
|
|
||||||
let storage = CpuDevice.ones_impl(shape, dtype)?;
|
|
||||||
Ok(Storage::Cpu(storage))
|
|
||||||
}
|
|
||||||
Device::Cuda(device) => {
|
|
||||||
let storage = device.ones_impl(shape, dtype)?;
|
|
||||||
Ok(Storage::Cuda(storage))
|
|
||||||
}
|
|
||||||
Device::Metal(device) => {
|
|
||||||
let storage = device.ones_impl(shape, dtype)?;
|
|
||||||
Ok(Storage::Metal(storage))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||||
match self {
|
match self {
|
||||||
Device::Cpu => {
|
Device::Cpu => {
|
||||||
|
@ -107,6 +107,7 @@ pub trait WithDType:
|
|||||||
|
|
||||||
fn from_f64(v: f64) -> Self;
|
fn from_f64(v: f64) -> Self;
|
||||||
fn to_f64(self) -> f64;
|
fn to_f64(self) -> f64;
|
||||||
|
fn to_scalar(self) -> crate::scalar::Scalar;
|
||||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||||
|
|
||||||
@ -131,6 +132,10 @@ macro_rules! with_dtype {
|
|||||||
$to_f64(self)
|
$to_f64(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn to_scalar(self) -> crate::scalar::Scalar {
|
||||||
|
crate::scalar::Scalar::$dtype(self)
|
||||||
|
}
|
||||||
|
|
||||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||||
CpuStorageRef::$dtype(data)
|
CpuStorageRef::$dtype(data)
|
||||||
}
|
}
|
||||||
|
@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_set(
|
||||||
&self,
|
&mut self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: usize,
|
_: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add_set(
|
||||||
|
&mut self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<()> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
|||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithCudaSupport)
|
Err(Error::NotCompiledWithCudaSupport)
|
||||||
}
|
}
|
||||||
|
@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
|
|||||||
fail!()
|
fail!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_set(
|
||||||
&self,
|
&mut self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: &Self,
|
_: &Self,
|
||||||
_: &Layout,
|
_: &Layout,
|
||||||
_: usize,
|
_: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<()> {
|
||||||
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add_set(
|
||||||
|
&mut self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: &Self,
|
||||||
|
_: &Layout,
|
||||||
|
_: usize,
|
||||||
|
) -> Result<()> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
|
|||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
|
||||||
}
|
|
||||||
|
|
||||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||||
Err(Error::NotCompiledWithMetalSupport)
|
Err(Error::NotCompiledWithMetalSupport)
|
||||||
}
|
}
|
||||||
|
@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
|
|||||||
self.binary(name, rhs, lhs_l, rhs_l)
|
self.binary(name, rhs, lhs_l, rhs_l)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||||
|
use crate::scalar::Scalar;
|
||||||
|
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||||
|
self_: &mut MetalStorage,
|
||||||
|
s: S,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<()> {
|
||||||
|
let device = self_.device();
|
||||||
|
let dtype = self_.dtype;
|
||||||
|
let shape = l.shape();
|
||||||
|
let el_count = shape.elem_count();
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
command_buffer.set_label("const-set");
|
||||||
|
let dst = buffer_o(&self_.buffer, l, self_.dtype);
|
||||||
|
|
||||||
|
match (el_count % 2, dtype, l.is_contiguous()) {
|
||||||
|
(0, DType::BF16 | DType::F16, true) => {
|
||||||
|
use candle_metal_kernels::unary::contiguous_tiled;
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
DType::F16 => contiguous_tiled::const_set::HALF,
|
||||||
|
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
|
||||||
|
_ => crate::bail!("internal bug in const_set"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_const_set_contiguous_tiled(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
s,
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
(_, _, true) => {
|
||||||
|
use candle_metal_kernels::unary::contiguous;
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
DType::F16 => contiguous::const_set::HALF,
|
||||||
|
DType::BF16 => contiguous::const_set::BFLOAT,
|
||||||
|
DType::F32 => contiguous::const_set::FLOAT,
|
||||||
|
DType::I64 => contiguous::const_set::I64,
|
||||||
|
DType::U32 => contiguous::const_set::U32,
|
||||||
|
DType::U8 => contiguous::const_set::U8,
|
||||||
|
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_const_set_contiguous(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
el_count,
|
||||||
|
s,
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
(_, _, false) => {
|
||||||
|
use candle_metal_kernels::unary::strided;
|
||||||
|
let kernel_name = match dtype {
|
||||||
|
DType::F16 => strided::const_set::HALF,
|
||||||
|
DType::BF16 => strided::const_set::BFLOAT,
|
||||||
|
DType::F32 => strided::const_set::FLOAT,
|
||||||
|
DType::I64 => strided::const_set::I64,
|
||||||
|
DType::U32 => strided::const_set::U32,
|
||||||
|
DType::U8 => strided::const_set::U8,
|
||||||
|
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_const_set_strided(
|
||||||
|
&device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&device.kernels,
|
||||||
|
kernel_name,
|
||||||
|
l.dims(),
|
||||||
|
s,
|
||||||
|
l.stride(),
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
match (self.dtype, s) {
|
||||||
|
(DType::U8, Scalar::U8(s)) => set(self, s, l),
|
||||||
|
(DType::U32, Scalar::U32(s)) => set(self, s, l),
|
||||||
|
(DType::I64, Scalar::I64(s)) => set(self, s, l),
|
||||||
|
(DType::F16, Scalar::F16(s)) => set(self, s, l),
|
||||||
|
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
|
||||||
|
(DType::F32, Scalar::F32(s)) => set(self, s, l),
|
||||||
|
(DType::F64, Scalar::F64(s)) => set(self, s, l),
|
||||||
|
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||||
let device = self.device();
|
let device = self.device();
|
||||||
let shape = layout.shape();
|
let shape = layout.shape();
|
||||||
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
|
|||||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(
|
fn scatter_set(
|
||||||
&self,
|
&mut self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
ids: &Self,
|
ids: &Self,
|
||||||
ids_l: &Layout,
|
ids_l: &Layout,
|
||||||
src: &Self,
|
src: &Self,
|
||||||
src_l: &Layout,
|
src_l: &Layout,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<()> {
|
||||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||||
self.copy_strided_src(&mut acc, 0, l)?;
|
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
|
||||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
};
|
||||||
|
let name = match (ids.dtype, self.dtype) {
|
||||||
|
(DType::U8, DType::F32) => "s_u8_f32",
|
||||||
|
(DType::U8, DType::F16) => "s_u8_f16",
|
||||||
|
(DType::U8, DType::BF16) => "s_u8_bf16",
|
||||||
|
(DType::U32, DType::U32) => "s_u32_u32",
|
||||||
|
(DType::U32, DType::F32) => "s_u32_f32",
|
||||||
|
(DType::U32, DType::F16) => "s_u32_f16",
|
||||||
|
(DType::U32, DType::BF16) => "s_u32_bf16",
|
||||||
|
(DType::I64, DType::F32) => "s_i64_f32",
|
||||||
|
(DType::I64, DType::F16) => "s_i64_f16",
|
||||||
|
(DType::I64, DType::BF16) => "s_i64_bf16",
|
||||||
|
_ => Err(MetalError::UnexpectedDType {
|
||||||
|
msg: "scatter ids should be u8/u32/i64",
|
||||||
|
expected: DType::U32,
|
||||||
|
got: ids.dtype(),
|
||||||
|
})?,
|
||||||
|
};
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||||
|
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||||
|
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||||
|
candle_metal_kernels::call_scatter(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
src_l.dims(),
|
||||||
|
l.dims(),
|
||||||
|
dim,
|
||||||
|
src,
|
||||||
|
ids,
|
||||||
|
dst,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn scatter_add_set(
|
||||||
|
&mut self,
|
||||||
|
l: &Layout,
|
||||||
|
ids: &Self,
|
||||||
|
ids_l: &Layout,
|
||||||
|
src: &Self,
|
||||||
|
src_l: &Layout,
|
||||||
|
dim: usize,
|
||||||
|
) -> Result<()> {
|
||||||
|
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||||
};
|
};
|
||||||
let name = match (ids.dtype, self.dtype) {
|
let name = match (ids.dtype, self.dtype) {
|
||||||
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
let command_buffer = self.device.command_buffer()?;
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||||
candle_metal_kernels::call_scatter_add(
|
candle_metal_kernels::call_scatter(
|
||||||
&self.device.device,
|
&self.device.device,
|
||||||
&command_buffer,
|
&command_buffer,
|
||||||
&self.device.kernels,
|
&self.device.kernels,
|
||||||
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
|
|||||||
dim,
|
dim,
|
||||||
src,
|
src,
|
||||||
ids,
|
ids,
|
||||||
&acc.buffer,
|
dst,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
Ok(acc)
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||||
@ -1965,40 +2107,6 @@ impl BackendDevice for MetalDevice {
|
|||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
|
||||||
let name = match dtype {
|
|
||||||
DType::U8 => "fill_u8",
|
|
||||||
DType::U32 => "fill_u32",
|
|
||||||
DType::I64 => "fill_i64",
|
|
||||||
DType::F16 => "fill_f16",
|
|
||||||
DType::BF16 => "fill_bf16",
|
|
||||||
DType::F32 => "fill_f32",
|
|
||||||
DType::F64 => {
|
|
||||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
|
||||||
return self.storage_from_cpu_storage(&cpu_storage);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
|
||||||
let command_buffer = self.command_buffer()?;
|
|
||||||
candle_metal_kernels::call_const_fill(
|
|
||||||
&self.device,
|
|
||||||
&command_buffer,
|
|
||||||
&self.kernels,
|
|
||||||
name,
|
|
||||||
shape.elem_count(),
|
|
||||||
&buffer,
|
|
||||||
1.,
|
|
||||||
)
|
|
||||||
.map_err(MetalError::from)?;
|
|
||||||
|
|
||||||
Ok(MetalStorage::new(
|
|
||||||
buffer,
|
|
||||||
self.clone(),
|
|
||||||
shape.elem_count(),
|
|
||||||
dtype,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||||
|
@ -80,6 +80,7 @@ pub enum Op {
|
|||||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||||
Matmul(Tensor, Tensor),
|
Matmul(Tensor, Tensor),
|
||||||
Gather(Tensor, Tensor, usize),
|
Gather(Tensor, Tensor, usize),
|
||||||
|
Scatter(Tensor, Tensor, Tensor, usize),
|
||||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||||
IndexSelect(Tensor, Tensor, usize),
|
IndexSelect(Tensor, Tensor, usize),
|
||||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||||
|
@ -1,6 +1,74 @@
|
|||||||
//! TensorScalar Enum and Trait
|
//! TensorScalar Enum and Trait
|
||||||
//!
|
//!
|
||||||
use crate::{Result, Tensor, WithDType};
|
use crate::{DType, Result, Tensor, WithDType};
|
||||||
|
use half::{bf16, f16};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||||
|
pub enum Scalar {
|
||||||
|
U8(u8),
|
||||||
|
U32(u32),
|
||||||
|
I64(i64),
|
||||||
|
BF16(bf16),
|
||||||
|
F16(f16),
|
||||||
|
F32(f32),
|
||||||
|
F64(f64),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: WithDType> From<T> for Scalar {
|
||||||
|
fn from(value: T) -> Self {
|
||||||
|
value.to_scalar()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Scalar {
|
||||||
|
pub fn zero(dtype: DType) -> Self {
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => Scalar::U8(0),
|
||||||
|
DType::U32 => Scalar::U32(0),
|
||||||
|
DType::I64 => Scalar::I64(0),
|
||||||
|
DType::BF16 => Scalar::BF16(bf16::ZERO),
|
||||||
|
DType::F16 => Scalar::F16(f16::ZERO),
|
||||||
|
DType::F32 => Scalar::F32(0.0),
|
||||||
|
DType::F64 => Scalar::F64(0.0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn one(dtype: DType) -> Self {
|
||||||
|
match dtype {
|
||||||
|
DType::U8 => Scalar::U8(1),
|
||||||
|
DType::U32 => Scalar::U32(1),
|
||||||
|
DType::I64 => Scalar::I64(1),
|
||||||
|
DType::BF16 => Scalar::BF16(bf16::ONE),
|
||||||
|
DType::F16 => Scalar::F16(f16::ONE),
|
||||||
|
DType::F32 => Scalar::F32(1.0),
|
||||||
|
DType::F64 => Scalar::F64(1.0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dtype(&self) -> DType {
|
||||||
|
match self {
|
||||||
|
Scalar::U8(_) => DType::U8,
|
||||||
|
Scalar::U32(_) => DType::U32,
|
||||||
|
Scalar::I64(_) => DType::I64,
|
||||||
|
Scalar::BF16(_) => DType::BF16,
|
||||||
|
Scalar::F16(_) => DType::F16,
|
||||||
|
Scalar::F32(_) => DType::F32,
|
||||||
|
Scalar::F64(_) => DType::F64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn to_f64(&self) -> f64 {
|
||||||
|
match self {
|
||||||
|
Scalar::U8(v) => *v as f64,
|
||||||
|
Scalar::U32(v) => *v as f64,
|
||||||
|
Scalar::I64(v) => *v as f64,
|
||||||
|
Scalar::BF16(v) => v.to_f64(),
|
||||||
|
Scalar::F16(v) => v.to_f64(),
|
||||||
|
Scalar::F32(v) => *v as f64,
|
||||||
|
Scalar::F64(v) => *v,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub enum TensorScalar {
|
pub enum TensorScalar {
|
||||||
Tensor(Tensor),
|
Tensor(Tensor),
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
use crate::backend::BackendStorage;
|
use crate::backend::BackendStorage;
|
||||||
use crate::op::{self, CmpOp, ReduceOp};
|
use crate::op::{self, CmpOp, ReduceOp};
|
||||||
|
use crate::scalar::Scalar;
|
||||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||||
|
|
||||||
@ -73,6 +74,14 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
|
||||||
|
match self {
|
||||||
|
Storage::Cpu(storage) => storage.const_set(v, l),
|
||||||
|
Storage::Cuda(storage) => storage.const_set(v, l),
|
||||||
|
Storage::Metal(storage) => storage.const_set(v, l),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||||
match self {
|
match self {
|
||||||
Storage::Cpu(storage) => {
|
Storage::Cpu(storage) => {
|
||||||
@ -619,32 +628,56 @@ impl Storage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn scatter_add(
|
pub(crate) fn scatter_set(
|
||||||
&self,
|
&mut self,
|
||||||
l: &Layout,
|
l: &Layout,
|
||||||
indexes: &Self,
|
indexes: &Self,
|
||||||
indexes_l: &Layout,
|
indexes_l: &Layout,
|
||||||
source: &Self,
|
source: &Self,
|
||||||
source_l: &Layout,
|
source_l: &Layout,
|
||||||
d: usize,
|
d: usize,
|
||||||
) -> Result<Self> {
|
) -> Result<()> {
|
||||||
|
self.same_device(indexes, "scatter-set")?;
|
||||||
|
self.same_device(source, "scatter-set")?;
|
||||||
|
match (self, indexes, source) {
|
||||||
|
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||||
|
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
}
|
||||||
|
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||||
|
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
}
|
||||||
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
|
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
|
}
|
||||||
|
_ => unreachable!(),
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn scatter_add(
|
||||||
|
&mut self,
|
||||||
|
l: &Layout,
|
||||||
|
indexes: &Self,
|
||||||
|
indexes_l: &Layout,
|
||||||
|
source: &Self,
|
||||||
|
source_l: &Layout,
|
||||||
|
d: usize,
|
||||||
|
) -> Result<()> {
|
||||||
self.same_device(indexes, "scatter-add")?;
|
self.same_device(indexes, "scatter-add")?;
|
||||||
self.same_device(source, "scatter-add")?;
|
self.same_device(source, "scatter-add")?;
|
||||||
match (self, indexes, source) {
|
match (self, indexes, source) {
|
||||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cpu(storage))
|
|
||||||
}
|
}
|
||||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Cuda(storage))
|
|
||||||
}
|
}
|
||||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||||
Ok(Self::Metal(storage))
|
|
||||||
}
|
}
|
||||||
_ => unreachable!(),
|
_ => unreachable!(),
|
||||||
}
|
}
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn index_add(
|
pub(crate) fn index_add(
|
||||||
|
@ -3,7 +3,7 @@
|
|||||||
use crate::backend::{BackendDevice, BackendStorage};
|
use crate::backend::{BackendDevice, BackendStorage};
|
||||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||||
use crate::scalar::TensorOrScalar;
|
use crate::scalar::TensorOrScalar;
|
||||||
use crate::shape::{Dim, Dims};
|
use crate::shape::{Dim, Dims, ShapeWithOneHole};
|
||||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
|
|
||||||
@ -185,7 +185,9 @@ impl Tensor {
|
|||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let none = BackpropOp::none();
|
let none = BackpropOp::none();
|
||||||
let shape = shape.into();
|
let shape = shape.into();
|
||||||
let storage = device.ones(&shape, dtype)?;
|
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||||
|
let layout = Layout::contiguous(shape.clone());
|
||||||
|
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
|
||||||
Ok(from_storage(storage, shape, none, is_variable))
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -202,6 +204,18 @@ impl Tensor {
|
|||||||
Self::ones_impl(shape, dtype, device, false)
|
Self::ones_impl(shape, dtype, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
|
||||||
|
self.storage_mut().const_set(value, self.layout())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn zero_set(&self) -> Result<()> {
|
||||||
|
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn one_set(&self) -> Result<()> {
|
||||||
|
self.const_set(crate::scalar::Scalar::one(self.dtype()))
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||||
///
|
///
|
||||||
/// ```rust
|
/// ```rust
|
||||||
@ -452,17 +466,13 @@ impl Tensor {
|
|||||||
Self::from_vec_impl(data, len, device, false)
|
Self::from_vec_impl(data, len, device, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||||
data: Vec<D>,
|
data: Vec<D>,
|
||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
is_variable: bool,
|
is_variable: bool,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into_shape(data.len())?;
|
||||||
let buffer_size = data.len();
|
|
||||||
if buffer_size != shape.elem_count() {
|
|
||||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
|
||||||
}
|
|
||||||
let storage = device.storage_owned(data)?;
|
let storage = device.storage_owned(data)?;
|
||||||
let none = BackpropOp::none();
|
let none = BackpropOp::none();
|
||||||
Ok(from_storage(storage, shape, none, is_variable))
|
Ok(from_storage(storage, shape, none, is_variable))
|
||||||
@ -481,7 +491,7 @@ impl Tensor {
|
|||||||
/// ]);
|
/// ]);
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||||
data: Vec<D>,
|
data: Vec<D>,
|
||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -502,17 +512,12 @@ impl Tensor {
|
|||||||
/// ]);
|
/// ]);
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||||
array: &[D],
|
array: &[D],
|
||||||
shape: S,
|
shape: S,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let shape = shape.into();
|
let shape = shape.into_shape(array.len())?;
|
||||||
let n: usize = shape.elem_count();
|
|
||||||
let buffer_size: usize = array.len();
|
|
||||||
if buffer_size != n {
|
|
||||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
|
||||||
}
|
|
||||||
let storage = device.storage_from_slice(array)?;
|
let storage = device.storage_from_slice(array)?;
|
||||||
let none = BackpropOp::none();
|
let none = BackpropOp::none();
|
||||||
Ok(from_storage(storage, shape, none, false))
|
Ok(from_storage(storage, shape, none, false))
|
||||||
@ -1349,8 +1354,7 @@ impl Tensor {
|
|||||||
self.index_select(ids, 0)
|
self.index_select(ids, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
|
||||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
|
||||||
let source_dims = source.dims();
|
let source_dims = source.dims();
|
||||||
let self_dims = self.dims();
|
let self_dims = self.dims();
|
||||||
let mismatch = if source_dims.len() != self_dims.len() {
|
let mismatch = if source_dims.len() != self_dims.len() {
|
||||||
@ -1367,7 +1371,7 @@ impl Tensor {
|
|||||||
};
|
};
|
||||||
if mismatch {
|
if mismatch {
|
||||||
Err(Error::ShapeMismatchBinaryOp {
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
op: "scatter-add (self, src)",
|
op: "scatter (self, src)",
|
||||||
lhs: self.shape().clone(),
|
lhs: self.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
}
|
}
|
||||||
@ -1375,13 +1379,44 @@ impl Tensor {
|
|||||||
}
|
}
|
||||||
if indexes.dims() != source.dims() {
|
if indexes.dims() != source.dims() {
|
||||||
Err(Error::ShapeMismatchBinaryOp {
|
Err(Error::ShapeMismatchBinaryOp {
|
||||||
op: "scatter-add (indexes, src)",
|
op: "scatter (indexes, src)",
|
||||||
lhs: indexes.shape().clone(),
|
lhs: indexes.shape().clone(),
|
||||||
rhs: source.shape().clone(),
|
rhs: source.shape().clone(),
|
||||||
}
|
}
|
||||||
.bt())?
|
.bt())?
|
||||||
}
|
}
|
||||||
let storage = self.storage().scatter_add(
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||||
|
self.scatter_checks(indexes, source, dim)?;
|
||||||
|
let shape = self.shape();
|
||||||
|
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let layout = Layout::contiguous(shape);
|
||||||
|
storage.scatter_set(
|
||||||
|
&layout,
|
||||||
|
&indexes.storage(),
|
||||||
|
indexes.layout(),
|
||||||
|
&source.storage(),
|
||||||
|
source.layout(),
|
||||||
|
dim,
|
||||||
|
)?;
|
||||||
|
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||||
|
Op::Scatter(t1, t2, t3, dim)
|
||||||
|
});
|
||||||
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||||
|
if self.same_storage(source) {
|
||||||
|
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||||
|
}
|
||||||
|
let dim = dim.to_index(self.shape(), "scatter-set")?;
|
||||||
|
self.scatter_checks(indexes, source, dim)?;
|
||||||
|
self.storage_mut().scatter_set(
|
||||||
self.layout(),
|
self.layout(),
|
||||||
&indexes.storage(),
|
&indexes.storage(),
|
||||||
indexes.layout(),
|
indexes.layout(),
|
||||||
@ -1389,12 +1424,48 @@ impl Tensor {
|
|||||||
source.layout(),
|
source.layout(),
|
||||||
dim,
|
dim,
|
||||||
)?;
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||||
|
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||||
|
self.scatter_checks(indexes, source, dim)?;
|
||||||
|
let shape = self.shape();
|
||||||
|
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||||
|
self.storage()
|
||||||
|
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||||
|
let layout = Layout::contiguous(shape);
|
||||||
|
storage.scatter_add(
|
||||||
|
&layout,
|
||||||
|
&indexes.storage(),
|
||||||
|
indexes.layout(),
|
||||||
|
&source.storage(),
|
||||||
|
source.layout(),
|
||||||
|
dim,
|
||||||
|
)?;
|
||||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||||
Op::ScatterAdd(t1, t2, t3, dim)
|
Op::ScatterAdd(t1, t2, t3, dim)
|
||||||
});
|
});
|
||||||
Ok(from_storage(storage, self.shape(), op, false))
|
Ok(from_storage(storage, self.shape(), op, false))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||||
|
if self.same_storage(source) {
|
||||||
|
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||||
|
}
|
||||||
|
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
|
||||||
|
self.scatter_checks(indexes, source, dim)?;
|
||||||
|
self.storage_mut().scatter_add(
|
||||||
|
self.layout(),
|
||||||
|
&indexes.storage(),
|
||||||
|
indexes.layout(),
|
||||||
|
&source.storage(),
|
||||||
|
source.layout(),
|
||||||
|
dim,
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||||
@ -2197,7 +2268,7 @@ impl Tensor {
|
|||||||
///
|
///
|
||||||
/// # Ok::<(), candle_core::Error>(())
|
/// # Ok::<(), candle_core::Error>(())
|
||||||
/// ```
|
/// ```
|
||||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||||
let shape = s.into_shape(self.elem_count())?;
|
let shape = s.into_shape(self.elem_count())?;
|
||||||
if shape.elem_count() != self.elem_count() {
|
if shape.elem_count() != self.elem_count() {
|
||||||
return Err(Error::ShapeMismatchBinaryOp {
|
return Err(Error::ShapeMismatchBinaryOp {
|
||||||
|
@ -241,7 +241,7 @@ impl Tensor {
|
|||||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||||
///
|
///
|
||||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
/// Note that this modifies `self` in place and as such is not compatible with
|
||||||
/// back-propagation.
|
/// back-propagation.
|
||||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||||
|
@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
);
|
);
|
||||||
assert_eq!(
|
if !device.is_metal() {
|
||||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
assert_eq!(
|
||||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||||
);
|
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||||
|
);
|
||||||
|
}
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||||
[
|
[
|
||||||
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn full(device: &Device) -> Result<()> {
|
fn full(device: &Device) -> Result<()> {
|
||||||
|
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
|
||||||
|
tensor.const_set(42u32.into())?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec2::<u32>()?,
|
||||||
|
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
|
||||||
|
);
|
||||||
|
tensor.i((.., 2))?.const_set(1337u32.into())?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec2::<u32>()?,
|
||||||
|
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
|
||||||
|
);
|
||||||
|
tensor.i((2, ..))?.const_set(1u32.into())?;
|
||||||
|
assert_eq!(
|
||||||
|
tensor.to_vec2::<u32>()?,
|
||||||
|
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn const_set(device: &Device) -> Result<()> {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||||
[[42, 42, 42], [42, 42, 42]],
|
[[42, 42, 42], [42, 42, 42]],
|
||||||
@ -826,6 +848,31 @@ fn embeddings(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn index_select_fail() -> Result<()> {
|
||||||
|
// Check that an error is properly reported on out of bounds.
|
||||||
|
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
|
||||||
|
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
|
||||||
|
let hs = t.index_select(&ids, 0);
|
||||||
|
assert!(hs.is_err());
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// The test below triggers an unwinding panic as there is a panic within the
|
||||||
|
// #[cfg(feature = "cuda")]
|
||||||
|
// #[test]
|
||||||
|
// #[should_panic]
|
||||||
|
// fn index_select_fail_gpu() {
|
||||||
|
// // Check that a panic happens for out of bounds in cuda
|
||||||
|
// if let Ok(device) = Device::new_cuda(0) {
|
||||||
|
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
|
||||||
|
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
|
||||||
|
// let _ = t.index_select(&ids, 0);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
fn cmp(device: &Device) -> Result<()> {
|
fn cmp(device: &Device) -> Result<()> {
|
||||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||||
@ -980,7 +1027,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scatter_add(device: &Device) -> Result<()> {
|
fn scatter(device: &Device) -> Result<()> {
|
||||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
t.to_vec2::<f32>()?,
|
t.to_vec2::<f32>()?,
|
||||||
@ -1004,6 +1051,17 @@ fn scatter_add(device: &Device) -> Result<()> {
|
|||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
|
||||||
|
let hs = init.scatter(&ids, &t, 1)?;
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 1.0, 2.0, 1.0, 1.0],
|
||||||
|
[5.0, 1.0, 1.0, 3.0, 4.0],
|
||||||
|
[1.0, 8.0, 1.0, 7.0, 1.0],
|
||||||
|
[10.0, 1.0, 9.0, 1.0, 11.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
||||||
let hs = init.scatter_add(&ids, &t, 0)?;
|
let hs = init.scatter_add(&ids, &t, 0)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
@ -1017,6 +1075,30 @@ fn scatter_add(device: &Device) -> Result<()> {
|
|||||||
[1.0, 1.0, 1.0]
|
[1.0, 1.0, 1.0]
|
||||||
]
|
]
|
||||||
);
|
);
|
||||||
|
let hs = init.scatter(&ids, &t, 0)?;
|
||||||
|
assert_eq!(
|
||||||
|
hs.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 10.0, 5.0],
|
||||||
|
[1.0, 1.0, 8.0],
|
||||||
|
[9.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 1.0],
|
||||||
|
[1.0, 4.0, 11.0],
|
||||||
|
[1.0, 1.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
init.scatter_set(&ids, &t, 0)?;
|
||||||
|
assert_eq!(
|
||||||
|
init.to_vec2::<f32>()?,
|
||||||
|
&[
|
||||||
|
[0.0, 10.0, 5.0],
|
||||||
|
[1.0, 1.0, 8.0],
|
||||||
|
[9.0, 1.0, 2.0],
|
||||||
|
[6.0, 7.0, 1.0],
|
||||||
|
[1.0, 4.0, 11.0],
|
||||||
|
[1.0, 1.0, 1.0]
|
||||||
|
]
|
||||||
|
);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1484,6 +1566,7 @@ fn zero_dim(device: &Device) -> Result<()> {
|
|||||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||||
|
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
|
||||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||||
@ -1515,12 +1598,7 @@ test_device!(
|
|||||||
);
|
);
|
||||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||||
test_device!(
|
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
|
||||||
scatter_add,
|
|
||||||
scatter_add_cpu,
|
|
||||||
scatter_add_gpu,
|
|
||||||
scatter_add_metal
|
|
||||||
);
|
|
||||||
test_device!(
|
test_device!(
|
||||||
slice_scatter,
|
slice_scatter,
|
||||||
slice_scatter_cpu,
|
slice_scatter_cpu,
|
||||||
|
@ -124,6 +124,17 @@ impl TextGeneration {
|
|||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
None => anyhow::bail!("cannot find the <eos> token"),
|
None => anyhow::bail!("cannot find the <eos> token"),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||||
|
Some(token) => token,
|
||||||
|
None => {
|
||||||
|
println!(
|
||||||
|
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||||
|
);
|
||||||
|
eos_token
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let start_gen = std::time::Instant::now();
|
let start_gen = std::time::Instant::now();
|
||||||
for index in 0..sample_len {
|
for index in 0..sample_len {
|
||||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||||
@ -146,7 +157,7 @@ impl TextGeneration {
|
|||||||
let next_token = self.logits_processor.sample(&logits)?;
|
let next_token = self.logits_processor.sample(&logits)?;
|
||||||
tokens.push(next_token);
|
tokens.push(next_token);
|
||||||
generated_tokens += 1;
|
generated_tokens += 1;
|
||||||
if next_token == eos_token {
|
if next_token == eos_token || next_token == eot_token {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||||
@ -350,6 +361,31 @@ fn main() -> Result<()> {
|
|||||||
args.repeat_last_n,
|
args.repeat_last_n,
|
||||||
&device,
|
&device,
|
||||||
);
|
);
|
||||||
pipeline.run(&args.prompt, args.sample_len)?;
|
|
||||||
|
let prompt = match args.which {
|
||||||
|
Which::Base2B
|
||||||
|
| Which::Base7B
|
||||||
|
| Which::Instruct2B
|
||||||
|
| Which::Instruct7B
|
||||||
|
| Which::InstructV1_1_2B
|
||||||
|
| Which::InstructV1_1_7B
|
||||||
|
| Which::CodeBase2B
|
||||||
|
| Which::CodeBase7B
|
||||||
|
| Which::CodeInstruct2B
|
||||||
|
| Which::CodeInstruct7B
|
||||||
|
| Which::BaseV2_2B
|
||||||
|
| Which::InstructV2_2B
|
||||||
|
| Which::BaseV2_9B
|
||||||
|
| Which::InstructV2_9B
|
||||||
|
| Which::BaseV3_1B => args.prompt,
|
||||||
|
Which::InstructV3_1B => {
|
||||||
|
format!(
|
||||||
|
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||||
|
args.prompt
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
pipeline.run(&prompt, args.sample_len)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
18
candle-examples/examples/quantized-gemma/README.md
Normal file
18
candle-examples/examples/quantized-gemma/README.md
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
# candle-quantized-gemma
|
||||||
|
|
||||||
|
Candle implementation of quantized Gemma.
|
||||||
|
|
||||||
|
## Running an example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
|
||||||
|
|
||||||
|
> ```python
|
||||||
|
> def fibonacci(n):
|
||||||
|
> """Calculates the nth Fibonacci number using recursion."""
|
||||||
|
> if n <= 1:
|
||||||
|
> return n
|
||||||
|
> else:
|
||||||
|
> return fibonacci(n-1) + fibonacci(n-2
|
||||||
|
> ```
|
||||||
|
```
|
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
@ -0,0 +1,344 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use std::io::Write;
|
||||||
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
use candle::quantized::gguf_file;
|
||||||
|
use candle::Tensor;
|
||||||
|
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||||
|
|
||||||
|
use candle_examples::token_output_stream::TokenOutputStream;
|
||||||
|
use candle_transformers::models::quantized_gemma3::ModelWeights;
|
||||||
|
|
||||||
|
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "gemma3-4b-it")]
|
||||||
|
Gemma3_4bIt,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// GGUF file to load, typically a .gguf file generated by quantization
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||||
|
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||||
|
/// is preserved.
|
||||||
|
#[arg(long)]
|
||||||
|
prompt: Option<String>,
|
||||||
|
|
||||||
|
/// The length of the sample to generate (in tokens).
|
||||||
|
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||||
|
sample_len: usize,
|
||||||
|
|
||||||
|
/// The tokenizer config in json format.
|
||||||
|
#[arg(long)]
|
||||||
|
tokenizer: Option<String>,
|
||||||
|
|
||||||
|
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||||
|
#[arg(long, default_value_t = 0.8)]
|
||||||
|
temperature: f64,
|
||||||
|
|
||||||
|
/// Nucleus sampling probability cutoff.
|
||||||
|
#[arg(long)]
|
||||||
|
top_p: Option<f64>,
|
||||||
|
|
||||||
|
/// Only sample among the top K samples.
|
||||||
|
#[arg(long)]
|
||||||
|
top_k: Option<usize>,
|
||||||
|
|
||||||
|
/// The seed to use when generating random samples.
|
||||||
|
#[arg(long, default_value_t = 299792458)]
|
||||||
|
seed: u64,
|
||||||
|
|
||||||
|
/// Enable tracing (generates a trace-timestamp.json file).
|
||||||
|
#[arg(long)]
|
||||||
|
tracing: bool,
|
||||||
|
|
||||||
|
/// Process prompt elements separately.
|
||||||
|
#[arg(long)]
|
||||||
|
split_prompt: bool,
|
||||||
|
|
||||||
|
/// Run on CPU rather than GPU even if a GPU is available.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||||
|
#[arg(long, default_value_t = 1.1)]
|
||||||
|
repeat_penalty: f32,
|
||||||
|
|
||||||
|
/// The context size to consider for the repeat penalty.
|
||||||
|
#[arg(long, default_value_t = 64)]
|
||||||
|
repeat_last_n: usize,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "gemma3-4b-it")]
|
||||||
|
which: Which,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Args {
|
||||||
|
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||||
|
let tokenizer_path = match &self.tokenizer {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
let repo = "google/gemma-3-4b-it";
|
||||||
|
println!("DEBUG: Downloading tokenizer from {}", repo);
|
||||||
|
let api = api.model(repo.to_string());
|
||||||
|
api.get("tokenizer.json")?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||||
|
|
||||||
|
Ok(tokenizer)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||||
|
let model_path = match &self.model {
|
||||||
|
Some(config) => std::path::PathBuf::from(config),
|
||||||
|
None => {
|
||||||
|
let (repo, filename) = match self.which {
|
||||||
|
Which::Gemma3_4bIt => (
|
||||||
|
"google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||||
|
"gemma-3-4b-it-q4_0.gguf",
|
||||||
|
),
|
||||||
|
};
|
||||||
|
let api = hf_hub::api::sync::Api::new()?;
|
||||||
|
api.repo(hf_hub::Repo::with_revision(
|
||||||
|
repo.to_string(),
|
||||||
|
hf_hub::RepoType::Model,
|
||||||
|
"main".to_string(),
|
||||||
|
))
|
||||||
|
.get(filename)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Ok(model_path)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn format_size(size_in_bytes: usize) -> String {
|
||||||
|
if size_in_bytes < 1_000 {
|
||||||
|
format!("{}B", size_in_bytes)
|
||||||
|
} else if size_in_bytes < 1_000_000 {
|
||||||
|
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||||
|
} else if size_in_bytes < 1_000_000_000 {
|
||||||
|
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||||
|
} else {
|
||||||
|
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum Prompt {
|
||||||
|
Interactive,
|
||||||
|
Chat,
|
||||||
|
One(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> anyhow::Result<()> {
|
||||||
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
|
let args = Args::parse();
|
||||||
|
let _guard = if args.tracing {
|
||||||
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
|
Some(guard)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||||
|
candle::utils::with_avx(),
|
||||||
|
candle::utils::with_neon(),
|
||||||
|
candle::utils::with_simd128(),
|
||||||
|
candle::utils::with_f16c()
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||||
|
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||||
|
);
|
||||||
|
|
||||||
|
let model_path = args.model()?;
|
||||||
|
let mut file = std::fs::File::open(&model_path)?;
|
||||||
|
let start = std::time::Instant::now();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
|
||||||
|
let mut model = {
|
||||||
|
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
|
||||||
|
let mut total_size_in_bytes = 0;
|
||||||
|
for (_, tensor) in model.tensor_infos.iter() {
|
||||||
|
let elem_count = tensor.shape.elem_count();
|
||||||
|
total_size_in_bytes +=
|
||||||
|
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"loaded {:?} tensors ({}) in {:.2}s",
|
||||||
|
model.tensor_infos.len(),
|
||||||
|
&format_size(total_size_in_bytes),
|
||||||
|
start.elapsed().as_secs_f32(),
|
||||||
|
);
|
||||||
|
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||||
|
};
|
||||||
|
println!("model built");
|
||||||
|
|
||||||
|
let tokenizer = args.tokenizer()?;
|
||||||
|
|
||||||
|
let mut tos = TokenOutputStream::new(tokenizer);
|
||||||
|
println!(
|
||||||
|
"DEBUG: Tokenizer vocabulary size: {}",
|
||||||
|
tos.tokenizer().get_vocab(true).len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let prompt = match args.prompt.as_deref() {
|
||||||
|
Some("chat") => Prompt::Chat,
|
||||||
|
Some("interactive") => Prompt::Interactive,
|
||||||
|
Some(s) => Prompt::One(s.to_string()),
|
||||||
|
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut pre_prompt_tokens = vec![];
|
||||||
|
for _ in 0.. {
|
||||||
|
let prompt_str = match &prompt {
|
||||||
|
Prompt::One(prompt) => prompt.clone(),
|
||||||
|
Prompt::Interactive | Prompt::Chat => {
|
||||||
|
print!("> ");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let mut prompt = String::new();
|
||||||
|
std::io::stdin().read_line(&mut prompt)?;
|
||||||
|
if prompt.ends_with('\n') {
|
||||||
|
prompt.pop();
|
||||||
|
if prompt.ends_with('\r') {
|
||||||
|
prompt.pop();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Format for Gemma 3 chat/instruction format
|
||||||
|
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
print!("{}", &prompt_str);
|
||||||
|
|
||||||
|
let tokens = tos
|
||||||
|
.tokenizer()
|
||||||
|
.encode(prompt_str, true)
|
||||||
|
.map_err(anyhow::Error::msg)?;
|
||||||
|
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||||
|
|
||||||
|
let to_sample = args.sample_len.saturating_sub(1);
|
||||||
|
let max_seq_len = 8192; // Gemma 3 context length
|
||||||
|
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
|
||||||
|
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
|
||||||
|
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||||
|
} else {
|
||||||
|
prompt_tokens
|
||||||
|
};
|
||||||
|
let mut all_tokens = vec![];
|
||||||
|
let mut logits_processor = {
|
||||||
|
let temperature = args.temperature;
|
||||||
|
let sampling = if temperature <= 0. {
|
||||||
|
Sampling::ArgMax
|
||||||
|
} else {
|
||||||
|
match (args.top_k, args.top_p) {
|
||||||
|
(None, None) => Sampling::All { temperature },
|
||||||
|
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||||
|
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||||
|
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||||
|
}
|
||||||
|
};
|
||||||
|
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_prompt_processing = std::time::Instant::now();
|
||||||
|
let mut next_token = if !args.split_prompt {
|
||||||
|
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, 0)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
logits_processor.sample(&logits)?
|
||||||
|
} else {
|
||||||
|
let mut next_token = 0;
|
||||||
|
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||||
|
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, pos)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
next_token = logits_processor.sample(&logits)?
|
||||||
|
}
|
||||||
|
next_token
|
||||||
|
};
|
||||||
|
let prompt_dt = start_prompt_processing.elapsed();
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Gemma 3, use the correct end of sequence token
|
||||||
|
let eos_token = *tos
|
||||||
|
.tokenizer()
|
||||||
|
.get_vocab(true)
|
||||||
|
.get("<end_of_turn>")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let start_post_prompt = std::time::Instant::now();
|
||||||
|
let mut sampled = 0;
|
||||||
|
for index in 0..to_sample {
|
||||||
|
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||||
|
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||||
|
let logits = logits.squeeze(0)?;
|
||||||
|
let logits = if args.repeat_penalty == 1. {
|
||||||
|
logits
|
||||||
|
} else {
|
||||||
|
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||||
|
candle_transformers::utils::apply_repeat_penalty(
|
||||||
|
&logits,
|
||||||
|
args.repeat_penalty,
|
||||||
|
&all_tokens[start_at..],
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
next_token = logits_processor.sample(&logits)?;
|
||||||
|
all_tokens.push(next_token);
|
||||||
|
if let Some(t) = tos.next_token(next_token)? {
|
||||||
|
print!("{t}");
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
}
|
||||||
|
sampled += 1;
|
||||||
|
if next_token == eos_token {
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
}
|
||||||
|
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||||
|
print!("{rest}");
|
||||||
|
}
|
||||||
|
std::io::stdout().flush()?;
|
||||||
|
let dt = start_post_prompt.elapsed();
|
||||||
|
println!(
|
||||||
|
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||||
|
prompt_tokens.len(),
|
||||||
|
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"{sampled:4} tokens generated: {:.2} token/s",
|
||||||
|
sampled as f64 / dt.as_secs_f64(),
|
||||||
|
);
|
||||||
|
|
||||||
|
match prompt {
|
||||||
|
Prompt::One(_) => break,
|
||||||
|
Prompt::Interactive => {}
|
||||||
|
Prompt::Chat => {
|
||||||
|
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-flash-attn"
|
name = "candle-flash-attn"
|
||||||
version = "0.9.0-alpha.4"
|
version = "0.9.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Flash attention layer for the candle ML framework."
|
description = "Flash attention layer for the candle ML framework."
|
||||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
|
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0" }
|
||||||
half = { version = "2.3.1", features = ["num-traits"] }
|
half = { version = "2.3.1", features = ["num-traits"] }
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-kernels"
|
name = "candle-kernels"
|
||||||
version = "0.9.0-alpha.4"
|
version = "0.9.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "CUDA kernels for Candle"
|
description = "CUDA kernels for Candle"
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#include<stdint.h>
|
#include<stdint.h>
|
||||||
#include "cuda_fp16.h"
|
#include "cuda_fp16.h"
|
||||||
|
#include "cuda_utils.cuh"
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||||
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
|
|||||||
COPY2D_OP(uint32_t, copy2d_u32)
|
COPY2D_OP(uint32_t, copy2d_u32)
|
||||||
COPY2D_OP(int64_t, copy2d_i64)
|
COPY2D_OP(int64_t, copy2d_i64)
|
||||||
|
|
||||||
|
#define CONST_SET_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t numel, \
|
||||||
|
const size_t num_dims, \
|
||||||
|
const size_t *info, \
|
||||||
|
const TYPENAME inp, \
|
||||||
|
TYPENAME *out \
|
||||||
|
) { \
|
||||||
|
const size_t *dims = info; \
|
||||||
|
const size_t *strides = info + num_dims; \
|
||||||
|
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
out[i] = inp; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
else { \
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||||
|
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||||
|
out[strided_i] = inp; \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
} \
|
||||||
|
|
||||||
|
CONST_SET_OP(float, const_set_f32)
|
||||||
|
CONST_SET_OP(double, const_set_f64)
|
||||||
|
CONST_SET_OP(uint8_t, const_set_u8)
|
||||||
|
CONST_SET_OP(uint32_t, const_set_u32)
|
||||||
|
CONST_SET_OP(int64_t, const_set_i64)
|
||||||
|
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
COPY2D_OP(__half, copy2d_f16)
|
COPY2D_OP(__half, copy2d_f16)
|
||||||
|
CONST_SET_OP(__half, const_set_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
#include <cuda_bf16.h>
|
#include <cuda_bf16.h>
|
||||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||||
|
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
@ -23,6 +23,7 @@ __device__ void index_select(
|
|||||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||||
unsigned int right_i = dst_i % right_size;
|
unsigned int right_i = dst_i % right_size;
|
||||||
|
assert(ids[id_i] < src_dim_size);
|
||||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||||
out[dst_i] = inp[strided_i];
|
out[dst_i] = inp[strided_i];
|
||||||
@ -57,6 +58,7 @@ __device__ void gather(
|
|||||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
size_t post = i % right_size;
|
size_t post = i % right_size;
|
||||||
size_t idx = ids[i];
|
size_t idx = ids[i];
|
||||||
|
assert(idx < src_dim_size);
|
||||||
size_t pre = i / (right_size * ids_dim_size);
|
size_t pre = i / (right_size * ids_dim_size);
|
||||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||||
out[i] = inp[src_i];
|
out[i] = inp[src_i];
|
||||||
@ -92,6 +94,7 @@ __device__ void index_add(
|
|||||||
const size_t post = i % right_size;
|
const size_t post = i % right_size;
|
||||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||||
const size_t idx = ids[j];
|
const size_t idx = ids[j];
|
||||||
|
assert(idx < dst_dim_size);
|
||||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
out[dst_i] += inp[src_i];
|
out[dst_i] += inp[src_i];
|
||||||
@ -111,6 +114,30 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
const size_t right_size \
|
const size_t right_size \
|
||||||
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||||
|
|
||||||
|
template<typename T, typename I>
|
||||||
|
__device__ void scatter(
|
||||||
|
const I *ids,
|
||||||
|
const T *inp,
|
||||||
|
T *out,
|
||||||
|
const size_t left_size,
|
||||||
|
const size_t src_dim_size,
|
||||||
|
const size_t dst_dim_size,
|
||||||
|
const size_t right_size
|
||||||
|
) {
|
||||||
|
const size_t numel = left_size * right_size;
|
||||||
|
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||||
|
const size_t pre = i / right_size;
|
||||||
|
const size_t post = i % right_size;
|
||||||
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
|
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||||
|
const size_t idx = ids[src_i];
|
||||||
|
assert(idx < dst_dim_size);
|
||||||
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
|
out[dst_i] = inp[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename T, typename I>
|
template<typename T, typename I>
|
||||||
__device__ void scatter_add(
|
__device__ void scatter_add(
|
||||||
const I *ids,
|
const I *ids,
|
||||||
@ -128,12 +155,24 @@ __device__ void scatter_add(
|
|||||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||||
const size_t idx = ids[src_i];
|
const size_t idx = ids[src_i];
|
||||||
|
assert(idx < dst_dim_size);
|
||||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||||
out[dst_i] += inp[src_i];
|
out[dst_i] += inp[src_i];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const INDEX_TYPENAME *ids, \
|
||||||
|
const TYPENAME *inp, \
|
||||||
|
TYPENAME *out, \
|
||||||
|
const size_t left_size, \
|
||||||
|
const size_t src_dim_size, \
|
||||||
|
const size_t dst_dim_size, \
|
||||||
|
const size_t right_size \
|
||||||
|
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||||
|
|
||||||
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const INDEX_TYPENAME *ids, \
|
const INDEX_TYPENAME *ids, \
|
||||||
@ -159,6 +198,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
|
|||||||
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
|
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
|
||||||
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
||||||
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
||||||
|
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
|
||||||
|
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
|
||||||
|
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -174,6 +216,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
|
|||||||
SA_OP(__half, int64_t, sa_i64_f16)
|
SA_OP(__half, int64_t, sa_i64_f16)
|
||||||
SA_OP(__half, uint32_t, sa_u32_f16)
|
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||||
SA_OP(__half, uint8_t, sa_u8_f16)
|
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||||
|
S_OP(__half, int64_t, s_i64_f16)
|
||||||
|
S_OP(__half, uint32_t, s_u32_f16)
|
||||||
|
S_OP(__half, uint8_t, s_u8_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
IS_OP(float, int64_t, is_i64_f32)
|
IS_OP(float, int64_t, is_i64_f32)
|
||||||
@ -247,3 +292,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
|
|||||||
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
||||||
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
||||||
SA_OP(int64_t, uint8_t, sa_u8_i64)
|
SA_OP(int64_t, uint8_t, sa_u8_i64)
|
||||||
|
|
||||||
|
S_OP(float, int64_t, s_i64_f32)
|
||||||
|
S_OP(double, int64_t, s_i64_f64)
|
||||||
|
S_OP(uint8_t, int64_t, s_i64_u8)
|
||||||
|
S_OP(int64_t, int64_t, s_i64_i64)
|
||||||
|
S_OP(uint32_t, int64_t, s_i64_u32)
|
||||||
|
|
||||||
|
S_OP(float, uint32_t, s_u32_f32)
|
||||||
|
S_OP(double, uint32_t, s_u32_f64)
|
||||||
|
S_OP(uint8_t, uint32_t, s_u32_u8)
|
||||||
|
S_OP(int64_t, uint32_t, s_u32_i64)
|
||||||
|
S_OP(uint32_t, uint32_t, s_u32_u32)
|
||||||
|
|
||||||
|
S_OP(float, uint8_t, s_u8_f32)
|
||||||
|
S_OP(double, uint8_t, s_u8_f64)
|
||||||
|
S_OP(uint8_t, uint8_t, s_u8_u8)
|
||||||
|
S_OP(uint32_t, uint8_t, s_u8_u32)
|
||||||
|
S_OP(int64_t, uint8_t, s_u8_i64)
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-metal-kernels"
|
name = "candle-metal-kernels"
|
||||||
version = "0.9.0-alpha.4"
|
version = "0.9.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "Metal kernels for Candle"
|
description = "Metal kernels for Candle"
|
||||||
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
metal = { version = "0.27.0", features = ["mps"] }
|
metal = { version = "0.27.0", features = ["mps"] }
|
||||||
|
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
once_cell = "1.18.0"
|
once_cell = "1.18.0"
|
||||||
thiserror = "1"
|
thiserror = "1"
|
||||||
tracing = "0.1.37"
|
tracing = "0.1.37"
|
||||||
|
@ -4,20 +4,20 @@ using namespace metal;
|
|||||||
|
|
||||||
template<typename T> METAL_FUNC void fill_with(
|
template<typename T> METAL_FUNC void fill_with(
|
||||||
device T *out,
|
device T *out,
|
||||||
constant float &value,
|
constant T &value,
|
||||||
constant size_t &numel,
|
constant size_t &numel,
|
||||||
uint tid [[thread_position_in_grid]]
|
uint tid [[thread_position_in_grid]]
|
||||||
) {
|
) {
|
||||||
if (tid >= numel) {
|
if (tid >= numel) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
out[tid] = static_cast<T>(value);
|
out[tid] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
#define FILL_OP(NAME, T) \
|
#define FILL_OP(NAME, T) \
|
||||||
kernel void fill_##NAME( \
|
kernel void fill_##NAME( \
|
||||||
device T *out, \
|
device T *out, \
|
||||||
constant float &value, \
|
constant T &value, \
|
||||||
constant size_t &numel, \
|
constant size_t &numel, \
|
||||||
uint tid [[thread_position_in_grid]] \
|
uint tid [[thread_position_in_grid]] \
|
||||||
) { \
|
) { \
|
||||||
|
@ -104,6 +104,31 @@ kernel void NAME( \
|
|||||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
|
METAL_FUNC void scatter(
|
||||||
|
constant size_t &dst_size,
|
||||||
|
constant size_t &left_size,
|
||||||
|
constant size_t &src_dim_size,
|
||||||
|
constant size_t &right_size,
|
||||||
|
constant size_t &dst_dim_size,
|
||||||
|
const device TYPENAME *input,
|
||||||
|
const device INDEX_TYPENAME *input_ids,
|
||||||
|
device TYPENAME *output,
|
||||||
|
uint tid [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
if (tid >= dst_size) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t right_rank_i = tid % right_size;
|
||||||
|
const size_t left_rank_i = tid / right_size;
|
||||||
|
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||||
|
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||||
|
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||||
|
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||||
|
output[dst_i] = input[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||||
METAL_FUNC void scatter_add(
|
METAL_FUNC void scatter_add(
|
||||||
constant size_t &dst_size,
|
constant size_t &dst_size,
|
||||||
@ -129,6 +154,21 @@ METAL_FUNC void scatter_add(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &dst_size, \
|
||||||
|
constant size_t &left_size, \
|
||||||
|
constant size_t &src_dim_size, \
|
||||||
|
constant size_t &right_size, \
|
||||||
|
constant size_t &dst_dim_size, \
|
||||||
|
const device TYPENAME *input, \
|
||||||
|
const device INDEX_TYPENAME *input_ids, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||||
|
}
|
||||||
|
|
||||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &dst_size, \
|
constant size_t &dst_size, \
|
||||||
@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
|||||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
SCATTER_OP(s_u32_f32, uint32_t, float)
|
||||||
|
SCATTER_OP(s_u8_f32, uint8_t, float)
|
||||||
|
SCATTER_OP(s_i64_f32, int64_t, float)
|
||||||
|
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
|
||||||
|
SCATTER_OP(s_u32_f16, uint32_t, half)
|
||||||
|
SCATTER_OP(s_u8_f16, uint8_t, half)
|
||||||
|
SCATTER_OP(s_i64_f16, int64_t, half)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
|
||||||
|
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
|
||||||
|
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
|
||||||
|
#endif
|
||||||
|
|
||||||
// i64
|
// i64
|
||||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||||
|
@ -161,7 +161,7 @@ macro_rules! ops{
|
|||||||
pub mod unary {
|
pub mod unary {
|
||||||
ops!(
|
ops!(
|
||||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||||
tanh, recip, silu, sign, sigmoid
|
tanh, recip, silu, sign, sigmoid, const_set
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
pub mod binary {
|
pub mod binary {
|
||||||
@ -419,6 +419,82 @@ pub fn call_copy2d(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_const_set_contiguous_tiled(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
kernel_name: unary::contiguous_tiled::Kernel,
|
||||||
|
length: usize,
|
||||||
|
input: impl EncoderParam,
|
||||||
|
output: BufferOffset,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
let tile_size = 2;
|
||||||
|
let tiles = length.div_ceil(tile_size);
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(encoder, (length, input, &output));
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||||
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_const_set_contiguous(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
kernel_name: unary::contiguous::Kernel,
|
||||||
|
length: usize,
|
||||||
|
input: impl EncoderParam,
|
||||||
|
output: BufferOffset,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(encoder, (length, input, &output));
|
||||||
|
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_const_set_strided(
|
||||||
|
device: &Device,
|
||||||
|
ep: impl EncoderProvider,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: unary::strided::Kernel,
|
||||||
|
shape: &[usize],
|
||||||
|
input: impl EncoderParam,
|
||||||
|
strides: &[usize],
|
||||||
|
output: BufferOffset,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||||
|
|
||||||
|
let length: usize = shape.iter().product();
|
||||||
|
let num_dims: usize = shape.len();
|
||||||
|
let encoder = ep.encoder();
|
||||||
|
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||||
|
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
||||||
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_unary_contiguous_tiled(
|
pub fn call_unary_contiguous_tiled(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
@ -1371,7 +1447,7 @@ pub fn call_gather(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_scatter_add(
|
pub fn call_scatter(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
ep: impl EncoderProvider,
|
ep: impl EncoderProvider,
|
||||||
kernels: &Kernels,
|
kernels: &Kernels,
|
||||||
@ -1381,7 +1457,7 @@ pub fn call_scatter_add(
|
|||||||
dim: usize,
|
dim: usize,
|
||||||
input: BufferOffset,
|
input: BufferOffset,
|
||||||
ids: BufferOffset,
|
ids: BufferOffset,
|
||||||
output: &Buffer,
|
output: BufferOffset,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let left_size: usize = src_shape[..dim].iter().product();
|
let left_size: usize = src_shape[..dim].iter().product();
|
||||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||||
@ -1406,7 +1482,7 @@ pub fn call_scatter_add(
|
|||||||
dst_dim_size,
|
dst_dim_size,
|
||||||
&input,
|
&input,
|
||||||
&ids,
|
&ids,
|
||||||
output
|
&output
|
||||||
)
|
)
|
||||||
);
|
);
|
||||||
|
|
||||||
@ -1414,7 +1490,7 @@ pub fn call_scatter_add(
|
|||||||
|
|
||||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -2570,7 +2646,7 @@ pub fn call_const_fill(
|
|||||||
name: &'static str,
|
name: &'static str,
|
||||||
length: usize,
|
length: usize,
|
||||||
output: &Buffer,
|
output: &Buffer,
|
||||||
v: f32,
|
v: impl EncoderParam,
|
||||||
) -> Result<(), MetalKernelError> {
|
) -> Result<(), MetalKernelError> {
|
||||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||||
let encoder = ep.encoder();
|
let encoder = ep.encoder();
|
||||||
|
@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
|||||||
let input_buffer = new_buffer(&device, input);
|
let input_buffer = new_buffer(&device, input);
|
||||||
let ids_buffer = new_buffer(&device, ids);
|
let ids_buffer = new_buffer(&device, ids);
|
||||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||||
call_scatter_add(
|
call_scatter(
|
||||||
&device,
|
&device,
|
||||||
command_buffer,
|
command_buffer,
|
||||||
&kernels,
|
&kernels,
|
||||||
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn const_fill() {
|
fn const_fill() {
|
||||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
|
||||||
let dev = device();
|
let dev = device();
|
||||||
let kernels = Kernels::new();
|
let kernels = Kernels::new();
|
||||||
let command_queue = dev.new_command_queue();
|
let command_queue = dev.new_command_queue();
|
||||||
@ -2357,11 +2357,15 @@ fn const_fill() {
|
|||||||
command_buffer.wait_until_completed();
|
command_buffer.wait_until_completed();
|
||||||
read_to_vec::<T>(&buffer, len)
|
read_to_vec::<T>(&buffer, len)
|
||||||
}
|
}
|
||||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
|
||||||
|
name: &'static str,
|
||||||
|
f: F,
|
||||||
|
) {
|
||||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||||
|
let value = f(value);
|
||||||
let v = constant_fill::<T>(name, len, value);
|
let v = constant_fill::<T>(name, len, value);
|
||||||
assert_eq!(v, vec![f(value); len])
|
assert_eq!(v, vec![value; len])
|
||||||
}
|
}
|
||||||
test::<u8, _>("fill_u8", |v| v as u8);
|
test::<u8, _>("fill_u8", |v| v as u8);
|
||||||
test::<u32, _>("fill_u32", |v| v as u32);
|
test::<u32, _>("fill_u32", |v| v as u32);
|
||||||
|
@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
|
|||||||
|
|
||||||
#define TILE_SIZE 2
|
#define TILE_SIZE 2
|
||||||
|
|
||||||
|
#define CONST_SET(TYPENAME, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant TYPENAME &input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (tid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[tid] = input; \
|
||||||
|
} \
|
||||||
|
kernel void FN_NAME##_##strided( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant size_t &num_dims, \
|
||||||
|
constant size_t *dims, \
|
||||||
|
constant size_t *strides, \
|
||||||
|
constant TYPENAME &input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
if (tid >= dim) { \
|
||||||
|
return; \
|
||||||
|
} \
|
||||||
|
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
|
||||||
|
} \
|
||||||
|
kernel void FN_NAME##_##tiled( \
|
||||||
|
constant size_t &dim, \
|
||||||
|
constant TYPENAME &input, \
|
||||||
|
device TYPENAME *output, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||||
|
const uint idx = tid * TILE_SIZE + i; \
|
||||||
|
output[idx] = input; \
|
||||||
|
} \
|
||||||
|
}
|
||||||
|
|
||||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &dim, \
|
constant size_t &dim, \
|
||||||
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
|
|||||||
COPY2D(copy2d_u8, uint8_t)
|
COPY2D(copy2d_u8, uint8_t)
|
||||||
COPY2D(copy2d_u32, uint32_t)
|
COPY2D(copy2d_u32, uint32_t)
|
||||||
|
|
||||||
|
CONST_SET(float, const_set_f32)
|
||||||
|
CONST_SET(half, const_set_f16)
|
||||||
|
CONST_SET(uint8_t, const_set_u8)
|
||||||
|
CONST_SET(uint32_t, const_set_u32)
|
||||||
|
|
||||||
UNARY_OP(cos)
|
UNARY_OP(cos)
|
||||||
UNARY_OP(sin)
|
UNARY_OP(sin)
|
||||||
UNARY_OP(sqr)
|
UNARY_OP(sqr)
|
||||||
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
|||||||
#if __METAL_VERSION__ >= 220
|
#if __METAL_VERSION__ >= 220
|
||||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||||
COPY2D(copy2d_i64, int64_t)
|
COPY2D(copy2d_i64, int64_t)
|
||||||
|
CONST_SET(int64_t, const_set_i64)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
|||||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||||
|
|
||||||
COPY2D(copy2d_bf16, bfloat)
|
COPY2D(copy2d_bf16, bfloat)
|
||||||
|
CONST_SET(bfloat, const_set_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
@ -88,9 +88,13 @@ primitive!(bool);
|
|||||||
primitive!(usize);
|
primitive!(usize);
|
||||||
primitive!(i32);
|
primitive!(i32);
|
||||||
primitive!(i64);
|
primitive!(i64);
|
||||||
|
primitive!(u8);
|
||||||
primitive!(u32);
|
primitive!(u32);
|
||||||
primitive!(u64);
|
primitive!(u64);
|
||||||
primitive!(f32);
|
primitive!(f32);
|
||||||
|
primitive!(f64);
|
||||||
|
primitive!(half::bf16);
|
||||||
|
primitive!(half::f16);
|
||||||
|
|
||||||
pub struct BufferOffset<'a> {
|
pub struct BufferOffset<'a> {
|
||||||
pub buffer: &'a Buffer,
|
pub buffer: &'a Buffer,
|
||||||
|
@ -71,6 +71,8 @@ impl candle::Module for PReLU {
|
|||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
let weight = if self.is_scalar {
|
let weight = if self.is_scalar {
|
||||||
self.weight.reshape(())?
|
self.weight.reshape(())?
|
||||||
|
} else if xs.shape() == self.weight.shape() {
|
||||||
|
self.weight.clone()
|
||||||
} else if xs.rank() >= 2 {
|
} else if xs.rank() >= 2 {
|
||||||
let num_channels = xs.dim(1)?;
|
let num_channels = xs.dim(1)?;
|
||||||
let num_weights = self.weight.elem_count();
|
let num_weights = self.weight.elem_count();
|
||||||
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
|
|||||||
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
||||||
}
|
}
|
||||||
let mut s = vec![1; xs.rank()];
|
let mut s = vec![1; xs.rank()];
|
||||||
s[1] = self.weight.elem_count();
|
s[1] = num_weights;
|
||||||
self.weight.reshape(s)?
|
self.weight.reshape(s)?
|
||||||
} else {
|
} else {
|
||||||
self.weight.clone()
|
self.weight.clone()
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "candle-onnx"
|
name = "candle-onnx"
|
||||||
version = "0.9.0-alpha.4"
|
version = "0.9.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
description = "ONNX support for Candle"
|
description = "ONNX support for Candle"
|
||||||
@ -10,8 +10,8 @@ categories = ["science"]
|
|||||||
license = "MIT OR Apache-2.0"
|
license = "MIT OR Apache-2.0"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0" }
|
||||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
|
candle-nn = { path = "../candle-nn", version = "0.9.0" }
|
||||||
prost = "0.12.1"
|
prost = "0.12.1"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
|
@ -1,7 +1,9 @@
|
|||||||
use crate::onnx::attribute_proto::AttributeType;
|
use crate::onnx::attribute_proto::AttributeType;
|
||||||
use crate::onnx::tensor_proto::DataType;
|
use crate::onnx::tensor_proto::DataType;
|
||||||
use crate::onnx::{self, GraphProto};
|
use crate::onnx::{self, GraphProto};
|
||||||
|
use candle::Module;
|
||||||
use candle::{bail, DType, Device, Result, Tensor};
|
use candle::{bail, DType, Device, Result, Tensor};
|
||||||
|
use candle_nn::activation::PReLU;
|
||||||
use std::collections::{HashMap, HashSet};
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
pub type Value = Tensor;
|
pub type Value = Tensor;
|
||||||
@ -991,6 +993,14 @@ fn simple_eval_(
|
|||||||
let output = input.relu()?;
|
let output = input.relu()?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"PRelu" => {
|
||||||
|
// https://onnx.ai/onnx/operators/onnx__PRelu.html
|
||||||
|
let input = get(&node.input[0])?;
|
||||||
|
let slope = get(&node.input[1])?;
|
||||||
|
|
||||||
|
let output = PReLU::new(slope.clone(), false).forward(input)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Ceil" => {
|
"Ceil" => {
|
||||||
let input = get(&node.input[0])?;
|
let input = get(&node.input[0])?;
|
||||||
let output = input.ceil()?;
|
let output = input.ceil()?;
|
||||||
|
@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// "PRelu"
|
||||||
|
#[test]
|
||||||
|
fn test_prelu_operation() -> Result<()> {
|
||||||
|
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||||
|
node: vec![NodeProto {
|
||||||
|
op_type: "PRelu".to_string(),
|
||||||
|
domain: "".to_string(),
|
||||||
|
attribute: vec![],
|
||||||
|
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||||
|
output: vec![OUTPUT_Z.to_string()],
|
||||||
|
name: "".to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
}],
|
||||||
|
name: "".to_string(),
|
||||||
|
initializer: vec![],
|
||||||
|
input: vec![
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_X.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
ValueInfoProto {
|
||||||
|
name: INPUT_Y.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
output: vec![ValueInfoProto {
|
||||||
|
name: OUTPUT_Z.to_string(),
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
r#type: None,
|
||||||
|
}],
|
||||||
|
value_info: vec![],
|
||||||
|
doc_string: "".to_string(),
|
||||||
|
sparse_initializer: vec![],
|
||||||
|
quantization_annotation: vec![],
|
||||||
|
}));
|
||||||
|
let x: Tensor = Tensor::from_vec(
|
||||||
|
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
|
||||||
|
&[2, 2],
|
||||||
|
&Device::Cpu,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
|
||||||
|
|
||||||
|
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||||
|
inputs.insert(INPUT_X.to_string(), x);
|
||||||
|
inputs.insert(INPUT_Y.to_string(), y);
|
||||||
|
|
||||||
|
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||||
|
assert_eq!(eval.len(), 1);
|
||||||
|
|
||||||
|
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||||
|
let results = z.to_vec2::<f32>()?;
|
||||||
|
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
// "Constant"
|
// "Constant"
|
||||||
// #[test]
|
// #[test]
|
||||||
|
|
||||||
|
@ -21,6 +21,7 @@ pub struct Config {
|
|||||||
pub num_key_value_heads: usize,
|
pub num_key_value_heads: usize,
|
||||||
pub rms_norm_eps: f64,
|
pub rms_norm_eps: f64,
|
||||||
pub rope_theta: f64,
|
pub rope_theta: f64,
|
||||||
|
pub rope_local_base_freq: f64,
|
||||||
pub vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
pub final_logit_softcapping: Option<f64>,
|
pub final_logit_softcapping: Option<f64>,
|
||||||
pub attn_logit_softcapping: Option<f64>,
|
pub attn_logit_softcapping: Option<f64>,
|
||||||
@ -67,12 +68,22 @@ struct RotaryEmbedding {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl RotaryEmbedding {
|
impl RotaryEmbedding {
|
||||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
fn new(
|
||||||
|
dtype: DType,
|
||||||
|
cfg: &Config,
|
||||||
|
dev: &Device,
|
||||||
|
sliding_window: Option<usize>,
|
||||||
|
) -> Result<Self> {
|
||||||
let dim = cfg.head_dim;
|
let dim = cfg.head_dim;
|
||||||
let max_seq_len = cfg.max_position_embeddings;
|
let max_seq_len = cfg.max_position_embeddings;
|
||||||
|
let rope_freq = if sliding_window.is_some() {
|
||||||
|
cfg.rope_local_base_freq
|
||||||
|
} else {
|
||||||
|
cfg.rope_theta
|
||||||
|
};
|
||||||
let inv_freq: Vec<_> = (0..dim)
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
.step_by(2)
|
.step_by(2)
|
||||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
|
||||||
.collect();
|
.collect();
|
||||||
let inv_freq_len = inv_freq.len();
|
let inv_freq_len = inv_freq.len();
|
||||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||||
@ -162,8 +173,8 @@ impl Attention {
|
|||||||
fn new(
|
fn new(
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
rotary_emb: Arc<RotaryEmbedding>,
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
is_sliding: bool,
|
|
||||||
cfg: &Config,
|
cfg: &Config,
|
||||||
|
sliding_window: Option<usize>,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let hidden_sz = cfg.hidden_size;
|
let hidden_sz = cfg.hidden_size;
|
||||||
@ -178,13 +189,13 @@ impl Attention {
|
|||||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||||
let kv_cache = if is_sliding {
|
let kv_cache = if let Some(sliding_window) = sliding_window {
|
||||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
|
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
|
||||||
2,
|
|
||||||
cfg.sliding_window,
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
|
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
|
||||||
|
2,
|
||||||
|
cfg.max_position_embeddings,
|
||||||
|
))
|
||||||
};
|
};
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
q_proj,
|
q_proj,
|
||||||
@ -302,21 +313,27 @@ struct DecoderLayer {
|
|||||||
pre_feedforward_layernorm: RmsNorm,
|
pre_feedforward_layernorm: RmsNorm,
|
||||||
post_feedforward_layernorm: RmsNorm,
|
post_feedforward_layernorm: RmsNorm,
|
||||||
post_attention_layernorm: RmsNorm,
|
post_attention_layernorm: RmsNorm,
|
||||||
|
sliding_window: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DecoderLayer {
|
impl DecoderLayer {
|
||||||
fn new(
|
fn new(
|
||||||
rotary_emb: Arc<RotaryEmbedding>,
|
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
is_sliding: bool,
|
|
||||||
cfg: &Config,
|
cfg: &Config,
|
||||||
vb: VarBuilder,
|
vb: VarBuilder,
|
||||||
|
sliding_window: Option<usize>,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
let rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||||
|
vb.dtype(),
|
||||||
|
cfg,
|
||||||
|
vb.device(),
|
||||||
|
sliding_window,
|
||||||
|
)?);
|
||||||
let self_attn = Attention::new(
|
let self_attn = Attention::new(
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
is_sliding,
|
|
||||||
cfg,
|
cfg,
|
||||||
|
sliding_window,
|
||||||
vb.pp("self_attn"),
|
vb.pp("self_attn"),
|
||||||
)?;
|
)?;
|
||||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||||
@ -344,6 +361,7 @@ impl DecoderLayer {
|
|||||||
pre_feedforward_layernorm,
|
pre_feedforward_layernorm,
|
||||||
post_feedforward_layernorm,
|
post_feedforward_layernorm,
|
||||||
post_attention_layernorm,
|
post_attention_layernorm,
|
||||||
|
sliding_window,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -370,6 +388,42 @@ impl DecoderLayer {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn prepare_decoder_attention_mask(
|
||||||
|
b_size: usize,
|
||||||
|
tgt_len: usize,
|
||||||
|
seqlen_offset: usize,
|
||||||
|
sliding_window: Option<usize>,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
|
||||||
|
(0..tgt_len)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..tgt_len).map(move |j| {
|
||||||
|
if i < j || j + sliding_window < i {
|
||||||
|
f32::NEG_INFINITY
|
||||||
|
} else {
|
||||||
|
0.
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
(0..tgt_len)
|
||||||
|
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
|
||||||
|
let mask = if seqlen_offset > 0 {
|
||||||
|
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||||
|
.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Model {
|
pub struct Model {
|
||||||
embed_tokens: candle_nn::Embedding,
|
embed_tokens: candle_nn::Embedding,
|
||||||
@ -388,17 +442,15 @@ impl Model {
|
|||||||
let vb_m = vb.pp("model");
|
let vb_m = vb.pp("model");
|
||||||
let embed_tokens =
|
let embed_tokens =
|
||||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
|
||||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||||
let vb_l = vb_m.pp("layers");
|
let vb_l = vb_m.pp("layers");
|
||||||
for layer_idx in 0..cfg.num_hidden_layers {
|
for layer_idx in 0..cfg.num_hidden_layers {
|
||||||
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||||
let layer = DecoderLayer::new(
|
let layer = DecoderLayer::new(
|
||||||
rotary_emb.clone(),
|
|
||||||
use_flash_attn,
|
use_flash_attn,
|
||||||
is_sliding,
|
|
||||||
cfg,
|
cfg,
|
||||||
vb_l.pp(layer_idx),
|
vb_l.pp(layer_idx),
|
||||||
|
sliding_window.then_some(cfg.sliding_window),
|
||||||
)?;
|
)?;
|
||||||
layers.push(layer)
|
layers.push(layer)
|
||||||
}
|
}
|
||||||
@ -417,51 +469,52 @@ impl Model {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepare_decoder_attention_mask(
|
fn create_attention_masks(
|
||||||
&self,
|
&self,
|
||||||
b_size: usize,
|
batch_size: usize,
|
||||||
tgt_len: usize,
|
seq_len: usize,
|
||||||
seqlen_offset: usize,
|
seqlen_offset: usize,
|
||||||
) -> Result<Tensor> {
|
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||||
let mask: Vec<_> = match Some(self.sliding_window) {
|
if seq_len <= 1 {
|
||||||
None => (0..tgt_len)
|
return Ok((None, None));
|
||||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
}
|
||||||
.collect(),
|
|
||||||
Some(sliding_window) => (0..tgt_len)
|
let mask = prepare_decoder_attention_mask(
|
||||||
.flat_map(|i| {
|
batch_size,
|
||||||
(0..tgt_len).map(move |j| {
|
seq_len,
|
||||||
if i < j || j + sliding_window < i {
|
seqlen_offset,
|
||||||
f32::NEG_INFINITY
|
None,
|
||||||
} else {
|
self.dtype,
|
||||||
0.
|
&self.device,
|
||||||
}
|
)?;
|
||||||
})
|
|
||||||
})
|
let sliding_mask = prepare_decoder_attention_mask(
|
||||||
.collect(),
|
batch_size,
|
||||||
};
|
seq_len,
|
||||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
seqlen_offset,
|
||||||
let mask = if seqlen_offset > 0 {
|
Some(self.sliding_window),
|
||||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
self.dtype,
|
||||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
&self.device,
|
||||||
} else {
|
)?;
|
||||||
mask
|
|
||||||
};
|
Ok((Some(mask), Some(sliding_mask)))
|
||||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
|
||||||
.to_dtype(self.dtype)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
let (b_size, seq_len) = input_ids.dims2()?;
|
let (b_size, seq_len) = input_ids.dims2()?;
|
||||||
let attention_mask = if seq_len <= 1 {
|
|
||||||
None
|
|
||||||
} else {
|
|
||||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
|
||||||
Some(mask)
|
|
||||||
};
|
|
||||||
let xs = self.embed_tokens.forward(input_ids)?;
|
let xs = self.embed_tokens.forward(input_ids)?;
|
||||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||||
|
|
||||||
|
let (attention_mask, sliding_attention_mask) =
|
||||||
|
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
|
||||||
|
|
||||||
for layer in self.layers.iter_mut() {
|
for layer in self.layers.iter_mut() {
|
||||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
let mask = if layer.sliding_window.is_some() {
|
||||||
|
&sliding_attention_mask
|
||||||
|
} else {
|
||||||
|
&attention_mask
|
||||||
|
};
|
||||||
|
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
|
||||||
}
|
}
|
||||||
let logits = xs
|
let logits = xs
|
||||||
.narrow(1, seq_len - 1, 1)?
|
.narrow(1, seq_len - 1, 1)?
|
||||||
|
@ -79,6 +79,7 @@ pub mod phi3;
|
|||||||
pub mod pixtral;
|
pub mod pixtral;
|
||||||
pub mod quantized_blip;
|
pub mod quantized_blip;
|
||||||
pub mod quantized_blip_text;
|
pub mod quantized_blip_text;
|
||||||
|
pub mod quantized_gemma3;
|
||||||
pub mod quantized_llama;
|
pub mod quantized_llama;
|
||||||
pub mod quantized_llama2_c;
|
pub mod quantized_llama2_c;
|
||||||
pub mod quantized_metavoice;
|
pub mod quantized_metavoice;
|
||||||
|
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
@ -0,0 +1,466 @@
|
|||||||
|
//! Gemma 3 model implementation with quantization support.
|
||||||
|
//!
|
||||||
|
//! Gemma 3 is a family of multimodal language models developed by Google.
|
||||||
|
//! This implementation provides quantization for reduced memory usage and faster inference.
|
||||||
|
//!
|
||||||
|
//! Key characteristics:
|
||||||
|
//! - Group-Query Attention (GQA) with specialized key-value heads
|
||||||
|
//! - RMSNorm for layer normalization
|
||||||
|
//! - Specialized attention patterns with separate normalization for Q/K/V
|
||||||
|
//! - Feed-forward network with SwiGLU activation
|
||||||
|
//! - Support for 2/3/4/8-bit quantization
|
||||||
|
//!
|
||||||
|
//! References:
|
||||||
|
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||||
|
//!
|
||||||
|
|
||||||
|
use crate::quantized_nn::RmsNorm;
|
||||||
|
use candle::quantized::gguf_file;
|
||||||
|
use candle::quantized::QTensor;
|
||||||
|
use candle::D;
|
||||||
|
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||||
|
use candle_nn::{Embedding, Module};
|
||||||
|
|
||||||
|
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||||
|
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
|
||||||
|
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
|
||||||
|
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
|
||||||
|
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct QMatMul {
|
||||||
|
inner: candle::quantized::QMatMul,
|
||||||
|
span: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl QMatMul {
|
||||||
|
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||||
|
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||||
|
Ok(Self { inner, span })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
self.inner.forward(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Mlp {
|
||||||
|
feed_forward_gate: QMatMul, // ffn_gate in GGUF
|
||||||
|
feed_forward_up: QMatMul, // ffn_up in GGUF
|
||||||
|
feed_forward_down: QMatMul, // ffn_down in GGUF
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Mlp {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let gate = self.feed_forward_gate.forward(xs)?;
|
||||||
|
let up = self.feed_forward_up.forward(xs)?;
|
||||||
|
let silu = candle_nn::ops::silu(&gate)?;
|
||||||
|
let gated = (silu * up)?;
|
||||||
|
self.feed_forward_down.forward(&gated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct RotaryEmbedding {
|
||||||
|
sin: Tensor,
|
||||||
|
cos: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RotaryEmbedding {
|
||||||
|
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
|
||||||
|
let theta: Vec<_> = (0..head_dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
|
||||||
|
.collect();
|
||||||
|
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||||
|
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||||
|
.to_dtype(DType::F32)?
|
||||||
|
.reshape((MAX_SEQ_LEN, 1))?
|
||||||
|
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||||
|
let cos = idx_theta.cos()?;
|
||||||
|
let sin = idx_theta.sin()?;
|
||||||
|
Ok(Self { sin, cos })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn apply_rotary_emb_qkv(
|
||||||
|
&self,
|
||||||
|
q: &Tensor,
|
||||||
|
k: &Tensor,
|
||||||
|
index_pos: usize,
|
||||||
|
) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||||
|
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||||
|
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||||
|
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||||
|
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||||
|
Ok((q_embed, k_embed))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LayerWeights {
|
||||||
|
// Attention components
|
||||||
|
attention_wq: QMatMul,
|
||||||
|
attention_wk: QMatMul,
|
||||||
|
attention_wv: QMatMul,
|
||||||
|
attention_wo: QMatMul,
|
||||||
|
|
||||||
|
// Specialized normalization for Q and K
|
||||||
|
attention_q_norm: RmsNorm,
|
||||||
|
attention_k_norm: RmsNorm,
|
||||||
|
|
||||||
|
// Layer normalization
|
||||||
|
attention_norm: RmsNorm, // Applied before attention
|
||||||
|
post_attention_norm: RmsNorm, // Applied after attention
|
||||||
|
ffn_norm: RmsNorm, // Applied before feedforward
|
||||||
|
post_ffn_norm: RmsNorm, // Applied after feedforward
|
||||||
|
|
||||||
|
// Feed-forward network
|
||||||
|
mlp: Mlp,
|
||||||
|
|
||||||
|
// Attention parameters
|
||||||
|
n_head: usize, // Number of query heads
|
||||||
|
n_kv_head: usize, // Number of key-value heads
|
||||||
|
head_dim: usize, // Dimension of each head
|
||||||
|
q_dim: usize, // Total dimension for queries
|
||||||
|
|
||||||
|
sliding_window_size: Option<usize>,
|
||||||
|
|
||||||
|
rotary_embedding: RotaryEmbedding,
|
||||||
|
neg_inf: Tensor,
|
||||||
|
|
||||||
|
// Cache
|
||||||
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
|
|
||||||
|
// Tracing
|
||||||
|
span_attn: tracing::Span,
|
||||||
|
span_mlp: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LayerWeights {
|
||||||
|
fn mask(
|
||||||
|
&self,
|
||||||
|
b_sz: usize,
|
||||||
|
seq_len: usize,
|
||||||
|
index_pos: usize,
|
||||||
|
dtype: DType,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
|
||||||
|
(0..seq_len)
|
||||||
|
.flat_map(|i| {
|
||||||
|
(0..seq_len).map(move |j| {
|
||||||
|
if i < j || j + sliding_window_size < i {
|
||||||
|
0u32
|
||||||
|
} else {
|
||||||
|
1u32
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
(0..seq_len)
|
||||||
|
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||||
|
let mask = if index_pos > 0 {
|
||||||
|
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
|
||||||
|
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||||
|
} else {
|
||||||
|
mask
|
||||||
|
};
|
||||||
|
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
|
||||||
|
.to_dtype(dtype)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward_attn(
|
||||||
|
&mut self,
|
||||||
|
x: &Tensor,
|
||||||
|
mask: Option<&Tensor>,
|
||||||
|
index_pos: usize,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
let _enter = self.span_attn.enter();
|
||||||
|
let (b_sz, seq_len, _) = x.dims3()?;
|
||||||
|
|
||||||
|
let q = self.attention_wq.forward(x)?;
|
||||||
|
let k = self.attention_wk.forward(x)?;
|
||||||
|
let v = self.attention_wv.forward(x)?;
|
||||||
|
|
||||||
|
let q = q
|
||||||
|
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||||
|
.transpose(1, 2)?;
|
||||||
|
|
||||||
|
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||||
|
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||||
|
|
||||||
|
let (q, k) = self
|
||||||
|
.rotary_embedding
|
||||||
|
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||||
|
|
||||||
|
let (k, v) = match &self.kv_cache {
|
||||||
|
None => (k, v),
|
||||||
|
Some((k_cache, v_cache)) => {
|
||||||
|
if index_pos == 0 {
|
||||||
|
(k, v)
|
||||||
|
} else {
|
||||||
|
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
|
||||||
|
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||||
|
(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
self.kv_cache = Some((k.clone(), v.clone())); // update cache
|
||||||
|
|
||||||
|
// Repeat KV for GQA
|
||||||
|
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||||
|
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||||
|
|
||||||
|
// Scaled Dot-Product Attention
|
||||||
|
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||||
|
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
|
||||||
|
if let Some(mask) = mask {
|
||||||
|
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||||
|
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
|
||||||
|
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
let attn_output = attn_weights.matmul(&v)?;
|
||||||
|
|
||||||
|
let attn_output = attn_output
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b_sz, seq_len, self.q_dim))?;
|
||||||
|
|
||||||
|
self.attention_wo.forward(&attn_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ModelWeights {
|
||||||
|
tok_embeddings: Embedding,
|
||||||
|
embedding_length: usize,
|
||||||
|
layers: Vec<LayerWeights>,
|
||||||
|
norm: RmsNorm,
|
||||||
|
output: QMatMul,
|
||||||
|
span: tracing::Span,
|
||||||
|
span_output: tracing::Span,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelWeights {
|
||||||
|
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
|
ct: gguf_file::Content,
|
||||||
|
reader: &mut R,
|
||||||
|
device: &Device,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||||
|
None => candle::bail!("cannot find {s} in metadata"),
|
||||||
|
Some(v) => Ok(v),
|
||||||
|
};
|
||||||
|
|
||||||
|
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
|
||||||
|
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
|
||||||
|
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
|
||||||
|
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
|
||||||
|
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||||
|
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||||
|
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||||
|
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
|
||||||
|
|
||||||
|
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
|
||||||
|
.and_then(|m| Ok(m.to_u32()? as usize))
|
||||||
|
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
|
||||||
|
|
||||||
|
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||||
|
.and_then(|m| m.to_f32())
|
||||||
|
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
|
||||||
|
|
||||||
|
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
|
||||||
|
.and_then(|m| m.to_f32())
|
||||||
|
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
|
||||||
|
|
||||||
|
// Unused in Llama.cpp so we aren't using it here.
|
||||||
|
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||||
|
.and_then(|m| m.to_f32())
|
||||||
|
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
|
||||||
|
|
||||||
|
// Compute the dimensions for queries, keys, and values
|
||||||
|
// These are the total dimensions when projected across all heads
|
||||||
|
let q_dim = head_count * key_length;
|
||||||
|
|
||||||
|
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||||
|
|
||||||
|
// Load token embeddings and output projection
|
||||||
|
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||||
|
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||||
|
let norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(reader, "output_norm.weight", device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
let output = match ct.tensor(reader, "output.weight", device) {
|
||||||
|
Ok(tensor) => tensor,
|
||||||
|
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut layers = Vec::with_capacity(block_count);
|
||||||
|
for layer_idx in 0..block_count {
|
||||||
|
let prefix = format!("blk.{layer_idx}");
|
||||||
|
|
||||||
|
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||||
|
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||||
|
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||||
|
let attention_wo =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||||
|
|
||||||
|
let attention_q_norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let attention_k_norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let attention_norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let post_attention_norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(
|
||||||
|
reader,
|
||||||
|
&format!("{prefix}.post_attention_norm.weight"),
|
||||||
|
device,
|
||||||
|
)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let ffn_norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let post_ffn_norm = RmsNorm::from_qtensor(
|
||||||
|
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
|
||||||
|
rms_norm_eps,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let feed_forward_gate =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||||
|
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||||
|
let feed_forward_down =
|
||||||
|
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||||
|
|
||||||
|
let mlp = Mlp {
|
||||||
|
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
|
||||||
|
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
|
||||||
|
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
|
||||||
|
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
|
||||||
|
let sliding_window_size = is_sliding.then_some(sliding_window_size);
|
||||||
|
let layer_rope_frequency = if is_sliding {
|
||||||
|
rope_freq_base_sliding
|
||||||
|
} else {
|
||||||
|
rope_freq_base
|
||||||
|
};
|
||||||
|
|
||||||
|
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
|
||||||
|
|
||||||
|
// Tracing spans
|
||||||
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
|
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||||
|
|
||||||
|
layers.push(LayerWeights {
|
||||||
|
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||||
|
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||||
|
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||||
|
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||||
|
attention_q_norm,
|
||||||
|
attention_k_norm,
|
||||||
|
attention_norm,
|
||||||
|
post_attention_norm,
|
||||||
|
ffn_norm,
|
||||||
|
post_ffn_norm,
|
||||||
|
mlp,
|
||||||
|
n_head: head_count,
|
||||||
|
n_kv_head: head_count_kv,
|
||||||
|
head_dim: key_length,
|
||||||
|
q_dim,
|
||||||
|
sliding_window_size,
|
||||||
|
rotary_embedding,
|
||||||
|
neg_inf: neg_inf.clone(),
|
||||||
|
kv_cache: None,
|
||||||
|
span_attn,
|
||||||
|
span_mlp,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||||
|
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||||
|
embedding_length,
|
||||||
|
layers,
|
||||||
|
norm,
|
||||||
|
output: QMatMul::from_qtensor(output)?,
|
||||||
|
span,
|
||||||
|
span_output,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||||
|
let (b_sz, seq_len) = x.dims2()?;
|
||||||
|
let _enter = self.span.enter();
|
||||||
|
|
||||||
|
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||||
|
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||||
|
|
||||||
|
for layer in self.layers.iter_mut() {
|
||||||
|
let attention_mask = if seq_len == 1 {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Attention block
|
||||||
|
let residual = &layer_in;
|
||||||
|
let x = layer.attention_norm.forward(&layer_in)?;
|
||||||
|
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
|
||||||
|
let x = layer.post_attention_norm.forward(&x)?;
|
||||||
|
let x = (x + residual)?;
|
||||||
|
|
||||||
|
// Feed-forward block
|
||||||
|
let _enter = layer.span_mlp.enter();
|
||||||
|
let residual = &x;
|
||||||
|
let x = layer.ffn_norm.forward(&x)?;
|
||||||
|
let x = layer.mlp.forward(&x)?;
|
||||||
|
let x = layer.post_ffn_norm.forward(&x)?;
|
||||||
|
let x = (x + residual)?;
|
||||||
|
drop(_enter);
|
||||||
|
|
||||||
|
layer_in = x;
|
||||||
|
}
|
||||||
|
|
||||||
|
let _enter = self.span_output.enter();
|
||||||
|
|
||||||
|
let x = layer_in.i((.., seq_len - 1, ..))?;
|
||||||
|
let x = self.norm.forward(&x)?;
|
||||||
|
let output = self.output.forward(&x)?;
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user