mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 18:48:51 +00:00
Compare commits
13 Commits
0.9.0-alph
...
0.9.0
Author | SHA1 | Date | |
---|---|---|---|
fbaf0b0e32 | |||
a2e925462c | |||
3827685524 | |||
3aeb9575c7 | |||
6ff0a6999c | |||
82def7ae38 | |||
99bd69f383 | |||
a4c56a958e | |||
b2904a830b | |||
21055b5697 | |||
9dbaf958dc | |||
ce5f8dd129 | |||
9954981327 |
20
Cargo.toml
20
Cargo.toml
@ -20,7 +20,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -33,17 +33,17 @@ ab_glyph = "0.2.23"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0-alpha.4" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0-alpha.4" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0-alpha.4" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0-alpha.4" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0-alpha.4" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0-alpha.4" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0-alpha.4" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.9.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.9.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.9.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.9.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.9.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.9.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.9.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.16.0", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
cudarc = { version = "0.16.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.4.1"
|
||||
|
@ -71,15 +71,27 @@ pub trait BackendStorage: Sized {
|
||||
fn upsample_nearest2d(&self, _: &Layout, _: usize, _: usize) -> Result<Self>;
|
||||
|
||||
fn gather(&self, _: &Layout, _: &Self, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn scatter_add(
|
||||
&self,
|
||||
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self>;
|
||||
) -> Result<()>;
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn index_select(&self, _: &Self, _: &Layout, _: &Layout, _: usize) -> Result<Self>;
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -113,6 +125,8 @@ pub trait BackendStorage: Sized {
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
@ -127,8 +141,6 @@ pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
||||
fn zeros_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage>;
|
||||
|
||||
/// # Safety
|
||||
/// This function is unsafe as it doesn't initialize the underlying data store.
|
||||
/// The caller should ensure that the data is properly initialized as early as possible
|
||||
|
@ -53,6 +53,7 @@ impl Tensor {
|
||||
} else if let Some(op) = node.op() {
|
||||
match op {
|
||||
Op::IndexAdd(t1, t2, t3, _)
|
||||
| Op::Scatter(t1, t2, t3, _)
|
||||
| Op::ScatterAdd(t1, t2, t3, _)
|
||||
| Op::CustomOp3(t1, t2, t3, _)
|
||||
| Op::WhereCond(t1, t2, t3) => {
|
||||
@ -419,7 +420,7 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.scatter_add(indexes, &grad, *dim)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
Op::Scatter(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
||||
@ -427,6 +428,16 @@ impl Tensor {
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::ScatterAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
let mask = init.ones_like()?;
|
||||
let mask = mask.scatter(indexes, &mask.zeros_like()?, *dim)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad.mul(&mask)?)?;
|
||||
|
||||
let src_grad = grad.gather(indexes, *dim)?;
|
||||
let src_sum_grad = grads.or_insert(src)?;
|
||||
*src_sum_grad = src_sum_grad.add(&src_grad)?;
|
||||
}
|
||||
Op::IndexAdd(init, indexes, src, dim) => {
|
||||
let init_sum_grad = grads.or_insert(init)?;
|
||||
*init_sum_grad = init_sum_grad.add(&grad)?;
|
||||
|
@ -7,7 +7,7 @@ use rayon::prelude::*;
|
||||
|
||||
mod utils;
|
||||
pub use utils::{
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2U8,
|
||||
binary_map, binary_map_vec, unary_map, unary_map_vec, Map1, Map1Any, Map2, Map2InPlace, Map2U8,
|
||||
};
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
@ -554,26 +554,65 @@ impl<I: IntDType> Map1 for IndexSelect<'_, I> {
|
||||
}
|
||||
}
|
||||
|
||||
struct ScatterAdd<'a, I: IntDType> {
|
||||
trait ElemUpdate {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T);
|
||||
}
|
||||
|
||||
struct Set;
|
||||
struct Add;
|
||||
|
||||
impl ElemUpdate for Set {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst = src
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemUpdate for Add {
|
||||
fn f<T: WithDType>(dst: &mut T, src: T) {
|
||||
*dst += src
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a, I: IntDType, M: ElemUpdate> {
|
||||
ids: &'a [I],
|
||||
ids_l: &'a Layout,
|
||||
dim: usize,
|
||||
_phantom: std::marker::PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
const OP: &'static str = "scatter-add";
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, src: &[T], src_l: &Layout) -> Result<Vec<T>> {
|
||||
let dst_len = l1.shape().elem_count();
|
||||
let mut dst = vec![T::zero(); dst_len];
|
||||
copy_strided_src_(v1, &mut dst, 0, l1);
|
||||
impl<'a, I: IntDType, M: ElemUpdate> Scatter<'a, I, M> {
|
||||
fn new(ids: &'a [I], ids_l: &'a Layout, dim: usize) -> Self {
|
||||
Self {
|
||||
ids,
|
||||
ids_l,
|
||||
dim,
|
||||
_phantom: Default::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I: IntDType, M: ElemUpdate> Map2InPlace for Scatter<'_, I, M> {
|
||||
const OP: &'static str = "scatter";
|
||||
fn f<T: WithDType>(
|
||||
&self,
|
||||
dst: &mut [T],
|
||||
dst_l: &Layout,
|
||||
src: &[T],
|
||||
src_l: &Layout,
|
||||
) -> Result<()> {
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &mut dst[o1..o2],
|
||||
};
|
||||
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
None => Err(Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
None => Err(Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
Some((o1, o2)) => &src[o1..o2],
|
||||
};
|
||||
|
||||
let dim = self.dim;
|
||||
let ids_dims = self.ids_l.dims();
|
||||
let dst_dims = l1.dims();
|
||||
let dst_dims = dst_l.dims();
|
||||
let dst_dim_len = dst_dims[dim];
|
||||
let dst_right_len: usize = dst_dims[dim + 1..].iter().product();
|
||||
|
||||
@ -602,12 +641,12 @@ impl<I: IntDType> Map2 for ScatterAdd<'_, I> {
|
||||
.bt())?
|
||||
}
|
||||
let dst_idx = start_dst_idx + index * dst_right_len + right_i;
|
||||
dst[dst_idx] += src[ids_idx]
|
||||
M::f(&mut dst[dst_idx], src[ids_idx])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dst)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@ -2381,19 +2420,36 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U8(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Set>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter").bt()),
|
||||
}
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
match ids {
|
||||
Self::U8(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::U32(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
Self::I64(ids) => Scatter::<_, Add>::new(ids, ids_l, dim).map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
}
|
||||
}
|
||||
@ -2454,6 +2510,48 @@ impl BackendStorage for CpuStorage {
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<T: crate::WithDType>(src: &mut [T], l: &Layout, s: T) {
|
||||
match l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
src[start_offset..start_offset + len].fill(s)
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len: 1,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index] = s
|
||||
}
|
||||
}
|
||||
crate::StridedBlocks::MultipleBlocks {
|
||||
block_start_index,
|
||||
block_len,
|
||||
} => {
|
||||
for src_index in block_start_index {
|
||||
src[src_index..src_index + block_len].fill(s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
match (self, s) {
|
||||
(Self::BF16(storage), Scalar::BF16(v)) => set(storage, l, v),
|
||||
(Self::F16(storage), Scalar::F16(v)) => set(storage, l, v),
|
||||
(Self::F32(storage), Scalar::F32(v)) => set(storage, l, v),
|
||||
(Self::F64(storage), Scalar::F64(v)) => set(storage, l, v),
|
||||
(Self::U8(storage), Scalar::U8(v)) => set(storage, l, v),
|
||||
(Self::U32(storage), Scalar::U32(v)) => set(storage, l, v),
|
||||
(Self::I64(storage), Scalar::I64(v)) => set(storage, l, v),
|
||||
(st, s) => crate::bail!(
|
||||
"const_set dtype mismatch, expected {:?} but got {:?}",
|
||||
st.dtype(),
|
||||
s
|
||||
),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl BackendDevice for CpuDevice {
|
||||
@ -2628,20 +2726,6 @@ impl BackendDevice for CpuDevice {
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
DType::U8 => CpuStorage::U8(vec![1u8; elem_count]),
|
||||
DType::U32 => CpuStorage::U32(vec![1u32; elem_count]),
|
||||
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
|
||||
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
|
||||
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
|
||||
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
|
||||
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
|
||||
};
|
||||
Ok(storage)
|
||||
}
|
||||
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<CpuStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let storage = match dtype {
|
||||
|
@ -58,6 +58,30 @@ pub trait Map2 {
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2InPlace {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &mut [T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<()>;
|
||||
|
||||
fn map(&self, v1: &mut C, l1: &Layout, v2: &C, l2: &Layout) -> Result<()> {
|
||||
match (v1, v2) {
|
||||
(C::U8(v1), C::U8(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::U32(v1), C::U32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::I64(v1), C::I64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::BF16(v1), C::BF16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F16(v1), C::F16(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F32(v1), C::F32(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(C::F64(v1), C::F64(v2)) => self.f(v1, l1, v2, l2)?,
|
||||
(v1, v2) => Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: v1.dtype(),
|
||||
rhs: v2.dtype(),
|
||||
op: Self::OP,
|
||||
}
|
||||
.bt())?,
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub trait Map2U8 {
|
||||
const OP: &'static str;
|
||||
fn f<T: WithDType>(&self, v1: &[T], l1: &Layout, v2: &[T], l2: &Layout) -> Result<Vec<u8>>;
|
||||
|
@ -2,7 +2,7 @@ use crate::backend::BackendDevice;
|
||||
use crate::{CpuStorage, CpuStorageRef, DType, Layout, Result, Shape};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaFunction, LaunchConfig, PushKernelArg};
|
||||
use cudarc::driver::CudaFunction;
|
||||
use half::{bf16, f16};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
@ -188,100 +188,6 @@ impl CudaDevice {
|
||||
self.id
|
||||
}
|
||||
|
||||
fn const_impl(&self, v: f64, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(elem_count as u32);
|
||||
let slice = match dtype {
|
||||
DType::U8 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u8>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u8", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u8;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U8(data)
|
||||
}
|
||||
DType::U32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<u32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_u32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as u32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::U32(data)
|
||||
}
|
||||
DType::I64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<i64>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_i64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as i64;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::I64(data)
|
||||
}
|
||||
DType::BF16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<bf16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_bf16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = bf16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::BF16(data)
|
||||
}
|
||||
DType::F16 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f16>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f16", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = f16::from_f64(v);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F16(data)
|
||||
}
|
||||
DType::F32 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f32>(elem_count)? };
|
||||
let func = self.get_or_load_func("fill_f32", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
let v = v as f32;
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F32(data)
|
||||
}
|
||||
DType::F64 => {
|
||||
// SAFETY: Set later by running the fill kernel.
|
||||
let data = unsafe { self.alloc::<f64>(elem_count) }?;
|
||||
let func = self.get_or_load_func("fill_f64", &kernels::FILL)?;
|
||||
let mut builder = self.stream.launch_builder(&func);
|
||||
builder.arg(&data);
|
||||
builder.arg(&v);
|
||||
builder.arg(&elem_count);
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
CudaStorageSlice::F64(data)
|
||||
}
|
||||
};
|
||||
Ok(CudaStorage {
|
||||
slice,
|
||||
device: self.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get_or_load_custom_func(
|
||||
&self,
|
||||
fn_name: &str,
|
||||
@ -504,10 +410,6 @@ impl BackendDevice for CudaDevice {
|
||||
})
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<CudaStorage> {
|
||||
self.const_impl(1., shape, dtype)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
let elem_count = shape.elem_count();
|
||||
let slice = match dtype {
|
||||
|
@ -2,7 +2,7 @@
|
||||
//!
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, Shape, WithDType};
|
||||
use crate::{builder_arg as barg, CpuStorage, DType, Layout, Result, WithDType};
|
||||
pub use candle_kernels as kernels;
|
||||
pub use cudarc;
|
||||
use cudarc::cublas::{Gemm, GemmConfig, StridedBatchedConfig};
|
||||
@ -34,6 +34,21 @@ impl<T: DeviceRepr> SlicePtrOrNull<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::scalar::Scalar {
|
||||
pub fn builder_arg<'a, 'b: 'a>(&'b self, builder: &mut cudarc::driver::LaunchArgs<'a>) {
|
||||
use crate::scalar::Scalar;
|
||||
match self {
|
||||
Scalar::U8(v) => builder.arg(v),
|
||||
Scalar::U32(v) => builder.arg(v),
|
||||
Scalar::I64(v) => builder.arg(v),
|
||||
Scalar::F32(v) => builder.arg(v),
|
||||
Scalar::F64(v) => builder.arg(v),
|
||||
Scalar::F16(v) => builder.arg(v),
|
||||
Scalar::BF16(v) => builder.arg(v),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
impl SlicePtrOrNull<usize> {
|
||||
pub fn params_from_layout(dev: &CudaDevice, l: &Layout) -> Result<Self> {
|
||||
let ds = if l.is_contiguous() {
|
||||
@ -395,7 +410,7 @@ impl Map1 for IndexSelect<'_> {
|
||||
CudaStorageSlice::U8(slice) => ("is_u8", slice_ptr(slice, ids_l.start_offset())),
|
||||
CudaStorageSlice::I64(slice) => ("is_i64", slice_ptr(slice, ids_l.start_offset())),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "index_select ids should be u8 or u32",
|
||||
msg: "index_select ids should be u8, u32, or i64",
|
||||
expected: DType::U32,
|
||||
got: self.0.dtype(),
|
||||
})
|
||||
@ -492,7 +507,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -514,6 +529,10 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "index-add" }.bt())?,
|
||||
@ -521,7 +540,7 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let ids_dim_sz = ids_l.dims()[0];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
@ -529,7 +548,59 @@ impl Map2InPlace for IndexAdd<'_> {
|
||||
barg!(builder, ids);
|
||||
barg!(builder, ids_dim_sz);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
struct Scatter<'a>(&'a CudaStorage, &'a Layout, usize);
|
||||
impl Map2InPlace for Scatter<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
let ids = &self.0;
|
||||
let ids_l = &self.1;
|
||||
let dim = self.2;
|
||||
let (ids_o1, _) = match ids_l.contiguous_offsets() {
|
||||
Some(o12) => o12,
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let (name, (ids, _guard)) = match &ids.slice {
|
||||
CudaStorageSlice::U32(slice) => ("s_u32", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::I64(slice) => ("s_i64", slice_ptr(slice, ids_o1)),
|
||||
CudaStorageSlice::U8(slice) => ("s_u8", slice_ptr(slice, ids_o1)),
|
||||
_ => Err(CudaError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter" }.bt())?,
|
||||
};
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -542,7 +613,7 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -564,6 +635,10 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let dst = match dst_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => dst.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let src = match src_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => src.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
@ -571,13 +646,13 @@ impl Map2InPlace for ScatterAdd<'_> {
|
||||
let left_sz: usize = src_l.dims()[..dim].iter().product();
|
||||
let right_sz: usize = src_l.dims()[dim + 1..].iter().product();
|
||||
let src_dim_sz = src_l.dims()[dim];
|
||||
let dst_dim_sz = dst_shape.dims()[dim];
|
||||
let dst_dim_sz = dst_l.dims()[dim];
|
||||
let cfg = LaunchConfig::for_num_elems((left_sz * right_sz) as u32);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>(name), &kernels::INDEXING)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, ids);
|
||||
builder.arg(&src);
|
||||
builder.arg(dst);
|
||||
builder.arg(&dst);
|
||||
barg!(builder, left_sz, src_dim_sz, dst_dim_sz, right_sz);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
@ -1235,6 +1310,36 @@ impl BackendStorage for CudaStorage {
|
||||
&self.device
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, layout: &Layout) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
let el_count = shape.elem_count();
|
||||
let cfg = LaunchConfig::for_num_elems(el_count as u32);
|
||||
let ds = SlicePtrOrNull::params_from_layout(dev, layout)?;
|
||||
let src_o = layout.start_offset();
|
||||
let ((src, _guard_src), kernel_name) = match &mut self.slice {
|
||||
S::U8(s) => (slice_ptr(s, src_o), "const_set_u8"),
|
||||
S::U32(s) => (slice_ptr(s, src_o), "const_set_u32"),
|
||||
S::I64(s) => (slice_ptr(s, src_o), "const_set_i64"),
|
||||
S::BF16(s) => (slice_ptr(s, src_o), "const_set_bf16"),
|
||||
S::F16(s) => (slice_ptr(s, src_o), "const_set_f16"),
|
||||
S::F32(s) => (slice_ptr(s, src_o), "const_set_f32"),
|
||||
S::F64(s) => (slice_ptr(s, src_o), "const_set_f64"),
|
||||
};
|
||||
|
||||
let func = dev.get_or_load_func(kernel_name, &kernels::FILL)?;
|
||||
let mut builder = func.builder();
|
||||
barg!(builder, el_count);
|
||||
barg!(builder, dims.len());
|
||||
ds.builder_arg(&mut builder);
|
||||
s.builder_arg(&mut builder);
|
||||
barg!(builder, src);
|
||||
// SAFETY: ffi.
|
||||
unsafe { builder.launch(cfg) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let shape = layout.shape();
|
||||
let dims = shape.dims();
|
||||
@ -1793,20 +1898,29 @@ impl BackendStorage for CudaStorage {
|
||||
let slice = Gather(ids, ids_l, dim).map(&self.slice, &device, l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
Scatter(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
let device = self.device().clone();
|
||||
ScatterAdd(ids, ids_l, dim).map(&mut self.slice, l, &src.slice, src_l, &device)
|
||||
}
|
||||
fn index_add(
|
||||
&self,
|
||||
@ -1820,7 +1934,7 @@ impl BackendStorage for CudaStorage {
|
||||
let device = self.device().clone();
|
||||
let mut acc = unsafe { device.alloc_uninit(l.shape(), self.dtype())? };
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l.shape(), &src.slice, src_l, &device)?;
|
||||
IndexAdd(ids, ids_l, dim).map(&mut acc.slice, l, &src.slice, src_l, &device)?;
|
||||
Ok(acc)
|
||||
}
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
/// Helper functions to plug cuda kernels in candle.
|
||||
use crate::{Layout, Result, Shape, WithDType};
|
||||
use crate::{Layout, Result, WithDType};
|
||||
pub use cudarc;
|
||||
use cudarc::driver::{CudaSlice, DeviceRepr, ValidAsZeroBits};
|
||||
|
||||
@ -96,7 +96,7 @@ pub trait Map2InPlace {
|
||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||
&self,
|
||||
dst: &mut CudaSlice<T>,
|
||||
dst_shape: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &CudaSlice<T>,
|
||||
src_l: &Layout,
|
||||
dev: &CudaDevice,
|
||||
@ -105,19 +105,19 @@ pub trait Map2InPlace {
|
||||
fn map(
|
||||
&self,
|
||||
dst: &mut S,
|
||||
dst_s: &Shape,
|
||||
dst_l: &Layout,
|
||||
src: &S,
|
||||
src_l: &Layout,
|
||||
d: &CudaDevice,
|
||||
) -> Result<()> {
|
||||
match (dst, src) {
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_s, src, src_l, d),
|
||||
(S::U8(dst), S::U8(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::U32(dst), S::U32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::I64(dst), S::I64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::BF16(dst), S::BF16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F16(dst), S::F16(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F32(dst), S::F32(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
(S::F64(dst), S::F64(src)) => self.f(dst, dst_l, src, src_l, d),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in binary op"))?,
|
||||
}
|
||||
}
|
||||
|
@ -292,23 +292,6 @@ impl Device {
|
||||
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
|
||||
}
|
||||
|
||||
pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
let storage = CpuDevice.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cpu(storage))
|
||||
}
|
||||
Device::Cuda(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Cuda(storage))
|
||||
}
|
||||
Device::Metal(device) => {
|
||||
let storage = device.ones_impl(shape, dtype)?;
|
||||
Ok(Storage::Metal(storage))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn zeros(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
|
||||
match self {
|
||||
Device::Cpu => {
|
||||
|
@ -107,6 +107,7 @@ pub trait WithDType:
|
||||
|
||||
fn from_f64(v: f64) -> Self;
|
||||
fn to_f64(self) -> f64;
|
||||
fn to_scalar(self) -> crate::scalar::Scalar;
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_>;
|
||||
fn to_cpu_storage_owned(data: Vec<Self>) -> CpuStorage;
|
||||
|
||||
@ -131,6 +132,10 @@ macro_rules! with_dtype {
|
||||
$to_f64(self)
|
||||
}
|
||||
|
||||
fn to_scalar(self) -> crate::scalar::Scalar {
|
||||
crate::scalar::Scalar::$dtype(self)
|
||||
}
|
||||
|
||||
fn cpu_storage_ref(data: &[Self]) -> CpuStorageRef<'_> {
|
||||
CpuStorageRef::$dtype(data)
|
||||
}
|
||||
|
@ -37,6 +37,10 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
@ -124,15 +128,27 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
@ -214,10 +230,6 @@ impl crate::backend::BackendDevice for CudaDevice {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -41,6 +41,10 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
fail!()
|
||||
}
|
||||
|
||||
fn const_set(&mut self, _: crate::scalar::Scalar, _: &Layout) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn to_cpu_storage(&self) -> Result<CpuStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
@ -128,15 +132,27 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: &Self,
|
||||
_: &Layout,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
@ -218,10 +234,6 @@ impl crate::backend::BackendDevice for MetalDevice {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn ones_impl(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
unsafe fn alloc_uninit(&self, _shape: &Shape, _dtype: DType) -> Result<Self::Storage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -413,6 +413,100 @@ impl BackendStorage for MetalStorage {
|
||||
self.binary(name, rhs, lhs_l, rhs_l)
|
||||
}
|
||||
|
||||
fn const_set(&mut self, s: crate::scalar::Scalar, l: &Layout) -> Result<()> {
|
||||
use crate::scalar::Scalar;
|
||||
fn set<S: crate::WithDType + candle_metal_kernels::utils::EncoderParam>(
|
||||
self_: &mut MetalStorage,
|
||||
s: S,
|
||||
l: &Layout,
|
||||
) -> Result<()> {
|
||||
let device = self_.device();
|
||||
let dtype = self_.dtype;
|
||||
let shape = l.shape();
|
||||
let el_count = shape.elem_count();
|
||||
let command_buffer = device.command_buffer()?;
|
||||
command_buffer.set_label("const-set");
|
||||
let dst = buffer_o(&self_.buffer, l, self_.dtype);
|
||||
|
||||
match (el_count % 2, dtype, l.is_contiguous()) {
|
||||
(0, DType::BF16 | DType::F16, true) => {
|
||||
use candle_metal_kernels::unary::contiguous_tiled;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous_tiled::const_set::HALF,
|
||||
DType::BF16 => contiguous_tiled::const_set::BFLOAT,
|
||||
_ => crate::bail!("internal bug in const_set"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous_tiled(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, true) => {
|
||||
use candle_metal_kernels::unary::contiguous;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => contiguous::const_set::HALF,
|
||||
DType::BF16 => contiguous::const_set::BFLOAT,
|
||||
DType::F32 => contiguous::const_set::FLOAT,
|
||||
DType::I64 => contiguous::const_set::I64,
|
||||
DType::U32 => contiguous::const_set::U32,
|
||||
DType::U8 => contiguous::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_contiguous(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
el_count,
|
||||
s,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
(_, _, false) => {
|
||||
use candle_metal_kernels::unary::strided;
|
||||
let kernel_name = match dtype {
|
||||
DType::F16 => strided::const_set::HALF,
|
||||
DType::BF16 => strided::const_set::BFLOAT,
|
||||
DType::F32 => strided::const_set::FLOAT,
|
||||
DType::I64 => strided::const_set::I64,
|
||||
DType::U32 => strided::const_set::U32,
|
||||
DType::U8 => strided::const_set::U8,
|
||||
DType::F64 => crate::bail!("unsupported const-set f64"),
|
||||
};
|
||||
candle_metal_kernels::call_const_set_strided(
|
||||
&device.device,
|
||||
&command_buffer,
|
||||
&device.kernels,
|
||||
kernel_name,
|
||||
l.dims(),
|
||||
s,
|
||||
l.stride(),
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
match (self.dtype, s) {
|
||||
(DType::U8, Scalar::U8(s)) => set(self, s, l),
|
||||
(DType::U32, Scalar::U32(s)) => set(self, s, l),
|
||||
(DType::I64, Scalar::I64(s)) => set(self, s, l),
|
||||
(DType::F16, Scalar::F16(s)) => set(self, s, l),
|
||||
(DType::BF16, Scalar::BF16(s)) => set(self, s, l),
|
||||
(DType::F32, Scalar::F32(s)) => set(self, s, l),
|
||||
(DType::F64, Scalar::F64(s)) => set(self, s, l),
|
||||
_ => crate::bail!("dtype mismatch, expected {:?}, got {:?}", self.dtype, s),
|
||||
}
|
||||
}
|
||||
|
||||
fn to_dtype(&self, layout: &Layout, dtype: DType) -> Result<Self> {
|
||||
let device = self.device();
|
||||
let shape = layout.shape();
|
||||
@ -1332,18 +1426,65 @@ impl BackendStorage for MetalStorage {
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
&self,
|
||||
fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<Self> {
|
||||
let mut acc = self.device.zeros_impl(l.shape(), self.dtype())?;
|
||||
self.copy_strided_src(&mut acc, 0, l)?;
|
||||
if !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "s_u8_f32",
|
||||
(DType::U8, DType::F16) => "s_u8_f16",
|
||||
(DType::U8, DType::BF16) => "s_u8_bf16",
|
||||
(DType::U32, DType::U32) => "s_u32_u32",
|
||||
(DType::U32, DType::F32) => "s_u32_f32",
|
||||
(DType::U32, DType::F16) => "s_u32_f16",
|
||||
(DType::U32, DType::BF16) => "s_u32_bf16",
|
||||
(DType::I64, DType::F32) => "s_i64_f32",
|
||||
(DType::I64, DType::F16) => "s_i64_f16",
|
||||
(DType::I64, DType::BF16) => "s_i64_bf16",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
got: ids.dtype(),
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
name,
|
||||
src_l.dims(),
|
||||
l.dims(),
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
ids: &Self,
|
||||
ids_l: &Layout,
|
||||
src: &Self,
|
||||
src_l: &Layout,
|
||||
dim: usize,
|
||||
) -> Result<()> {
|
||||
if !l.is_contiguous() || !ids_l.is_contiguous() || !src_l.is_contiguous() {
|
||||
return Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt());
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
@ -1364,9 +1505,10 @@ impl BackendStorage for MetalStorage {
|
||||
})?,
|
||||
};
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
let dst = buffer_o(&self.buffer, l, self.dtype);
|
||||
let src = buffer_o(&src.buffer, src_l, src.dtype);
|
||||
let ids = buffer_o(&ids.buffer, ids_l, ids.dtype);
|
||||
candle_metal_kernels::call_scatter_add(
|
||||
candle_metal_kernels::call_scatter(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
@ -1376,10 +1518,10 @@ impl BackendStorage for MetalStorage {
|
||||
dim,
|
||||
src,
|
||||
ids,
|
||||
&acc.buffer,
|
||||
dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(acc)
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn index_select(&self, ids: &Self, src_l: &Layout, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -1965,40 +2107,6 @@ impl BackendDevice for MetalDevice {
|
||||
))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let name = match dtype {
|
||||
DType::U8 => "fill_u8",
|
||||
DType::U32 => "fill_u32",
|
||||
DType::I64 => "fill_i64",
|
||||
DType::F16 => "fill_f16",
|
||||
DType::BF16 => "fill_bf16",
|
||||
DType::F32 => "fill_f32",
|
||||
DType::F64 => {
|
||||
let cpu_storage = crate::cpu_backend::CpuDevice.ones_impl(shape, dtype)?;
|
||||
return self.storage_from_cpu_storage(&cpu_storage);
|
||||
}
|
||||
};
|
||||
let buffer = self.new_buffer(shape.elem_count(), dtype, "alloc-ones")?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
candle_metal_kernels::call_const_fill(
|
||||
&self.device,
|
||||
&command_buffer,
|
||||
&self.kernels,
|
||||
name,
|
||||
shape.elem_count(),
|
||||
&buffer,
|
||||
1.,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
}
|
||||
|
||||
fn storage_from_slice<T: crate::WithDType>(&self, s: &[T]) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match T::cpu_storage_ref(s) {
|
||||
CpuStorageRef::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
|
@ -80,6 +80,7 @@ pub enum Op {
|
||||
Reduce(Tensor, ReduceOp, Vec<usize>),
|
||||
Matmul(Tensor, Tensor),
|
||||
Gather(Tensor, Tensor, usize),
|
||||
Scatter(Tensor, Tensor, Tensor, usize),
|
||||
ScatterAdd(Tensor, Tensor, Tensor, usize),
|
||||
IndexSelect(Tensor, Tensor, usize),
|
||||
IndexAdd(Tensor, Tensor, Tensor, usize),
|
||||
|
@ -1,6 +1,74 @@
|
||||
//! TensorScalar Enum and Trait
|
||||
//!
|
||||
use crate::{Result, Tensor, WithDType};
|
||||
use crate::{DType, Result, Tensor, WithDType};
|
||||
use half::{bf16, f16};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq)]
|
||||
pub enum Scalar {
|
||||
U8(u8),
|
||||
U32(u32),
|
||||
I64(i64),
|
||||
BF16(bf16),
|
||||
F16(f16),
|
||||
F32(f32),
|
||||
F64(f64),
|
||||
}
|
||||
|
||||
impl<T: WithDType> From<T> for Scalar {
|
||||
fn from(value: T) -> Self {
|
||||
value.to_scalar()
|
||||
}
|
||||
}
|
||||
|
||||
impl Scalar {
|
||||
pub fn zero(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(0),
|
||||
DType::U32 => Scalar::U32(0),
|
||||
DType::I64 => Scalar::I64(0),
|
||||
DType::BF16 => Scalar::BF16(bf16::ZERO),
|
||||
DType::F16 => Scalar::F16(f16::ZERO),
|
||||
DType::F32 => Scalar::F32(0.0),
|
||||
DType::F64 => Scalar::F64(0.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn one(dtype: DType) -> Self {
|
||||
match dtype {
|
||||
DType::U8 => Scalar::U8(1),
|
||||
DType::U32 => Scalar::U32(1),
|
||||
DType::I64 => Scalar::I64(1),
|
||||
DType::BF16 => Scalar::BF16(bf16::ONE),
|
||||
DType::F16 => Scalar::F16(f16::ONE),
|
||||
DType::F32 => Scalar::F32(1.0),
|
||||
DType::F64 => Scalar::F64(1.0),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> DType {
|
||||
match self {
|
||||
Scalar::U8(_) => DType::U8,
|
||||
Scalar::U32(_) => DType::U32,
|
||||
Scalar::I64(_) => DType::I64,
|
||||
Scalar::BF16(_) => DType::BF16,
|
||||
Scalar::F16(_) => DType::F16,
|
||||
Scalar::F32(_) => DType::F32,
|
||||
Scalar::F64(_) => DType::F64,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_f64(&self) -> f64 {
|
||||
match self {
|
||||
Scalar::U8(v) => *v as f64,
|
||||
Scalar::U32(v) => *v as f64,
|
||||
Scalar::I64(v) => *v as f64,
|
||||
Scalar::BF16(v) => v.to_f64(),
|
||||
Scalar::F16(v) => v.to_f64(),
|
||||
Scalar::F32(v) => *v as f64,
|
||||
Scalar::F64(v) => *v,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub enum TensorScalar {
|
||||
Tensor(Tensor),
|
||||
|
@ -1,5 +1,6 @@
|
||||
use crate::backend::BackendStorage;
|
||||
use crate::op::{self, CmpOp, ReduceOp};
|
||||
use crate::scalar::Scalar;
|
||||
use crate::{CpuStorage, CudaStorage, DType, Device, Error, Layout, MetalStorage, Result, Shape};
|
||||
use crate::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
|
||||
|
||||
@ -73,6 +74,14 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn const_set(&mut self, v: Scalar, l: &Layout) -> Result<()> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => storage.const_set(v, l),
|
||||
Storage::Cuda(storage) => storage.const_set(v, l),
|
||||
Storage::Metal(storage) => storage.const_set(v, l),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn affine(&self, layout: &Layout, mul: f64, add: f64) -> Result<Self> {
|
||||
match self {
|
||||
Storage::Cpu(storage) => {
|
||||
@ -619,32 +628,56 @@ impl Storage {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&self,
|
||||
pub(crate) fn scatter_set(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<Self> {
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-set")?;
|
||||
self.same_device(source, "scatter-set")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
s.scatter_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn scatter_add(
|
||||
&mut self,
|
||||
l: &Layout,
|
||||
indexes: &Self,
|
||||
indexes_l: &Layout,
|
||||
source: &Self,
|
||||
source_l: &Layout,
|
||||
d: usize,
|
||||
) -> Result<()> {
|
||||
self.same_device(indexes, "scatter-add")?;
|
||||
self.same_device(source, "scatter-add")?;
|
||||
match (self, indexes, source) {
|
||||
(Self::Cpu(s), Self::Cpu(indexes), Self::Cpu(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cpu(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Cuda(s), Self::Cuda(indexes), Self::Cuda(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Cuda(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
(Self::Metal(s), Self::Metal(indexes), Self::Metal(source)) => {
|
||||
let storage = s.scatter_add(l, indexes, indexes_l, source, source_l, d)?;
|
||||
Ok(Self::Metal(storage))
|
||||
s.scatter_add_set(l, indexes, indexes_l, source, source_l, d)?;
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn index_add(
|
||||
|
@ -3,7 +3,7 @@
|
||||
use crate::backend::{BackendDevice, BackendStorage};
|
||||
use crate::op::{BackpropOp, BinaryOp, CmpOp, Op, ReduceOp, UnaryOp};
|
||||
use crate::scalar::TensorOrScalar;
|
||||
use crate::shape::{Dim, Dims};
|
||||
use crate::shape::{Dim, Dims, ShapeWithOneHole};
|
||||
use crate::{bail, storage::Storage, DType, Device, Error, Layout, Result, Shape};
|
||||
use std::sync::{Arc, RwLock};
|
||||
|
||||
@ -185,7 +185,9 @@ impl Tensor {
|
||||
) -> Result<Self> {
|
||||
let none = BackpropOp::none();
|
||||
let shape = shape.into();
|
||||
let storage = device.ones(&shape, dtype)?;
|
||||
let mut storage = unsafe { device.alloc_uninit(&shape, dtype)? };
|
||||
let layout = Layout::contiguous(shape.clone());
|
||||
storage.const_set(crate::scalar::Scalar::one(dtype), &layout)?;
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
}
|
||||
|
||||
@ -202,6 +204,18 @@ impl Tensor {
|
||||
Self::ones_impl(shape, dtype, device, false)
|
||||
}
|
||||
|
||||
pub fn const_set(&self, value: crate::scalar::Scalar) -> Result<()> {
|
||||
self.storage_mut().const_set(value, self.layout())
|
||||
}
|
||||
|
||||
pub fn zero_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::zero(self.dtype()))
|
||||
}
|
||||
|
||||
pub fn one_set(&self) -> Result<()> {
|
||||
self.const_set(crate::scalar::Scalar::one(self.dtype()))
|
||||
}
|
||||
|
||||
/// Creates a new tensor filled with ones with same shape, dtype, and device as the other tensor.
|
||||
///
|
||||
/// ```rust
|
||||
@ -452,17 +466,13 @@ impl Tensor {
|
||||
Self::from_vec_impl(data, len, device, false)
|
||||
}
|
||||
|
||||
pub(crate) fn from_vec_impl<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub(crate) fn from_vec_impl<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
is_variable: bool,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let buffer_size = data.len();
|
||||
if buffer_size != shape.elem_count() {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(data.len())?;
|
||||
let storage = device.storage_owned(data)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, is_variable))
|
||||
@ -481,7 +491,7 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_vec<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_vec<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
data: Vec<D>,
|
||||
shape: S,
|
||||
device: &Device,
|
||||
@ -502,17 +512,12 @@ impl Tensor {
|
||||
/// ]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn from_slice<S: Into<Shape>, D: crate::WithDType>(
|
||||
pub fn from_slice<S: ShapeWithOneHole, D: crate::WithDType>(
|
||||
array: &[D],
|
||||
shape: S,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let shape = shape.into();
|
||||
let n: usize = shape.elem_count();
|
||||
let buffer_size: usize = array.len();
|
||||
if buffer_size != n {
|
||||
return Err(Error::ShapeMismatch { buffer_size, shape }.bt());
|
||||
}
|
||||
let shape = shape.into_shape(array.len())?;
|
||||
let storage = device.storage_from_slice(array)?;
|
||||
let none = BackpropOp::none();
|
||||
Ok(from_storage(storage, shape, none, false))
|
||||
@ -1349,8 +1354,7 @@ impl Tensor {
|
||||
self.index_select(ids, 0)
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
fn scatter_checks(&self, indexes: &Self, source: &Self, dim: usize) -> Result<()> {
|
||||
let source_dims = source.dims();
|
||||
let self_dims = self.dims();
|
||||
let mismatch = if source_dims.len() != self_dims.len() {
|
||||
@ -1367,7 +1371,7 @@ impl Tensor {
|
||||
};
|
||||
if mismatch {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (self, src)",
|
||||
op: "scatter (self, src)",
|
||||
lhs: self.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
@ -1375,13 +1379,44 @@ impl Tensor {
|
||||
}
|
||||
if indexes.dims() != source.dims() {
|
||||
Err(Error::ShapeMismatchBinaryOp {
|
||||
op: "scatter-add (indexes, src)",
|
||||
op: "scatter (indexes, src)",
|
||||
lhs: indexes.shape().clone(),
|
||||
rhs: source.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
let storage = self.storage().scatter_add(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_set(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::Scatter(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_set(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
@ -1389,12 +1424,48 @@ impl Tensor {
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn scatter_add<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "scatter-add")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
let shape = self.shape();
|
||||
let mut storage = unsafe { self.device().alloc_uninit(shape, self.dtype())? };
|
||||
self.storage()
|
||||
.copy_strided_src(&mut storage, 0, self.layout())?;
|
||||
let layout = Layout::contiguous(shape);
|
||||
storage.scatter_add(
|
||||
&layout,
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
let op = BackpropOp::new3(self, indexes, source, |t1, t2, t3| {
|
||||
Op::ScatterAdd(t1, t2, t3, dim)
|
||||
});
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub fn scatter_add_set<D: Dim>(&self, indexes: &Self, source: &Self, dim: D) -> Result<()> {
|
||||
if self.same_storage(source) {
|
||||
crate::bail!("cannot use slice_set when self and src share their storage")
|
||||
}
|
||||
let dim = dim.to_index(self.shape(), "scatter-add-set")?;
|
||||
self.scatter_checks(indexes, source, dim)?;
|
||||
self.storage_mut().scatter_add(
|
||||
self.layout(),
|
||||
&indexes.storage(),
|
||||
indexes.layout(),
|
||||
&source.storage(),
|
||||
source.layout(),
|
||||
dim,
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Embeds the values of the `src` tensor into the `self` tensor on the specified dimension.
|
||||
pub fn slice_scatter<D: Dim>(&self, src: &Self, dim: D, start: usize) -> Result<Self> {
|
||||
let dim = dim.to_index(self.shape(), "slice-scatter")?;
|
||||
@ -2197,7 +2268,7 @@ impl Tensor {
|
||||
///
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn reshape<S: crate::shape::ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
pub fn reshape<S: ShapeWithOneHole>(&self, s: S) -> Result<Tensor> {
|
||||
let shape = s.into_shape(self.elem_count())?;
|
||||
if shape.elem_count() != self.elem_count() {
|
||||
return Err(Error::ShapeMismatchBinaryOp {
|
||||
|
@ -241,7 +241,7 @@ impl Tensor {
|
||||
/// `self` and `src` must have the same shape except on dimension `dim` where the `self` size
|
||||
/// has to be greater than or equal to `offset` plus the `src` size.
|
||||
///
|
||||
/// Note that this modifies `self` in place and as such is not compatibel with
|
||||
/// Note that this modifies `self` in place and as such is not compatible with
|
||||
/// back-propagation.
|
||||
pub fn slice_set<D: Dim>(&self, src: &Self, dim: D, offset: usize) -> Result<()> {
|
||||
let dim = dim.to_index(self.shape(), "slice-set")?;
|
||||
|
@ -25,10 +25,12 @@ fn ones(device: &Device) -> Result<()> {
|
||||
Tensor::ones((2, 3), DType::F32, device)?.to_vec2::<f32>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
if !device.is_metal() {
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F64, device)?.to_vec2::<f64>()?,
|
||||
[[1.0, 1.0, 1.0], [1.0, 1.0, 1.0]],
|
||||
);
|
||||
}
|
||||
assert_eq!(
|
||||
Tensor::ones((2, 3), DType::F16, device)?.to_vec2::<half::f16>()?,
|
||||
[
|
||||
@ -63,6 +65,26 @@ fn ones(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn full(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::zeros((3, 4), DType::U32, device)?;
|
||||
tensor.const_set(42u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 42, 42], [42, 42, 42, 42], [42, 42, 42, 42]]
|
||||
);
|
||||
tensor.i((.., 2))?.const_set(1337u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [42, 42, 1337, 42]]
|
||||
);
|
||||
tensor.i((2, ..))?.const_set(1u32.into())?;
|
||||
assert_eq!(
|
||||
tensor.to_vec2::<u32>()?,
|
||||
[[42, 42, 1337, 42], [42, 42, 1337, 42], [1, 1, 1, 1]]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn const_set(device: &Device) -> Result<()> {
|
||||
assert_eq!(
|
||||
Tensor::full(42u32, (2, 3), device)?.to_vec2::<u32>()?,
|
||||
[[42, 42, 42], [42, 42, 42]],
|
||||
@ -826,6 +848,31 @@ fn embeddings(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_select_fail() -> Result<()> {
|
||||
// Check that an error is properly reported on out of bounds.
|
||||
let ids = Tensor::new(&[4u32, 2u32, 1u32], &Device::Cpu)?;
|
||||
let t = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &Device::Cpu)?;
|
||||
let hs = t.index_select(&ids, 0);
|
||||
assert!(hs.is_err());
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// The test below triggers an unwinding panic as there is a panic within the
|
||||
// #[cfg(feature = "cuda")]
|
||||
// #[test]
|
||||
// #[should_panic]
|
||||
// fn index_select_fail_gpu() {
|
||||
// // Check that a panic happens for out of bounds in cuda
|
||||
// if let Ok(device) = Device::new_cuda(0) {
|
||||
// if let Ok(ids) = Tensor::new(&[4u32, 2u32, 1u32], &device) {
|
||||
// if let Ok(t) = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], &device) {
|
||||
// let _ = t.index_select(&ids, 0);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
fn cmp(device: &Device) -> Result<()> {
|
||||
let t1 = Tensor::new(&[[0f32, 1f32], [2f32, 3f32], [4f32, 5f32]], device)?;
|
||||
let t2 = Tensor::new(&[[1f32, 0f32], [3f32, 3f32], [4f32, 7f32]], device)?;
|
||||
@ -980,7 +1027,7 @@ fn slice_scatter(device: &Device) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn scatter_add(device: &Device) -> Result<()> {
|
||||
fn scatter(device: &Device) -> Result<()> {
|
||||
let t = Tensor::arange(0f32, 12f32, device)?.reshape((4, 3))?;
|
||||
assert_eq!(
|
||||
t.to_vec2::<f32>()?,
|
||||
@ -1004,6 +1051,17 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
]
|
||||
);
|
||||
|
||||
let hs = init.scatter(&ids, &t, 1)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 1.0, 2.0, 1.0, 1.0],
|
||||
[5.0, 1.0, 1.0, 3.0, 4.0],
|
||||
[1.0, 8.0, 1.0, 7.0, 1.0],
|
||||
[10.0, 1.0, 9.0, 1.0, 11.0]
|
||||
]
|
||||
);
|
||||
|
||||
let init = Tensor::ones((6, 3), DType::F32, device)?;
|
||||
let hs = init.scatter_add(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
@ -1017,6 +1075,30 @@ fn scatter_add(device: &Device) -> Result<()> {
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
let hs = init.scatter(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
hs.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
init.scatter_set(&ids, &t, 0)?;
|
||||
assert_eq!(
|
||||
init.to_vec2::<f32>()?,
|
||||
&[
|
||||
[0.0, 10.0, 5.0],
|
||||
[1.0, 1.0, 8.0],
|
||||
[9.0, 1.0, 2.0],
|
||||
[6.0, 7.0, 1.0],
|
||||
[1.0, 4.0, 11.0],
|
||||
[1.0, 1.0, 1.0]
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1484,6 +1566,7 @@ fn zero_dim(device: &Device) -> Result<()> {
|
||||
test_device!(zeros, zeros_cpu, zeros_gpu, zeros_metal);
|
||||
test_device!(ones, ones_cpu, ones_gpu, ones_metal);
|
||||
test_device!(full, full_cpu, full_gpu, full_metal);
|
||||
test_device!(const_set, cs_cpu, cs_gpu, cs_metal);
|
||||
test_device!(arange, arange_cpu, arange_gpu, arange_metal);
|
||||
test_device!(add_mul, add_mul_cpu, add_mul_gpu, add_mul_metal);
|
||||
test_device!(tensor_2d, tensor_2d_cpu, tensor_2d_gpu, tensor_2d_metal);
|
||||
@ -1515,12 +1598,7 @@ test_device!(
|
||||
);
|
||||
test_device!(index_add, index_add_cpu, index_add_gpu, index_add_metal);
|
||||
test_device!(gather, gather_cpu, gather_gpu, gather_metal);
|
||||
test_device!(
|
||||
scatter_add,
|
||||
scatter_add_cpu,
|
||||
scatter_add_gpu,
|
||||
scatter_add_metal
|
||||
);
|
||||
test_device!(scatter, scatter_cpu, scatter_gpu, scatter_metal);
|
||||
test_device!(
|
||||
slice_scatter,
|
||||
slice_scatter_cpu,
|
||||
|
@ -124,6 +124,17 @@ impl TextGeneration {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
|
||||
let eot_token = match self.tokenizer.get_token("<end_of_turn>") {
|
||||
Some(token) => token,
|
||||
None => {
|
||||
println!(
|
||||
"Warning: <end_of_turn> token not found in tokenizer, using <eos> as a backup"
|
||||
);
|
||||
eos_token
|
||||
}
|
||||
};
|
||||
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
@ -146,7 +157,7 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
if next_token == eos_token || next_token == eot_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
@ -350,6 +361,31 @@ fn main() -> Result<()> {
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
|
||||
let prompt = match args.which {
|
||||
Which::Base2B
|
||||
| Which::Base7B
|
||||
| Which::Instruct2B
|
||||
| Which::Instruct7B
|
||||
| Which::InstructV1_1_2B
|
||||
| Which::InstructV1_1_7B
|
||||
| Which::CodeBase2B
|
||||
| Which::CodeBase7B
|
||||
| Which::CodeInstruct2B
|
||||
| Which::CodeInstruct7B
|
||||
| Which::BaseV2_2B
|
||||
| Which::InstructV2_2B
|
||||
| Which::BaseV2_9B
|
||||
| Which::InstructV2_9B
|
||||
| Which::BaseV3_1B => args.prompt,
|
||||
Which::InstructV3_1B => {
|
||||
format!(
|
||||
"<start_of_turn> user\n{}<end_of_turn>\n<start_of_turn> model\n",
|
||||
args.prompt
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
pipeline.run(&prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
||||
|
18
candle-examples/examples/quantized-gemma/README.md
Normal file
18
candle-examples/examples/quantized-gemma/README.md
Normal file
@ -0,0 +1,18 @@
|
||||
# candle-quantized-gemma
|
||||
|
||||
Candle implementation of quantized Gemma.
|
||||
|
||||
## Running an example
|
||||
|
||||
```bash
|
||||
$ cargo run --example quantized-gemma -- --prompt "Write a function to calculate fibonacci numbers. "
|
||||
|
||||
> ```python
|
||||
> def fibonacci(n):
|
||||
> """Calculates the nth Fibonacci number using recursion."""
|
||||
> if n <= 1:
|
||||
> return n
|
||||
> else:
|
||||
> return fibonacci(n-1) + fibonacci(n-2
|
||||
> ```
|
||||
```
|
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
344
candle-examples/examples/quantized-gemma/main.rs
Normal file
@ -0,0 +1,344 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::Tensor;
|
||||
use candle_transformers::generation::{LogitsProcessor, Sampling};
|
||||
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_transformers::models::quantized_gemma3::ModelWeights;
|
||||
|
||||
const DEFAULT_PROMPT: &str = "Write a function to calculate fibonacci num";
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "gemma3-4b-it")]
|
||||
Gemma3_4bIt,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// GGUF file to load, typically a .gguf file generated by quantization
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The initial prompt, use 'interactive' for entering multiple prompts in an interactive way
|
||||
/// and 'chat' for an interactive model where history of previous prompts and generated tokens
|
||||
/// is preserved.
|
||||
#[arg(long)]
|
||||
prompt: Option<String>,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(short = 'n', long, default_value_t = 1000)]
|
||||
sample_len: usize,
|
||||
|
||||
/// The tokenizer config in json format.
|
||||
#[arg(long)]
|
||||
tokenizer: Option<String>,
|
||||
|
||||
/// The temperature used to generate samples, use 0 for greedy sampling.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
temperature: f64,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// Only sample among the top K samples.
|
||||
#[arg(long)]
|
||||
top_k: Option<usize>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "gemma3-4b-it")]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn tokenizer(&self) -> anyhow::Result<Tokenizer> {
|
||||
let tokenizer_path = match &self.tokenizer {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let repo = "google/gemma-3-4b-it";
|
||||
println!("DEBUG: Downloading tokenizer from {}", repo);
|
||||
let api = api.model(repo.to_string());
|
||||
api.get("tokenizer.json")?
|
||||
}
|
||||
};
|
||||
println!("DEBUG: Loading tokenizer from {:?}", tokenizer_path);
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(anyhow::Error::msg)?;
|
||||
|
||||
Ok(tokenizer)
|
||||
}
|
||||
|
||||
fn model(&self) -> anyhow::Result<std::path::PathBuf> {
|
||||
let model_path = match &self.model {
|
||||
Some(config) => std::path::PathBuf::from(config),
|
||||
None => {
|
||||
let (repo, filename) = match self.which {
|
||||
Which::Gemma3_4bIt => (
|
||||
"google/gemma-3-4b-it-qat-q4_0-gguf",
|
||||
"gemma-3-4b-it-q4_0.gguf",
|
||||
),
|
||||
};
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
api.repo(hf_hub::Repo::with_revision(
|
||||
repo.to_string(),
|
||||
hf_hub::RepoType::Model,
|
||||
"main".to_string(),
|
||||
))
|
||||
.get(filename)?
|
||||
}
|
||||
};
|
||||
Ok(model_path)
|
||||
}
|
||||
}
|
||||
|
||||
fn format_size(size_in_bytes: usize) -> String {
|
||||
if size_in_bytes < 1_000 {
|
||||
format!("{}B", size_in_bytes)
|
||||
} else if size_in_bytes < 1_000_000 {
|
||||
format!("{:.2}KB", size_in_bytes as f64 / 1e3)
|
||||
} else if size_in_bytes < 1_000_000_000 {
|
||||
format!("{:.2}MB", size_in_bytes as f64 / 1e6)
|
||||
} else {
|
||||
format!("{:.2}GB", size_in_bytes as f64 / 1e9)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Prompt {
|
||||
Interactive,
|
||||
Chat,
|
||||
One(String),
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature, args.repeat_penalty, args.repeat_last_n
|
||||
);
|
||||
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let mut model = {
|
||||
let model = gguf_file::Content::read(&mut file).map_err(|e| e.with_path(&model_path))?;
|
||||
let mut total_size_in_bytes = 0;
|
||||
for (_, tensor) in model.tensor_infos.iter() {
|
||||
let elem_count = tensor.shape.elem_count();
|
||||
total_size_in_bytes +=
|
||||
elem_count * tensor.ggml_dtype.type_size() / tensor.ggml_dtype.block_size();
|
||||
}
|
||||
println!(
|
||||
"loaded {:?} tensors ({}) in {:.2}s",
|
||||
model.tensor_infos.len(),
|
||||
&format_size(total_size_in_bytes),
|
||||
start.elapsed().as_secs_f32(),
|
||||
);
|
||||
ModelWeights::from_gguf(model, &mut file, &device)?
|
||||
};
|
||||
println!("model built");
|
||||
|
||||
let tokenizer = args.tokenizer()?;
|
||||
|
||||
let mut tos = TokenOutputStream::new(tokenizer);
|
||||
println!(
|
||||
"DEBUG: Tokenizer vocabulary size: {}",
|
||||
tos.tokenizer().get_vocab(true).len()
|
||||
);
|
||||
|
||||
let prompt = match args.prompt.as_deref() {
|
||||
Some("chat") => Prompt::Chat,
|
||||
Some("interactive") => Prompt::Interactive,
|
||||
Some(s) => Prompt::One(s.to_string()),
|
||||
None => Prompt::One(DEFAULT_PROMPT.to_string()),
|
||||
};
|
||||
|
||||
let mut pre_prompt_tokens = vec![];
|
||||
for _ in 0.. {
|
||||
let prompt_str = match &prompt {
|
||||
Prompt::One(prompt) => prompt.clone(),
|
||||
Prompt::Interactive | Prompt::Chat => {
|
||||
print!("> ");
|
||||
std::io::stdout().flush()?;
|
||||
let mut prompt = String::new();
|
||||
std::io::stdin().read_line(&mut prompt)?;
|
||||
if prompt.ends_with('\n') {
|
||||
prompt.pop();
|
||||
if prompt.ends_with('\r') {
|
||||
prompt.pop();
|
||||
}
|
||||
}
|
||||
// Format for Gemma 3 chat/instruction format
|
||||
format!("<start_of_turn> user\n{prompt}<end_of_turn>\n<start_of_turn> model\n")
|
||||
}
|
||||
};
|
||||
print!("{}", &prompt_str);
|
||||
|
||||
let tokens = tos
|
||||
.tokenizer()
|
||||
.encode(prompt_str, true)
|
||||
.map_err(anyhow::Error::msg)?;
|
||||
let prompt_tokens = [&pre_prompt_tokens, tokens.get_ids()].concat();
|
||||
|
||||
let to_sample = args.sample_len.saturating_sub(1);
|
||||
let max_seq_len = 8192; // Gemma 3 context length
|
||||
let prompt_tokens = if prompt_tokens.len() + to_sample > max_seq_len - 10 {
|
||||
let to_remove = prompt_tokens.len() + to_sample + 10 - max_seq_len;
|
||||
prompt_tokens[prompt_tokens.len().saturating_sub(to_remove)..].to_vec()
|
||||
} else {
|
||||
prompt_tokens
|
||||
};
|
||||
let mut all_tokens = vec![];
|
||||
let mut logits_processor = {
|
||||
let temperature = args.temperature;
|
||||
let sampling = if temperature <= 0. {
|
||||
Sampling::ArgMax
|
||||
} else {
|
||||
match (args.top_k, args.top_p) {
|
||||
(None, None) => Sampling::All { temperature },
|
||||
(Some(k), None) => Sampling::TopK { k, temperature },
|
||||
(None, Some(p)) => Sampling::TopP { p, temperature },
|
||||
(Some(k), Some(p)) => Sampling::TopKThenTopP { k, p, temperature },
|
||||
}
|
||||
};
|
||||
LogitsProcessor::from_sampling(args.seed, sampling)
|
||||
};
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
|
||||
// For Gemma 3, use the correct end of sequence token
|
||||
let eos_token = *tos
|
||||
.tokenizer()
|
||||
.get_vocab(true)
|
||||
.get("<end_of_turn>")
|
||||
.unwrap();
|
||||
|
||||
let start_post_prompt = std::time::Instant::now();
|
||||
let mut sampled = 0;
|
||||
for index in 0..to_sample {
|
||||
let input = Tensor::new(&[next_token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, prompt_tokens.len() + index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = all_tokens.len().saturating_sub(args.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
args.repeat_penalty,
|
||||
&all_tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
next_token = logits_processor.sample(&logits)?;
|
||||
all_tokens.push(next_token);
|
||||
if let Some(t) = tos.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
sampled += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
};
|
||||
}
|
||||
if let Some(rest) = tos.decode_rest().map_err(candle::Error::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
let dt = start_post_prompt.elapsed();
|
||||
println!(
|
||||
"\n\n{:4} prompt tokens processed: {:.2} token/s",
|
||||
prompt_tokens.len(),
|
||||
prompt_tokens.len() as f64 / prompt_dt.as_secs_f64(),
|
||||
);
|
||||
println!(
|
||||
"{sampled:4} tokens generated: {:.2} token/s",
|
||||
sampled as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
|
||||
match prompt {
|
||||
Prompt::One(_) => break,
|
||||
Prompt::Interactive => {}
|
||||
Prompt::Chat => {
|
||||
pre_prompt_tokens = [prompt_tokens.as_slice(), all_tokens.as_slice()].concat()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.9.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -1,5 +1,6 @@
|
||||
#include<stdint.h>
|
||||
#include "cuda_fp16.h"
|
||||
#include "cuda_utils.cuh"
|
||||
|
||||
template<typename T>
|
||||
__device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
@ -36,13 +37,45 @@ COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#define CONST_SET_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t numel, \
|
||||
const size_t num_dims, \
|
||||
const size_t *info, \
|
||||
const TYPENAME inp, \
|
||||
TYPENAME *out \
|
||||
) { \
|
||||
const size_t *dims = info; \
|
||||
const size_t *strides = info + num_dims; \
|
||||
if (info == nullptr || is_contiguous(num_dims, dims, strides)) { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
out[i] = inp; \
|
||||
} \
|
||||
} \
|
||||
else { \
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) { \
|
||||
unsigned strided_i = get_strided_index(i, num_dims, dims, strides); \
|
||||
out[strided_i] = inp; \
|
||||
} \
|
||||
} \
|
||||
} \
|
||||
|
||||
CONST_SET_OP(float, const_set_f32)
|
||||
CONST_SET_OP(double, const_set_f64)
|
||||
CONST_SET_OP(uint8_t, const_set_u8)
|
||||
CONST_SET_OP(uint32_t, const_set_u32)
|
||||
CONST_SET_OP(int64_t, const_set_i64)
|
||||
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
CONST_SET_OP(__half, const_set_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
CONST_SET_OP(__nv_bfloat16, const_set_bf16)
|
||||
#endif
|
||||
|
@ -23,6 +23,7 @@ __device__ void index_select(
|
||||
unsigned int left_i = dst_i / (ids_dim_size * right_size);
|
||||
unsigned int id_i = dst_i / right_size % ids_dim_size;
|
||||
unsigned int right_i = dst_i % right_size;
|
||||
assert(ids[id_i] < src_dim_size);
|
||||
unsigned int src_i = left_i * (src_dim_size * right_size) + ids[id_i] * right_size + right_i;
|
||||
unsigned strided_i = b ? src_i : get_strided_index(src_i, num_dims, dims, strides);
|
||||
out[dst_i] = inp[strided_i];
|
||||
@ -57,6 +58,7 @@ __device__ void gather(
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
size_t post = i % right_size;
|
||||
size_t idx = ids[i];
|
||||
assert(idx < src_dim_size);
|
||||
size_t pre = i / (right_size * ids_dim_size);
|
||||
size_t src_i = (pre * src_dim_size + idx) * right_size + post;
|
||||
out[i] = inp[src_i];
|
||||
@ -92,6 +94,7 @@ __device__ void index_add(
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < ids_dim_size; ++j) {
|
||||
const size_t idx = ids[j];
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t src_i = (pre * ids_dim_size + j) * right_size + post;
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
@ -111,6 +114,30 @@ extern "C" __global__ void FN_NAME( \
|
||||
const size_t right_size \
|
||||
) { index_add(ids, ids_dim_size, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter(
|
||||
const I *ids,
|
||||
const T *inp,
|
||||
T *out,
|
||||
const size_t left_size,
|
||||
const size_t src_dim_size,
|
||||
const size_t dst_dim_size,
|
||||
const size_t right_size
|
||||
) {
|
||||
const size_t numel = left_size * right_size;
|
||||
for (unsigned int i = blockIdx.x * blockDim.x + threadIdx.x; i < numel; i += blockDim.x * gridDim.x) {
|
||||
const size_t pre = i / right_size;
|
||||
const size_t post = i % right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const size_t idx = ids[src_i];
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] = inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T, typename I>
|
||||
__device__ void scatter_add(
|
||||
const I *ids,
|
||||
@ -128,12 +155,24 @@ __device__ void scatter_add(
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (pre * src_dim_size + j) * right_size + post;
|
||||
const size_t idx = ids[src_i];
|
||||
assert(idx < dst_dim_size);
|
||||
const size_t dst_i = (pre * dst_dim_size + idx) * right_size + post;
|
||||
out[dst_i] += inp[src_i];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define S_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
const TYPENAME *inp, \
|
||||
TYPENAME *out, \
|
||||
const size_t left_size, \
|
||||
const size_t src_dim_size, \
|
||||
const size_t dst_dim_size, \
|
||||
const size_t right_size \
|
||||
) { scatter(ids, inp, out, left_size, src_dim_size, dst_dim_size, right_size); } \
|
||||
|
||||
#define SA_OP(TYPENAME, INDEX_TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const INDEX_TYPENAME *ids, \
|
||||
@ -159,6 +198,9 @@ IA_OP(__nv_bfloat16, uint8_t, ia_u8_bf16)
|
||||
SA_OP(__nv_bfloat16, int64_t, sa_i64_bf16)
|
||||
SA_OP(__nv_bfloat16, uint32_t, sa_u32_bf16)
|
||||
SA_OP(__nv_bfloat16, uint8_t, sa_u8_bf16)
|
||||
S_OP(__nv_bfloat16, int64_t, s_i64_bf16)
|
||||
S_OP(__nv_bfloat16, uint32_t, s_u32_bf16)
|
||||
S_OP(__nv_bfloat16, uint8_t, s_u8_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -174,6 +216,9 @@ IA_OP(__half, uint8_t, ia_u8_f16)
|
||||
SA_OP(__half, int64_t, sa_i64_f16)
|
||||
SA_OP(__half, uint32_t, sa_u32_f16)
|
||||
SA_OP(__half, uint8_t, sa_u8_f16)
|
||||
S_OP(__half, int64_t, s_i64_f16)
|
||||
S_OP(__half, uint32_t, s_u32_f16)
|
||||
S_OP(__half, uint8_t, s_u8_f16)
|
||||
#endif
|
||||
|
||||
IS_OP(float, int64_t, is_i64_f32)
|
||||
@ -247,3 +292,21 @@ SA_OP(double, uint8_t, sa_u8_f64)
|
||||
SA_OP(uint8_t, uint8_t, sa_u8_u8)
|
||||
SA_OP(uint32_t, uint8_t, sa_u8_u32)
|
||||
SA_OP(int64_t, uint8_t, sa_u8_i64)
|
||||
|
||||
S_OP(float, int64_t, s_i64_f32)
|
||||
S_OP(double, int64_t, s_i64_f64)
|
||||
S_OP(uint8_t, int64_t, s_i64_u8)
|
||||
S_OP(int64_t, int64_t, s_i64_i64)
|
||||
S_OP(uint32_t, int64_t, s_i64_u32)
|
||||
|
||||
S_OP(float, uint32_t, s_u32_f32)
|
||||
S_OP(double, uint32_t, s_u32_f64)
|
||||
S_OP(uint8_t, uint32_t, s_u32_u8)
|
||||
S_OP(int64_t, uint32_t, s_u32_i64)
|
||||
S_OP(uint32_t, uint32_t, s_u32_u32)
|
||||
|
||||
S_OP(float, uint8_t, s_u8_f32)
|
||||
S_OP(double, uint8_t, s_u8_f64)
|
||||
S_OP(uint8_t, uint8_t, s_u8_u8)
|
||||
S_OP(uint32_t, uint8_t, s_u8_u32)
|
||||
S_OP(int64_t, uint8_t, s_u8_i64)
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
@ -12,6 +12,7 @@ license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
metal = { version = "0.27.0", features = ["mps"] }
|
||||
half = { version = "2.5.0", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
once_cell = "1.18.0"
|
||||
thiserror = "1"
|
||||
tracing = "0.1.37"
|
||||
|
@ -4,20 +4,20 @@ using namespace metal;
|
||||
|
||||
template<typename T> METAL_FUNC void fill_with(
|
||||
device T *out,
|
||||
constant float &value,
|
||||
constant T &value,
|
||||
constant size_t &numel,
|
||||
uint tid [[thread_position_in_grid]]
|
||||
) {
|
||||
if (tid >= numel) {
|
||||
return;
|
||||
}
|
||||
out[tid] = static_cast<T>(value);
|
||||
out[tid] = value;
|
||||
}
|
||||
|
||||
#define FILL_OP(NAME, T) \
|
||||
kernel void fill_##NAME( \
|
||||
device T *out, \
|
||||
constant float &value, \
|
||||
constant T &value, \
|
||||
constant size_t &numel, \
|
||||
uint tid [[thread_position_in_grid]] \
|
||||
) { \
|
||||
|
@ -104,6 +104,31 @@ kernel void NAME( \
|
||||
gather<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, ids_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter(
|
||||
constant size_t &dst_size,
|
||||
constant size_t &left_size,
|
||||
constant size_t &src_dim_size,
|
||||
constant size_t &right_size,
|
||||
constant size_t &dst_dim_size,
|
||||
const device TYPENAME *input,
|
||||
const device INDEX_TYPENAME *input_ids,
|
||||
device TYPENAME *output,
|
||||
uint tid [[ thread_position_in_grid ]]
|
||||
) {
|
||||
if (tid >= dst_size) {
|
||||
return;
|
||||
}
|
||||
const size_t right_rank_i = tid % right_size;
|
||||
const size_t left_rank_i = tid / right_size;
|
||||
for (unsigned int j = 0; j < src_dim_size; ++j) {
|
||||
const size_t src_i = (left_rank_i * src_dim_size + j) * right_size + right_rank_i;
|
||||
const INDEX_TYPENAME idx = input_ids[src_i];
|
||||
const size_t dst_i = (left_rank_i * dst_dim_size + idx) * right_size + right_rank_i;
|
||||
output[dst_i] = input[src_i];
|
||||
}
|
||||
}
|
||||
|
||||
template<typename TYPENAME, typename INDEX_TYPENAME>
|
||||
METAL_FUNC void scatter_add(
|
||||
constant size_t &dst_size,
|
||||
@ -129,6 +154,21 @@ METAL_FUNC void scatter_add(
|
||||
}
|
||||
}
|
||||
|
||||
# define SCATTER_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
constant size_t &left_size, \
|
||||
constant size_t &src_dim_size, \
|
||||
constant size_t &right_size, \
|
||||
constant size_t &dst_dim_size, \
|
||||
const device TYPENAME *input, \
|
||||
const device INDEX_TYPENAME *input_ids, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
scatter<TYPENAME, INDEX_TYPENAME>(dst_size, left_size, src_dim_size, right_size, dst_dim_size, input, input_ids, output, tid); \
|
||||
}
|
||||
|
||||
# define SCATTER_ADD_OP(NAME, INDEX_TYPENAME, TYPENAME) \
|
||||
kernel void NAME( \
|
||||
constant size_t &dst_size, \
|
||||
@ -235,6 +275,19 @@ SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
SCATTER_OP(s_u32_f32, uint32_t, float)
|
||||
SCATTER_OP(s_u8_f32, uint8_t, float)
|
||||
SCATTER_OP(s_i64_f32, int64_t, float)
|
||||
SCATTER_OP(s_u32_u32, uint32_t, uint32_t)
|
||||
SCATTER_OP(s_u32_f16, uint32_t, half)
|
||||
SCATTER_OP(s_u8_f16, uint8_t, half)
|
||||
SCATTER_OP(s_i64_f16, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
SCATTER_OP(s_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_OP(s_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_OP(s_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
INDEX_ADD_OP(ia_i64_f16, int64_t, half)
|
||||
INDEX_ADD_OP(ia_i64_f32, int64_t, float)
|
||||
|
@ -161,7 +161,7 @@ macro_rules! ops{
|
||||
pub mod unary {
|
||||
ops!(
|
||||
cos, sin, exp, sqr, sqrt, neg, log, gelu, abs, ceil, floor, relu, round, erf, gelu_erf,
|
||||
tanh, recip, silu, sign, sigmoid
|
||||
tanh, recip, silu, sign, sigmoid, const_set
|
||||
);
|
||||
}
|
||||
pub mod binary {
|
||||
@ -419,6 +419,82 @@ pub fn call_copy2d(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous_tiled(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous_tiled::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let tile_size = 2;
|
||||
let tiles = length.div_ceil(tile_size);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, tiles);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_contiguous(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
kernel_name: unary::contiguous::Kernel,
|
||||
length: usize,
|
||||
input: impl EncoderParam,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, kernel_name.0)?;
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
|
||||
set_params!(encoder, (length, input, &output));
|
||||
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_const_set_strided(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
name: unary::strided::Kernel,
|
||||
shape: &[usize],
|
||||
input: impl EncoderParam,
|
||||
strides: &[usize],
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
|
||||
let length: usize = shape.iter().product();
|
||||
let num_dims: usize = shape.len();
|
||||
let encoder = ep.encoder();
|
||||
let encoder: &ComputeCommandEncoderRef = encoder.as_ref();
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, length);
|
||||
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(encoder, (length, num_dims, shape, strides, input, &output));
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_contiguous_tiled(
|
||||
device: &Device,
|
||||
@ -1371,7 +1447,7 @@ pub fn call_gather(
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_scatter_add(
|
||||
pub fn call_scatter(
|
||||
device: &Device,
|
||||
ep: impl EncoderProvider,
|
||||
kernels: &Kernels,
|
||||
@ -1381,7 +1457,7 @@ pub fn call_scatter_add(
|
||||
dim: usize,
|
||||
input: BufferOffset,
|
||||
ids: BufferOffset,
|
||||
output: &Buffer,
|
||||
output: BufferOffset,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let left_size: usize = src_shape[..dim].iter().product();
|
||||
let right_size: usize = src_shape[dim + 1..].iter().product();
|
||||
@ -1406,7 +1482,7 @@ pub fn call_scatter_add(
|
||||
dst_dim_size,
|
||||
&input,
|
||||
&ids,
|
||||
output
|
||||
&output
|
||||
)
|
||||
);
|
||||
|
||||
@ -1414,7 +1490,7 @@ pub fn call_scatter_add(
|
||||
|
||||
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(ids.buffer, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(output.buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
Ok(())
|
||||
}
|
||||
@ -2570,7 +2646,7 @@ pub fn call_const_fill(
|
||||
name: &'static str,
|
||||
length: usize,
|
||||
output: &Buffer,
|
||||
v: f32,
|
||||
v: impl EncoderParam,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Fill, name)?;
|
||||
let encoder = ep.encoder();
|
||||
|
@ -1574,7 +1574,7 @@ fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
call_scatter(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
@ -2343,7 +2343,7 @@ fn conv_transpose1d_u32() {
|
||||
|
||||
#[test]
|
||||
fn const_fill() {
|
||||
fn constant_fill<T: Clone>(name: &'static str, len: usize, value: f32) -> Vec<T> {
|
||||
fn constant_fill<T: Clone + EncoderParam>(name: &'static str, len: usize, value: T) -> Vec<T> {
|
||||
let dev = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = dev.new_command_queue();
|
||||
@ -2357,11 +2357,15 @@ fn const_fill() {
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec::<T>(&buffer, len)
|
||||
}
|
||||
fn test<T: Clone + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(name: &'static str, f: F) {
|
||||
fn test<T: Clone + Copy + EncoderParam + PartialEq + std::fmt::Debug, F: FnOnce(f32) -> T>(
|
||||
name: &'static str,
|
||||
f: F,
|
||||
) {
|
||||
let len = rand::thread_rng().gen_range(2..16) * rand::thread_rng().gen_range(4..16);
|
||||
let value = rand::thread_rng().gen_range(1. ..19.);
|
||||
let value = f(value);
|
||||
let v = constant_fill::<T>(name, len, value);
|
||||
assert_eq!(v, vec![f(value); len])
|
||||
assert_eq!(v, vec![value; len])
|
||||
}
|
||||
test::<u8, _>("fill_u8", |v| v as u8);
|
||||
test::<u32, _>("fill_u32", |v| v as u32);
|
||||
|
@ -73,6 +73,44 @@ template <typename T> METAL_FUNC T sigmoid(T in) {
|
||||
|
||||
#define TILE_SIZE 2
|
||||
|
||||
#define CONST_SET(TYPENAME, FN_NAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[tid] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##strided( \
|
||||
constant size_t &dim, \
|
||||
constant size_t &num_dims, \
|
||||
constant size_t *dims, \
|
||||
constant size_t *strides, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= dim) { \
|
||||
return; \
|
||||
} \
|
||||
output[get_strided_index(tid, num_dims, dims, strides)] = input; \
|
||||
} \
|
||||
kernel void FN_NAME##_##tiled( \
|
||||
constant size_t &dim, \
|
||||
constant TYPENAME &input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
for (uint i = 0; i < TILE_SIZE; i++) { \
|
||||
const uint idx = tid * TILE_SIZE + i; \
|
||||
output[idx] = input; \
|
||||
} \
|
||||
}
|
||||
|
||||
#define UNARY(FN, TYPENAME, FN_NAME, FN_NAME_STRIDED) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &dim, \
|
||||
@ -139,6 +177,11 @@ COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
CONST_SET(float, const_set_f32)
|
||||
CONST_SET(half, const_set_f16)
|
||||
CONST_SET(uint8_t, const_set_u8)
|
||||
CONST_SET(uint32_t, const_set_u32)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
UNARY_OP(sqr)
|
||||
@ -171,6 +214,7 @@ UNARY(precise::tanh, half, tanh_f16, tanh_f16_strided);
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
CONST_SET(int64_t, const_set_i64)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -199,4 +243,5 @@ UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
UNARY(precise::tanh, bfloat, tanh_bf16, tanh_bf16_strided);
|
||||
|
||||
COPY2D(copy2d_bf16, bfloat)
|
||||
CONST_SET(bfloat, const_set_bf16)
|
||||
#endif
|
||||
|
@ -88,9 +88,13 @@ primitive!(bool);
|
||||
primitive!(usize);
|
||||
primitive!(i32);
|
||||
primitive!(i64);
|
||||
primitive!(u8);
|
||||
primitive!(u32);
|
||||
primitive!(u64);
|
||||
primitive!(f32);
|
||||
primitive!(f64);
|
||||
primitive!(half::bf16);
|
||||
primitive!(half::f16);
|
||||
|
||||
pub struct BufferOffset<'a> {
|
||||
pub buffer: &'a Buffer,
|
||||
|
@ -71,6 +71,8 @@ impl candle::Module for PReLU {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let weight = if self.is_scalar {
|
||||
self.weight.reshape(())?
|
||||
} else if xs.shape() == self.weight.shape() {
|
||||
self.weight.clone()
|
||||
} else if xs.rank() >= 2 {
|
||||
let num_channels = xs.dim(1)?;
|
||||
let num_weights = self.weight.elem_count();
|
||||
@ -78,7 +80,7 @@ impl candle::Module for PReLU {
|
||||
candle::bail!("error in prelu: unexpected number of channels for the input, got {num_channels}, weight dim is {num_weights}")
|
||||
}
|
||||
let mut s = vec![1; xs.rank()];
|
||||
s[1] = self.weight.elem_count();
|
||||
s[1] = num_weights;
|
||||
self.weight.reshape(s)?
|
||||
} else {
|
||||
self.weight.clone()
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.9.0-alpha.4"
|
||||
version = "0.9.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0-alpha.4" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0-alpha.4" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.9.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.9.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,7 +1,9 @@
|
||||
use crate::onnx::attribute_proto::AttributeType;
|
||||
use crate::onnx::tensor_proto::DataType;
|
||||
use crate::onnx::{self, GraphProto};
|
||||
use candle::Module;
|
||||
use candle::{bail, DType, Device, Result, Tensor};
|
||||
use candle_nn::activation::PReLU;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
pub type Value = Tensor;
|
||||
@ -991,6 +993,14 @@ fn simple_eval_(
|
||||
let output = input.relu()?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"PRelu" => {
|
||||
// https://onnx.ai/onnx/operators/onnx__PRelu.html
|
||||
let input = get(&node.input[0])?;
|
||||
let slope = get(&node.input[1])?;
|
||||
|
||||
let output = PReLU::new(slope.clone(), false).forward(input)?;
|
||||
values.insert(node.output[0].clone(), output);
|
||||
}
|
||||
"Ceil" => {
|
||||
let input = get(&node.input[0])?;
|
||||
let output = input.ceil()?;
|
||||
|
@ -1846,6 +1846,64 @@ fn test_relu_operation() -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// "PRelu"
|
||||
#[test]
|
||||
fn test_prelu_operation() -> Result<()> {
|
||||
let manual_graph = create_model_proto_with_graph(Some(GraphProto {
|
||||
node: vec![NodeProto {
|
||||
op_type: "PRelu".to_string(),
|
||||
domain: "".to_string(),
|
||||
attribute: vec![],
|
||||
input: vec![INPUT_X.to_string(), INPUT_Y.to_string()],
|
||||
output: vec![OUTPUT_Z.to_string()],
|
||||
name: "".to_string(),
|
||||
doc_string: "".to_string(),
|
||||
}],
|
||||
name: "".to_string(),
|
||||
initializer: vec![],
|
||||
input: vec![
|
||||
ValueInfoProto {
|
||||
name: INPUT_X.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
ValueInfoProto {
|
||||
name: INPUT_Y.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
},
|
||||
],
|
||||
output: vec![ValueInfoProto {
|
||||
name: OUTPUT_Z.to_string(),
|
||||
doc_string: "".to_string(),
|
||||
r#type: None,
|
||||
}],
|
||||
value_info: vec![],
|
||||
doc_string: "".to_string(),
|
||||
sparse_initializer: vec![],
|
||||
quantization_annotation: vec![],
|
||||
}));
|
||||
let x: Tensor = Tensor::from_vec(
|
||||
vec![-1.0f32, 1.0f32, -2.0f32, 3.0f32],
|
||||
&[2, 2],
|
||||
&Device::Cpu,
|
||||
)?;
|
||||
|
||||
let y: Tensor = Tensor::from_vec(vec![1.0f32, 1.1f32, 1.2f32, 1.3f32], &[2, 2], &Device::Cpu)?;
|
||||
|
||||
let mut inputs: HashMap<String, Tensor> = HashMap::new();
|
||||
inputs.insert(INPUT_X.to_string(), x);
|
||||
inputs.insert(INPUT_Y.to_string(), y);
|
||||
|
||||
let eval = candle_onnx::simple_eval(&manual_graph, inputs)?;
|
||||
assert_eq!(eval.len(), 1);
|
||||
|
||||
let z = eval.get(OUTPUT_Z).expect("Output 'z' not found");
|
||||
let results = z.to_vec2::<f32>()?;
|
||||
assert_eq!(results, vec![vec![-1.0, 1.0], vec![-2.4, 3.0]]);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
// "Constant"
|
||||
// #[test]
|
||||
|
||||
|
@ -21,6 +21,7 @@ pub struct Config {
|
||||
pub num_key_value_heads: usize,
|
||||
pub rms_norm_eps: f64,
|
||||
pub rope_theta: f64,
|
||||
pub rope_local_base_freq: f64,
|
||||
pub vocab_size: usize,
|
||||
pub final_logit_softcapping: Option<f64>,
|
||||
pub attn_logit_softcapping: Option<f64>,
|
||||
@ -67,12 +68,22 @@ struct RotaryEmbedding {
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(dtype: DType, cfg: &Config, dev: &Device) -> Result<Self> {
|
||||
fn new(
|
||||
dtype: DType,
|
||||
cfg: &Config,
|
||||
dev: &Device,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let dim = cfg.head_dim;
|
||||
let max_seq_len = cfg.max_position_embeddings;
|
||||
let rope_freq = if sliding_window.is_some() {
|
||||
cfg.rope_local_base_freq
|
||||
} else {
|
||||
cfg.rope_theta
|
||||
};
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / cfg.rope_theta.powf(i as f64 / dim as f64) as f32)
|
||||
.map(|i| 1f32 / rope_freq.powf(i as f64 / dim as f64) as f32)
|
||||
.collect();
|
||||
let inv_freq_len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?.to_dtype(dtype)?;
|
||||
@ -162,8 +173,8 @@ impl Attention {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
sliding_window: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let hidden_sz = cfg.hidden_size;
|
||||
@ -178,13 +189,13 @@ impl Attention {
|
||||
let o_proj = linear(num_heads * head_dim, hidden_sz, bias, vb.pp("o_proj"))?;
|
||||
let q_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("q_norm"))?;
|
||||
let k_norm = RmsNorm::new(head_dim, cfg.rms_norm_eps, vb.pp("k_norm"))?;
|
||||
let kv_cache = if is_sliding {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(
|
||||
2,
|
||||
cfg.sliding_window,
|
||||
))
|
||||
let kv_cache = if let Some(sliding_window) = sliding_window {
|
||||
KvCache::Rotating(candle_nn::kv_cache::RotatingKvCache::new(2, sliding_window))
|
||||
} else {
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(2, cfg.sliding_window))
|
||||
KvCache::Normal(candle_nn::kv_cache::KvCache::new(
|
||||
2,
|
||||
cfg.max_position_embeddings,
|
||||
))
|
||||
};
|
||||
Ok(Self {
|
||||
q_proj,
|
||||
@ -302,21 +313,27 @@ struct DecoderLayer {
|
||||
pre_feedforward_layernorm: RmsNorm,
|
||||
post_feedforward_layernorm: RmsNorm,
|
||||
post_attention_layernorm: RmsNorm,
|
||||
sliding_window: Option<usize>,
|
||||
}
|
||||
|
||||
impl DecoderLayer {
|
||||
fn new(
|
||||
rotary_emb: Arc<RotaryEmbedding>,
|
||||
use_flash_attn: bool,
|
||||
is_sliding: bool,
|
||||
cfg: &Config,
|
||||
vb: VarBuilder,
|
||||
sliding_window: Option<usize>,
|
||||
) -> Result<Self> {
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(
|
||||
vb.dtype(),
|
||||
cfg,
|
||||
vb.device(),
|
||||
sliding_window,
|
||||
)?);
|
||||
let self_attn = Attention::new(
|
||||
rotary_emb,
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
sliding_window,
|
||||
vb.pp("self_attn"),
|
||||
)?;
|
||||
let mlp = MLP::new(cfg, vb.pp("mlp"))?;
|
||||
@ -344,6 +361,7 @@ impl DecoderLayer {
|
||||
pre_feedforward_layernorm,
|
||||
post_feedforward_layernorm,
|
||||
post_attention_layernorm,
|
||||
sliding_window,
|
||||
})
|
||||
}
|
||||
|
||||
@ -370,6 +388,42 @@ impl DecoderLayer {
|
||||
}
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
seqlen_offset: usize,
|
||||
sliding_window: Option<usize>,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window) = sliding_window {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0f32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
embed_tokens: candle_nn::Embedding,
|
||||
@ -388,17 +442,15 @@ impl Model {
|
||||
let vb_m = vb.pp("model");
|
||||
let embed_tokens =
|
||||
candle_nn::embedding(cfg.vocab_size, cfg.hidden_size, vb_m.pp("embed_tokens"))?;
|
||||
let rotary_emb = Arc::new(RotaryEmbedding::new(vb.dtype(), cfg, vb_m.device())?);
|
||||
let mut layers = Vec::with_capacity(cfg.num_hidden_layers);
|
||||
let vb_l = vb_m.pp("layers");
|
||||
for layer_idx in 0..cfg.num_hidden_layers {
|
||||
let is_sliding = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let sliding_window = (layer_idx + 1) % cfg.sliding_window_pattern > 0;
|
||||
let layer = DecoderLayer::new(
|
||||
rotary_emb.clone(),
|
||||
use_flash_attn,
|
||||
is_sliding,
|
||||
cfg,
|
||||
vb_l.pp(layer_idx),
|
||||
sliding_window.then_some(cfg.sliding_window),
|
||||
)?;
|
||||
layers.push(layer)
|
||||
}
|
||||
@ -417,51 +469,52 @@ impl Model {
|
||||
})
|
||||
}
|
||||
|
||||
fn prepare_decoder_attention_mask(
|
||||
fn create_attention_masks(
|
||||
&self,
|
||||
b_size: usize,
|
||||
tgt_len: usize,
|
||||
batch_size: usize,
|
||||
seq_len: usize,
|
||||
seqlen_offset: usize,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = match Some(self.sliding_window) {
|
||||
None => (0..tgt_len)
|
||||
.flat_map(|i| (0..tgt_len).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect(),
|
||||
Some(sliding_window) => (0..tgt_len)
|
||||
.flat_map(|i| {
|
||||
(0..tgt_len).map(move |j| {
|
||||
if i < j || j + sliding_window < i {
|
||||
f32::NEG_INFINITY
|
||||
} else {
|
||||
0.
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect(),
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
|
||||
let mask = if seqlen_offset > 0 {
|
||||
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_size, 1, tgt_len, tgt_len + seqlen_offset))?
|
||||
.to_dtype(self.dtype)
|
||||
) -> Result<(Option<Tensor>, Option<Tensor>)> {
|
||||
if seq_len <= 1 {
|
||||
return Ok((None, None));
|
||||
}
|
||||
|
||||
let mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
None,
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
let sliding_mask = prepare_decoder_attention_mask(
|
||||
batch_size,
|
||||
seq_len,
|
||||
seqlen_offset,
|
||||
Some(self.sliding_window),
|
||||
self.dtype,
|
||||
&self.device,
|
||||
)?;
|
||||
|
||||
Ok((Some(mask), Some(sliding_mask)))
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||
let (b_size, seq_len) = input_ids.dims2()?;
|
||||
let attention_mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
|
||||
Some(mask)
|
||||
};
|
||||
let xs = self.embed_tokens.forward(input_ids)?;
|
||||
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
|
||||
|
||||
let (attention_mask, sliding_attention_mask) =
|
||||
self.create_attention_masks(b_size, seq_len, seqlen_offset)?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
|
||||
let mask = if layer.sliding_window.is_some() {
|
||||
&sliding_attention_mask
|
||||
} else {
|
||||
&attention_mask
|
||||
};
|
||||
xs = layer.forward(&xs, mask.as_ref(), seqlen_offset)?
|
||||
}
|
||||
let logits = xs
|
||||
.narrow(1, seq_len - 1, 1)?
|
||||
|
@ -79,6 +79,7 @@ pub mod phi3;
|
||||
pub mod pixtral;
|
||||
pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_gemma3;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_llama2_c;
|
||||
pub mod quantized_metavoice;
|
||||
|
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
466
candle-transformers/src/models/quantized_gemma3.rs
Normal file
@ -0,0 +1,466 @@
|
||||
//! Gemma 3 model implementation with quantization support.
|
||||
//!
|
||||
//! Gemma 3 is a family of multimodal language models developed by Google.
|
||||
//! This implementation provides quantization for reduced memory usage and faster inference.
|
||||
//!
|
||||
//! Key characteristics:
|
||||
//! - Group-Query Attention (GQA) with specialized key-value heads
|
||||
//! - RMSNorm for layer normalization
|
||||
//! - Specialized attention patterns with separate normalization for Q/K/V
|
||||
//! - Feed-forward network with SwiGLU activation
|
||||
//! - Support for 2/3/4/8-bit quantization
|
||||
//!
|
||||
//! References:
|
||||
//! - [Gemma 3 Models](https://blog.google/technology/developers/gemma-3/)
|
||||
//!
|
||||
|
||||
use crate::quantized_nn::RmsNorm;
|
||||
use candle::quantized::gguf_file;
|
||||
use candle::quantized::QTensor;
|
||||
use candle::D;
|
||||
use candle::{DType, Device, IndexOp, Result, Tensor};
|
||||
use candle_nn::{Embedding, Module};
|
||||
|
||||
pub const MAX_SEQ_LEN: usize = 131072; // Gemma 3 supports 128K context window
|
||||
pub const DEFAULT_SLIDING_WINDOW_TYPE: usize = 6;
|
||||
pub const DEFAULT_ROPE_FREQUENCY: f32 = 1_000_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SLIDING: f32 = 10_000.;
|
||||
pub const DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR: f32 = 1.;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct QMatMul {
|
||||
inner: candle::quantized::QMatMul,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl QMatMul {
|
||||
fn from_qtensor(qtensor: QTensor) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_qtensor(qtensor)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
self.inner.forward(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Mlp {
|
||||
feed_forward_gate: QMatMul, // ffn_gate in GGUF
|
||||
feed_forward_up: QMatMul, // ffn_up in GGUF
|
||||
feed_forward_down: QMatMul, // ffn_down in GGUF
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let gate = self.feed_forward_gate.forward(xs)?;
|
||||
let up = self.feed_forward_up.forward(xs)?;
|
||||
let silu = candle_nn::ops::silu(&gate)?;
|
||||
let gated = (silu * up)?;
|
||||
self.feed_forward_down.forward(&gated)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RotaryEmbedding {
|
||||
sin: Tensor,
|
||||
cos: Tensor,
|
||||
}
|
||||
|
||||
impl RotaryEmbedding {
|
||||
fn new(head_dim: usize, rope_frequency: f32, device: &Device) -> Result<Self> {
|
||||
let theta: Vec<_> = (0..head_dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / rope_frequency.powf(i as f32 / head_dim as f32))
|
||||
.collect();
|
||||
let theta = Tensor::new(theta.as_slice(), device)?;
|
||||
let idx_theta = Tensor::arange(0, MAX_SEQ_LEN as u32, device)?
|
||||
.to_dtype(DType::F32)?
|
||||
.reshape((MAX_SEQ_LEN, 1))?
|
||||
.matmul(&theta.reshape((1, theta.elem_count()))?)?;
|
||||
let cos = idx_theta.cos()?;
|
||||
let sin = idx_theta.sin()?;
|
||||
Ok(Self { sin, cos })
|
||||
}
|
||||
|
||||
fn apply_rotary_emb_qkv(
|
||||
&self,
|
||||
q: &Tensor,
|
||||
k: &Tensor,
|
||||
index_pos: usize,
|
||||
) -> Result<(Tensor, Tensor)> {
|
||||
let (_b_sz, _h, seq_len, _n_embd) = q.dims4()?;
|
||||
let cos = self.cos.narrow(0, index_pos, seq_len)?;
|
||||
let sin = self.sin.narrow(0, index_pos, seq_len)?;
|
||||
let q_embed = candle_nn::rotary_emb::rope(&q.contiguous()?, &cos, &sin)?;
|
||||
let k_embed = candle_nn::rotary_emb::rope(&k.contiguous()?, &cos, &sin)?;
|
||||
Ok((q_embed, k_embed))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct LayerWeights {
|
||||
// Attention components
|
||||
attention_wq: QMatMul,
|
||||
attention_wk: QMatMul,
|
||||
attention_wv: QMatMul,
|
||||
attention_wo: QMatMul,
|
||||
|
||||
// Specialized normalization for Q and K
|
||||
attention_q_norm: RmsNorm,
|
||||
attention_k_norm: RmsNorm,
|
||||
|
||||
// Layer normalization
|
||||
attention_norm: RmsNorm, // Applied before attention
|
||||
post_attention_norm: RmsNorm, // Applied after attention
|
||||
ffn_norm: RmsNorm, // Applied before feedforward
|
||||
post_ffn_norm: RmsNorm, // Applied after feedforward
|
||||
|
||||
// Feed-forward network
|
||||
mlp: Mlp,
|
||||
|
||||
// Attention parameters
|
||||
n_head: usize, // Number of query heads
|
||||
n_kv_head: usize, // Number of key-value heads
|
||||
head_dim: usize, // Dimension of each head
|
||||
q_dim: usize, // Total dimension for queries
|
||||
|
||||
sliding_window_size: Option<usize>,
|
||||
|
||||
rotary_embedding: RotaryEmbedding,
|
||||
neg_inf: Tensor,
|
||||
|
||||
// Cache
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
|
||||
// Tracing
|
||||
span_attn: tracing::Span,
|
||||
span_mlp: tracing::Span,
|
||||
}
|
||||
|
||||
impl LayerWeights {
|
||||
fn mask(
|
||||
&self,
|
||||
b_sz: usize,
|
||||
seq_len: usize,
|
||||
index_pos: usize,
|
||||
dtype: DType,
|
||||
device: &Device,
|
||||
) -> Result<Tensor> {
|
||||
let mask: Vec<_> = if let Some(sliding_window_size) = self.sliding_window_size {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| {
|
||||
(0..seq_len).map(move |j| {
|
||||
if i < j || j + sliding_window_size < i {
|
||||
0u32
|
||||
} else {
|
||||
1u32
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
} else {
|
||||
(0..seq_len)
|
||||
.flat_map(|i| (0..seq_len).map(move |j| if i < j { 0u32 } else { 1u32 }))
|
||||
.collect()
|
||||
};
|
||||
let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
|
||||
let mask = if index_pos > 0 {
|
||||
let mask0 = Tensor::zeros((seq_len, index_pos), DType::F32, device)?;
|
||||
Tensor::cat(&[&mask0, &mask], D::Minus1)?
|
||||
} else {
|
||||
mask
|
||||
};
|
||||
mask.expand((b_sz, 1, seq_len, seq_len + index_pos))?
|
||||
.to_dtype(dtype)
|
||||
}
|
||||
|
||||
fn forward_attn(
|
||||
&mut self,
|
||||
x: &Tensor,
|
||||
mask: Option<&Tensor>,
|
||||
index_pos: usize,
|
||||
) -> Result<Tensor> {
|
||||
let _enter = self.span_attn.enter();
|
||||
let (b_sz, seq_len, _) = x.dims3()?;
|
||||
|
||||
let q = self.attention_wq.forward(x)?;
|
||||
let k = self.attention_wk.forward(x)?;
|
||||
let v = self.attention_wv.forward(x)?;
|
||||
|
||||
let q = q
|
||||
.reshape((b_sz, seq_len, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let k = k
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seq_len, self.n_kv_head, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let q = self.attention_q_norm.forward(&q.contiguous()?)?;
|
||||
let k = self.attention_k_norm.forward(&k.contiguous()?)?;
|
||||
|
||||
let (q, k) = self
|
||||
.rotary_embedding
|
||||
.apply_rotary_emb_qkv(&q, &k, index_pos)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((k_cache, v_cache)) => {
|
||||
if index_pos == 0 {
|
||||
(k, v)
|
||||
} else {
|
||||
let k = Tensor::cat(&[k_cache, &k], 2)?; // concat on seq dim
|
||||
let v = Tensor::cat(&[v_cache, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone())); // update cache
|
||||
|
||||
// Repeat KV for GQA
|
||||
let k = crate::utils::repeat_kv(k, self.n_head / self.n_kv_head)?;
|
||||
let v = crate::utils::repeat_kv(v, self.n_head / self.n_kv_head)?;
|
||||
|
||||
// Scaled Dot-Product Attention
|
||||
let scale = 1.0 / (self.head_dim as f64).sqrt();
|
||||
let mut attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
if let Some(mask) = mask {
|
||||
let mask = mask.broadcast_as(attn_weights.shape())?;
|
||||
let neg_inf = self.neg_inf.broadcast_as(attn_weights.dims())?;
|
||||
attn_weights = mask.eq(0u32)?.where_cond(&neg_inf, &attn_weights)?;
|
||||
}
|
||||
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
|
||||
let attn_output = attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seq_len, self.q_dim))?;
|
||||
|
||||
self.attention_wo.forward(&attn_output)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ModelWeights {
|
||||
tok_embeddings: Embedding,
|
||||
embedding_length: usize,
|
||||
layers: Vec<LayerWeights>,
|
||||
norm: RmsNorm,
|
||||
output: QMatMul,
|
||||
span: tracing::Span,
|
||||
span_output: tracing::Span,
|
||||
}
|
||||
|
||||
impl ModelWeights {
|
||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||
ct: gguf_file::Content,
|
||||
reader: &mut R,
|
||||
device: &Device,
|
||||
) -> Result<Self> {
|
||||
let md_get = |s: &str| match ct.metadata.get(s) {
|
||||
None => candle::bail!("cannot find {s} in metadata"),
|
||||
Some(v) => Ok(v),
|
||||
};
|
||||
|
||||
let head_count = md_get("gemma3.attention.head_count")?.to_u32()? as usize;
|
||||
let head_count_kv = md_get("gemma3.attention.head_count_kv")?.to_u32()? as usize;
|
||||
let block_count = md_get("gemma3.block_count")?.to_u32()? as usize;
|
||||
let embedding_length = md_get("gemma3.embedding_length")?.to_u32()? as usize;
|
||||
let key_length = md_get("gemma3.attention.key_length")?.to_u32()? as usize;
|
||||
let _value_length = md_get("gemma3.attention.value_length")?.to_u32()? as usize;
|
||||
let rms_norm_eps = md_get("gemma3.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
|
||||
let sliding_window_size = md_get("gemma3.attention.sliding_window")?.to_u32()? as usize;
|
||||
|
||||
let sliding_window_type = md_get("gemma3.attention.sliding_window_type")
|
||||
.and_then(|m| Ok(m.to_u32()? as usize))
|
||||
.unwrap_or(DEFAULT_SLIDING_WINDOW_TYPE);
|
||||
|
||||
let rope_freq_base = md_get("gemma3.rope.freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY);
|
||||
|
||||
let rope_freq_base_sliding = md_get("gemma3.rope.local_freq_base")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SLIDING);
|
||||
|
||||
// Unused in Llama.cpp so we aren't using it here.
|
||||
let _rope_freq_scaling_factor = md_get("gemma3.rope.scaling.factor")
|
||||
.and_then(|m| m.to_f32())
|
||||
.unwrap_or(DEFAULT_ROPE_FREQUENCY_SCALE_FACTOR);
|
||||
|
||||
// Compute the dimensions for queries, keys, and values
|
||||
// These are the total dimensions when projected across all heads
|
||||
let q_dim = head_count * key_length;
|
||||
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
// Load token embeddings and output projection
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = match ct.tensor(reader, "output.weight", device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(_) => ct.tensor(reader, "token_embd.weight", device)?, // Use tied weights if output.weight doesn't exist
|
||||
};
|
||||
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
|
||||
let attention_wq = ct.tensor(reader, &format!("{prefix}.attn_q.weight"), device)?;
|
||||
let attention_wk = ct.tensor(reader, &format!("{prefix}.attn_k.weight"), device)?;
|
||||
let attention_wv = ct.tensor(reader, &format!("{prefix}.attn_v.weight"), device)?;
|
||||
let attention_wo =
|
||||
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
|
||||
|
||||
let attention_q_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_q_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_k_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_k_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.attn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_attention_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(
|
||||
reader,
|
||||
&format!("{prefix}.post_attention_norm.weight"),
|
||||
device,
|
||||
)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let post_ffn_norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, &format!("{prefix}.post_ffw_norm.weight"), device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
|
||||
let feed_forward_gate =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
|
||||
let feed_forward_up = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
|
||||
let feed_forward_down =
|
||||
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
|
||||
|
||||
let mlp = Mlp {
|
||||
feed_forward_gate: QMatMul::from_qtensor(feed_forward_gate)?,
|
||||
feed_forward_up: QMatMul::from_qtensor(feed_forward_up)?,
|
||||
feed_forward_down: QMatMul::from_qtensor(feed_forward_down)?,
|
||||
};
|
||||
|
||||
// Sliding window pattern hardcoded to 6 because it's not explicitly defined
|
||||
let is_sliding = (layer_idx + 1) % sliding_window_type > 0;
|
||||
let sliding_window_size = is_sliding.then_some(sliding_window_size);
|
||||
let layer_rope_frequency = if is_sliding {
|
||||
rope_freq_base_sliding
|
||||
} else {
|
||||
rope_freq_base
|
||||
};
|
||||
|
||||
let rotary_embedding = RotaryEmbedding::new(key_length, layer_rope_frequency, device)?;
|
||||
|
||||
// Tracing spans
|
||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||
let span_mlp = tracing::span!(tracing::Level::TRACE, "attn-mlp");
|
||||
|
||||
layers.push(LayerWeights {
|
||||
attention_wq: QMatMul::from_qtensor(attention_wq)?,
|
||||
attention_wk: QMatMul::from_qtensor(attention_wk)?,
|
||||
attention_wv: QMatMul::from_qtensor(attention_wv)?,
|
||||
attention_wo: QMatMul::from_qtensor(attention_wo)?,
|
||||
attention_q_norm,
|
||||
attention_k_norm,
|
||||
attention_norm,
|
||||
post_attention_norm,
|
||||
ffn_norm,
|
||||
post_ffn_norm,
|
||||
mlp,
|
||||
n_head: head_count,
|
||||
n_kv_head: head_count_kv,
|
||||
head_dim: key_length,
|
||||
q_dim,
|
||||
sliding_window_size,
|
||||
rotary_embedding,
|
||||
neg_inf: neg_inf.clone(),
|
||||
kv_cache: None,
|
||||
span_attn,
|
||||
span_mlp,
|
||||
})
|
||||
}
|
||||
|
||||
let span = tracing::span!(tracing::Level::TRACE, "model");
|
||||
let span_output = tracing::span!(tracing::Level::TRACE, "output");
|
||||
|
||||
Ok(Self {
|
||||
tok_embeddings: Embedding::new(tok_embeddings, embedding_length),
|
||||
embedding_length,
|
||||
layers,
|
||||
norm,
|
||||
output: QMatMul::from_qtensor(output)?,
|
||||
span,
|
||||
span_output,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, x: &Tensor, index_pos: usize) -> Result<Tensor> {
|
||||
let (b_sz, seq_len) = x.dims2()?;
|
||||
let _enter = self.span.enter();
|
||||
|
||||
let mut layer_in = self.tok_embeddings.forward(x)?;
|
||||
layer_in = (layer_in * (self.embedding_length as f64).sqrt())?;
|
||||
|
||||
for layer in self.layers.iter_mut() {
|
||||
let attention_mask = if seq_len == 1 {
|
||||
None
|
||||
} else {
|
||||
Some(layer.mask(b_sz, seq_len, index_pos, x.dtype(), x.device())?)
|
||||
};
|
||||
|
||||
// Attention block
|
||||
let residual = &layer_in;
|
||||
let x = layer.attention_norm.forward(&layer_in)?;
|
||||
let x = layer.forward_attn(&x, attention_mask.as_ref(), index_pos)?;
|
||||
let x = layer.post_attention_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
|
||||
// Feed-forward block
|
||||
let _enter = layer.span_mlp.enter();
|
||||
let residual = &x;
|
||||
let x = layer.ffn_norm.forward(&x)?;
|
||||
let x = layer.mlp.forward(&x)?;
|
||||
let x = layer.post_ffn_norm.forward(&x)?;
|
||||
let x = (x + residual)?;
|
||||
drop(_enter);
|
||||
|
||||
layer_in = x;
|
||||
}
|
||||
|
||||
let _enter = self.span_output.enter();
|
||||
|
||||
let x = layer_in.i((.., seq_len - 1, ..))?;
|
||||
let x = self.norm.forward(&x)?;
|
||||
let output = self.output.forward(&x)?;
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user