mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
14 Commits
0.5.1
...
dependabot
Author | SHA1 | Date | |
---|---|---|---|
84cd5158ad | |||
7abc3b8cd7 | |||
46012ed31f | |||
f3fade3b03 | |||
ea260aeffd | |||
0814dfd148 | |||
3ceca9901a | |||
1df2bddccf | |||
6f0b807ffd | |||
d54e02d73d | |||
45e235a747 | |||
31cf64147b | |||
77ea479a18 | |||
72e7ca529a |
@ -43,9 +43,9 @@ candle-onnx = { path = "./candle-onnx", version = "0.5.1" }
|
|||||||
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
candle-transformers = { path = "./candle-transformers", version = "0.5.1" }
|
||||||
clap = { version = "4.2.4", features = ["derive"] }
|
clap = { version = "4.2.4", features = ["derive"] }
|
||||||
criterion = { version = "0.5.1", default-features=false }
|
criterion = { version = "0.5.1", default-features=false }
|
||||||
cudarc = { version = "0.11.1", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
cudarc = { version = "0.11.4", features = ["std", "cublas", "cublaslt", "curand", "driver", "nvrtc", "f16", "cuda-version-from-build-system", "dynamic-linking"], default-features=false }
|
||||||
fancy-regex = "0.13.0"
|
fancy-regex = "0.13.0"
|
||||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
gemm = { version = "0.18.0", features = ["wasm-simd128-enable"] }
|
||||||
hf-hub = "0.3.0"
|
hf-hub = "0.3.0"
|
||||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||||
hound = "3.5.1"
|
hound = "3.5.1"
|
||||||
|
@ -10,7 +10,7 @@ pub use utils::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
const USE_IM2COL_CONV1D: bool = true;
|
const USE_IM2COL_CONV1D: bool = true;
|
||||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
const USE_IM2COL_CONV2D: bool = true;
|
const USE_IM2COL_CONV2D: bool = true;
|
||||||
|
|
||||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||||
@ -2249,7 +2249,7 @@ impl BackendStorage for CpuStorage {
|
|||||||
&& params.dilation == 1
|
&& params.dilation == 1
|
||||||
&& params.padding == 0
|
&& params.padding == 0
|
||||||
&& params.output_padding == 0;
|
&& params.output_padding == 0;
|
||||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
if !kernel_l.is_contiguous() {
|
if !kernel_l.is_contiguous() {
|
||||||
|
@ -16,7 +16,7 @@ mod error;
|
|||||||
mod utils;
|
mod utils;
|
||||||
pub use device::{CudaDevice, DeviceId};
|
pub use device::{CudaDevice, DeviceId};
|
||||||
pub use error::{CudaError, WrapErr};
|
pub use error::{CudaError, WrapErr};
|
||||||
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, S};
|
pub use utils::{Map1, Map1Any, Map2, Map2Any, Map2InPlace, Map3, S};
|
||||||
|
|
||||||
pub enum SlicePtrOrNull<T> {
|
pub enum SlicePtrOrNull<T> {
|
||||||
Ptr(CudaSlice<T>),
|
Ptr(CudaSlice<T>),
|
||||||
@ -630,6 +630,31 @@ impl<'a> Map2 for Conv2D<'a> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct Col2Im1D {
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Map1 for Col2Im1D {
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
col: &CudaSlice<T>,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
l: &Layout,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||||
|
let stride = self.stride;
|
||||||
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
|
let dst_el = b_size * c_out * l_out;
|
||||||
|
let mut im = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||||
|
|
||||||
|
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||||
|
let params = (dst_el, l_out, l_in, c_out, k_size, stride, col, &mut im);
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(im)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
@ -1366,9 +1391,55 @@ impl BackendStorage for CudaStorage {
|
|||||||
kernel_l: &Layout,
|
kernel_l: &Layout,
|
||||||
params: &crate::conv::ParamsConvTranspose1D,
|
params: &crate::conv::ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
|
|
||||||
let device = self.device().clone();
|
let device = self.device().clone();
|
||||||
let slice =
|
let can_use_col2im = kernel_l.is_contiguous()
|
||||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
&& params.dilation == 1
|
||||||
|
&& params.padding == 0
|
||||||
|
&& params.output_padding == 0;
|
||||||
|
let slice = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
|
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||||
|
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||||
|
if !kernel_l.is_contiguous() {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if c_in != c_in2 {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||||
|
l.shape(),
|
||||||
|
kernel_l.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let col = {
|
||||||
|
// This merges the last two dimensions of the kernel together.
|
||||||
|
let kernel_l_mm = Layout::new(
|
||||||
|
(b_size, c_in, k_size * c_out).into(),
|
||||||
|
vec![0, k_size * c_out, 1],
|
||||||
|
kernel_l.start_offset(),
|
||||||
|
);
|
||||||
|
self.matmul(
|
||||||
|
kernel,
|
||||||
|
(
|
||||||
|
b_size,
|
||||||
|
/* m */ l_in,
|
||||||
|
/* n */ c_out * k_size,
|
||||||
|
/* k */ c_in,
|
||||||
|
),
|
||||||
|
&l.transpose(1, 2)?,
|
||||||
|
&kernel_l_mm,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||||
|
Col2Im1D {
|
||||||
|
stride: params.stride,
|
||||||
|
}
|
||||||
|
.map(&col.slice, &device, &col_l)?
|
||||||
|
} else {
|
||||||
|
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?
|
||||||
|
};
|
||||||
Ok(Self { slice, device })
|
Ok(Self { slice, device })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +54,44 @@ pub trait Map2 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub trait Map3 {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
|
&self,
|
||||||
|
src1: &CudaSlice<T>,
|
||||||
|
layout1: &Layout,
|
||||||
|
src2: &CudaSlice<T>,
|
||||||
|
layout2: &Layout,
|
||||||
|
src3: &CudaSlice<T>,
|
||||||
|
layout3: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>>;
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn map(
|
||||||
|
&self,
|
||||||
|
s1: &S,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &S,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &S,
|
||||||
|
l3: &Layout,
|
||||||
|
d: &CudaDevice,
|
||||||
|
) -> Result<S> {
|
||||||
|
let out = match (s1, s2, s3) {
|
||||||
|
(S::U8(s1), S::U8(s2), S::U8(s3)) => S::U8(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::U32(s1), S::U32(s2), S::U32(s3)) => S::U32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::I64(s1), S::I64(s2), S::I64(s3)) => S::I64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::BF16(s1), S::BF16(s2), S::BF16(s3)) => S::BF16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F16(s1), S::F16(s2), S::F16(s3)) => S::F16(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F32(s1), S::F32(s2), S::F32(s3)) => S::F32(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
(S::F64(s1), S::F64(s2), S::F64(s3)) => S::F64(self.f(s1, l1, s2, l2, s3, l3, d)?),
|
||||||
|
_ => Err(CudaError::InternalError("dtype mismatch in ternary op"))?,
|
||||||
|
};
|
||||||
|
Ok(out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub trait Map2InPlace {
|
pub trait Map2InPlace {
|
||||||
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
fn f<T: DeviceRepr + WithDType + ValidAsZeroBits>(
|
||||||
&self,
|
&self,
|
||||||
|
@ -824,8 +824,64 @@ impl BackendStorage for MetalStorage {
|
|||||||
k_layout: &Layout,
|
k_layout: &Layout,
|
||||||
params: &ParamsConvTranspose1D,
|
params: &ParamsConvTranspose1D,
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
|
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||||
|
|
||||||
|
let can_use_col2im = k_layout.is_contiguous()
|
||||||
|
&& params.dilation == 1
|
||||||
|
&& params.padding == 0
|
||||||
|
&& params.output_padding == 0;
|
||||||
let l_out = params.l_out();
|
let l_out = params.l_out();
|
||||||
let dst_el = params.c_out * l_out * params.b_size;
|
let dst_el = params.c_out * l_out * params.b_size;
|
||||||
|
|
||||||
|
let buffer = if USE_COL2IM_CONV1D_TR && can_use_col2im {
|
||||||
|
let (b_size, c_in, l_in) = layout.shape().dims3()?;
|
||||||
|
let (c_in2, c_out, k_size) = k_layout.shape().dims3()?;
|
||||||
|
if c_in != c_in2 {
|
||||||
|
crate::bail!(
|
||||||
|
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||||
|
layout.shape(),
|
||||||
|
k_layout.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
let buffer = self
|
||||||
|
.device
|
||||||
|
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||||
|
|
||||||
|
let command_buffer = self.device.command_buffer()?;
|
||||||
|
let name = match self.dtype {
|
||||||
|
DType::F32 => "col2im1d_f32",
|
||||||
|
DType::U32 => "col2im1d_u32",
|
||||||
|
DType::U8 => "col2im1d_u8",
|
||||||
|
dtype => crate::bail!("metal col2im1d {dtype:?} not implemented"),
|
||||||
|
};
|
||||||
|
let col = {
|
||||||
|
// This merges the last two dimensions of the kernel together.
|
||||||
|
let kernel_l_mm = Layout::new(
|
||||||
|
(b_size, c_in, k_size * c_out).into(),
|
||||||
|
vec![0, k_size * c_out, 1],
|
||||||
|
k_layout.start_offset(),
|
||||||
|
);
|
||||||
|
self.matmul(
|
||||||
|
k,
|
||||||
|
(b_size, l_in, c_out * k_size, c_in),
|
||||||
|
&layout.transpose(1, 2)?,
|
||||||
|
&kernel_l_mm,
|
||||||
|
)?
|
||||||
|
};
|
||||||
|
candle_metal_kernels::call_col2im1d(
|
||||||
|
&self.device.device,
|
||||||
|
&command_buffer,
|
||||||
|
&self.device.kernels,
|
||||||
|
name,
|
||||||
|
&[b_size, l_in, c_out, k_size],
|
||||||
|
params.k_size,
|
||||||
|
params.stride,
|
||||||
|
BufferOffset::zero_offset(&col.buffer),
|
||||||
|
&buffer,
|
||||||
|
)
|
||||||
|
.map_err(MetalError::from)?;
|
||||||
|
buffer
|
||||||
|
} else {
|
||||||
let buffer = self
|
let buffer = self
|
||||||
.device
|
.device
|
||||||
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
.new_buffer(dst_el, self.dtype, "conv_transpose1d")?;
|
||||||
@ -862,6 +918,8 @@ impl BackendStorage for MetalStorage {
|
|||||||
&buffer,
|
&buffer,
|
||||||
)
|
)
|
||||||
.map_err(MetalError::from)?;
|
.map_err(MetalError::from)?;
|
||||||
|
buffer
|
||||||
|
};
|
||||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -141,6 +141,8 @@ enum WhichModel {
|
|||||||
V2,
|
V2,
|
||||||
#[value(name = "3")]
|
#[value(name = "3")]
|
||||||
V3,
|
V3,
|
||||||
|
#[value(name = "3-medium")]
|
||||||
|
V3Medium,
|
||||||
#[value(name = "2-old")]
|
#[value(name = "2-old")]
|
||||||
V2Old,
|
V2Old,
|
||||||
PuffinPhiV2,
|
PuffinPhiV2,
|
||||||
@ -254,6 +256,7 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
WhichModel::V2 | WhichModel::V2Old => "microsoft/phi-2".to_string(),
|
||||||
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
WhichModel::V3 => "microsoft/Phi-3-mini-4k-instruct".to_string(),
|
||||||
|
WhichModel::V3Medium => "microsoft/Phi-3-medium-4k-instruct".to_string(),
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
"lmz/candle-quantized-phi".to_string()
|
"lmz/candle-quantized-phi".to_string()
|
||||||
}
|
}
|
||||||
@ -273,6 +276,7 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
WhichModel::V2Old => "834565c23f9b28b96ccbeabe614dd906b6db551a".to_string(),
|
||||||
WhichModel::V2
|
WhichModel::V2
|
||||||
| WhichModel::V3
|
| WhichModel::V3
|
||||||
|
| WhichModel::V3Medium
|
||||||
| WhichModel::PuffinPhiV2
|
| WhichModel::PuffinPhiV2
|
||||||
| WhichModel::PhiHermes => "main".to_string(),
|
| WhichModel::PhiHermes => "main".to_string(),
|
||||||
}
|
}
|
||||||
@ -287,7 +291,8 @@ fn main() -> Result<()> {
|
|||||||
| WhichModel::V1_5
|
| WhichModel::V1_5
|
||||||
| WhichModel::V2
|
| WhichModel::V2
|
||||||
| WhichModel::V2Old
|
| WhichModel::V2Old
|
||||||
| WhichModel::V3 => repo.get("tokenizer.json")?,
|
| WhichModel::V3
|
||||||
|
| WhichModel::V3Medium => repo.get("tokenizer.json")?,
|
||||||
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
WhichModel::PuffinPhiV2 | WhichModel::PhiHermes => {
|
||||||
repo.get("tokenizer-puffin-phi-v2.json")?
|
repo.get("tokenizer-puffin-phi-v2.json")?
|
||||||
}
|
}
|
||||||
@ -303,14 +308,14 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
WhichModel::V2 | WhichModel::V2Old => vec![repo.get("model-v2-q4k.gguf")?],
|
||||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||||
WhichModel::V3 => anyhow::bail!(
|
WhichModel::V3 | WhichModel::V3Medium => anyhow::bail!(
|
||||||
"use the quantized or quantized-phi examples for quantized phi-v3"
|
"use the quantized or quantized-phi examples for quantized phi-v3"
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
match args.model {
|
match args.model {
|
||||||
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
WhichModel::V1 | WhichModel::V1_5 => vec![repo.get("model.safetensors")?],
|
||||||
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 => {
|
WhichModel::V2 | WhichModel::V2Old | WhichModel::V3 | WhichModel::V3Medium => {
|
||||||
candle_examples::hub_load_safetensors(
|
candle_examples::hub_load_safetensors(
|
||||||
&repo,
|
&repo,
|
||||||
"model.safetensors.index.json",
|
"model.safetensors.index.json",
|
||||||
@ -332,7 +337,7 @@ fn main() -> Result<()> {
|
|||||||
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
WhichModel::V2 | WhichModel::V2Old => Config::v2(),
|
||||||
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
WhichModel::PuffinPhiV2 => Config::puffin_phi_v2(),
|
||||||
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
WhichModel::PhiHermes => Config::phi_hermes_1_3b(),
|
||||||
WhichModel::V3 => {
|
WhichModel::V3 | WhichModel::V3Medium => {
|
||||||
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
panic!("use the quantized or quantized-phi examples for quantized phi-v3")
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
@ -352,7 +357,9 @@ fn main() -> Result<()> {
|
|||||||
let dtype = match args.dtype {
|
let dtype = match args.dtype {
|
||||||
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
Some(dtype) => std::str::FromStr::from_str(&dtype)?,
|
||||||
None => {
|
None => {
|
||||||
if args.model == WhichModel::V3 && device.is_cuda() {
|
if (args.model == WhichModel::V3 || args.model == WhichModel::V3Medium)
|
||||||
|
&& device.is_cuda()
|
||||||
|
{
|
||||||
DType::BF16
|
DType::BF16
|
||||||
} else {
|
} else {
|
||||||
DType::F32
|
DType::F32
|
||||||
@ -368,7 +375,7 @@ fn main() -> Result<()> {
|
|||||||
let phi = Phi::new(&config, vb)?;
|
let phi = Phi::new(&config, vb)?;
|
||||||
Model::Phi(phi)
|
Model::Phi(phi)
|
||||||
}
|
}
|
||||||
WhichModel::V3 => {
|
WhichModel::V3 | WhichModel::V3Medium => {
|
||||||
let config_filename = repo.get("config.json")?;
|
let config_filename = repo.get("config.json")?;
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
let config = std::fs::read_to_string(config_filename)?;
|
||||||
let config: Phi3Config = serde_json::from_str(&config)?;
|
let config: Phi3Config = serde_json::from_str(&config)?;
|
||||||
|
@ -217,7 +217,6 @@ fn main() -> anyhow::Result<()> {
|
|||||||
match args.which {
|
match args.which {
|
||||||
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
Which::Phi2 => Model::Phi2(Phi2::from_gguf(model, &mut file, &device)?),
|
||||||
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
Which::Phi3 => Model::Phi3(Phi3::from_gguf(
|
||||||
1,
|
|
||||||
args.use_flash_attn,
|
args.use_flash_attn,
|
||||||
model,
|
model,
|
||||||
&mut file,
|
&mut file,
|
||||||
|
@ -97,6 +97,50 @@ __device__ void im2col1d(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
__device__ void col2im1d(
|
||||||
|
const size_t dst_el,
|
||||||
|
const size_t l_out,
|
||||||
|
const size_t l_in,
|
||||||
|
const size_t c_out,
|
||||||
|
const size_t k_size,
|
||||||
|
const size_t stride,
|
||||||
|
const T *src,
|
||||||
|
T *dst
|
||||||
|
) {
|
||||||
|
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
|
// src: (b_size, l_in, c_out, l_k)
|
||||||
|
// dst: (b_size, c_out, l_out)
|
||||||
|
if (dst_i >= dst_el) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t dst_s0 = c_out * l_out;
|
||||||
|
const size_t dst_s1 = l_out;
|
||||||
|
const size_t src_s0 = c_out * k_size * l_in;
|
||||||
|
const size_t src_s1 = c_out * k_size;
|
||||||
|
const size_t src_s2 = k_size;
|
||||||
|
|
||||||
|
size_t tmp_dst_i = dst_i;
|
||||||
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
|
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||||
|
tmp_dst_i -= c_idx * dst_s1;
|
||||||
|
const int l_out_idx = tmp_dst_i;
|
||||||
|
|
||||||
|
dst[dst_i] = static_cast<T>(0);
|
||||||
|
|
||||||
|
int l_in_idx = l_out_idx / stride;
|
||||||
|
int k0 = l_out_idx - l_in_idx * stride;
|
||||||
|
// l_out_idx = l_in_idx * stride + k0
|
||||||
|
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||||
|
if (l_in_idx < l_in) {
|
||||||
|
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||||
|
dst[dst_i] += src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ void im2col(
|
__device__ void im2col(
|
||||||
const size_t dst_numel,
|
const size_t dst_numel,
|
||||||
@ -542,6 +586,20 @@ extern "C" __global__ void FN_NAME( \
|
|||||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define COL2IM1D_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const size_t dst_el, \
|
||||||
|
const size_t l_out, \
|
||||||
|
const size_t l_in, \
|
||||||
|
const size_t c_out, \
|
||||||
|
const size_t k_size, \
|
||||||
|
const size_t stride, \
|
||||||
|
const TYPENAME *src, \
|
||||||
|
TYPENAME *dst \
|
||||||
|
) { \
|
||||||
|
col2im1d<TYPENAME>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||||
extern "C" __global__ void FN_NAME( \
|
extern "C" __global__ void FN_NAME( \
|
||||||
const size_t dst_numel, \
|
const size_t dst_numel, \
|
||||||
@ -643,6 +701,7 @@ MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
|||||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||||
|
COL2IM1D_OP(__nv_bfloat16, col2im1d_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
@ -655,6 +714,7 @@ MAX_POOL2D_OP(__half, max_pool2d_f16)
|
|||||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||||
IM2COL_OP(__half, im2col_f16)
|
IM2COL_OP(__half, im2col_f16)
|
||||||
IM2COL1D_OP(__half, im2col1d_f16)
|
IM2COL1D_OP(__half, im2col1d_f16)
|
||||||
|
COL2IM1D_OP(__half, col2im1d_f16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
CONV1D_OP(float, float, conv1d_f32)
|
CONV1D_OP(float, float, conv1d_f32)
|
||||||
@ -701,3 +761,8 @@ IM2COL1D_OP(float, im2col1d_f32)
|
|||||||
IM2COL1D_OP(double, im2col1d_f64)
|
IM2COL1D_OP(double, im2col1d_f64)
|
||||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||||
|
|
||||||
|
COL2IM1D_OP(float, col2im1d_f32)
|
||||||
|
COL2IM1D_OP(double, col2im1d_f64)
|
||||||
|
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||||
|
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||||
|
@ -50,6 +50,15 @@ fast_sum(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
dst[dst_id] = shr[0];
|
dst[dst_id] = shr[0];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
|
||||||
|
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
|
||||||
|
}
|
||||||
|
return a;
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
@ -58,6 +67,70 @@ static __device__ __forceinline__ float warp_reduce_sum(float x) {
|
|||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LayerNorm implementation adapted from ggml, accumulation is made using f32.
|
||||||
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L477
|
||||||
|
template <typename T>
|
||||||
|
__device__ void layernorm(const T * x, T * dst, const T * alpha, const T * beta, const int ncols, const float eps) {
|
||||||
|
const int row = blockIdx.x*blockDim.y + threadIdx.y;
|
||||||
|
const int tid = threadIdx.x;
|
||||||
|
const int block_size = blockDim.x;
|
||||||
|
|
||||||
|
float2 mean_var = make_float2(0.f, 0.f);
|
||||||
|
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
const float xi = x[row*ncols + col];
|
||||||
|
mean_var.x += xi;
|
||||||
|
mean_var.y += xi * xi;
|
||||||
|
}
|
||||||
|
|
||||||
|
// sum up partial sums
|
||||||
|
mean_var = warp_reduce_sum(mean_var);
|
||||||
|
if (block_size > WARP_SIZE) {
|
||||||
|
__shared__ float2 s_sum[32];
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
s_sum[warp_id] = mean_var;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
mean_var = s_sum[lane_id];
|
||||||
|
mean_var = warp_reduce_sum(mean_var);
|
||||||
|
}
|
||||||
|
|
||||||
|
const float mean = mean_var.x / ncols;
|
||||||
|
const float var = mean_var.y / ncols - mean * mean;
|
||||||
|
const float inv_std = rsqrtf(var + eps);
|
||||||
|
|
||||||
|
if (alpha == nullptr && beta == nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (alpha == nullptr && beta != nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float b = static_cast<float>(beta[col]);
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs + b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else if (alpha != nullptr && beta == nullptr) {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float a = static_cast<float>(alpha[col]);
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs * a);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
for (int col = tid; col < ncols; col += block_size) {
|
||||||
|
float a = static_cast<float>(alpha[col]);
|
||||||
|
float b = static_cast<float>(beta[col]);
|
||||||
|
float lhs = (static_cast<float>(x[row*ncols + col]) - mean) * inv_std;
|
||||||
|
dst[row*ncols + col] = static_cast<T>(lhs * a + b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
// RmsNorm implementation adapted from ggml, accumulation is made using f32.
|
||||||
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
// https://github.com/ggerganov/llama.cpp/blob/d59bd97065cd7ded6c4ecab54b1d5e0b1b11e318/ggml-cuda.cu#L523
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -461,6 +534,13 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
rmsnorm<TYPENAME>(src, dst, alpha, n_cols, eps); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define LAYERNORM_OP(TYPENAME, FN_NAME) \
|
||||||
|
extern "C" __global__ void FN_NAME( \
|
||||||
|
const TYPENAME *src, TYPENAME *dst, const TYPENAME *alpha, \
|
||||||
|
const TYPENAME *beta, const int n_cols, const float eps) { \
|
||||||
|
layernorm<TYPENAME>(src, dst, alpha, beta, n_cols, eps); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
#define ROPE_OP(TYPENAME, FN_NAME, FN_NAME_I, FN_NAME_THD) \
|
||||||
extern "C" __global__ void FN_NAME_I( \
|
extern "C" __global__ void FN_NAME_I( \
|
||||||
const TYPENAME *src, \
|
const TYPENAME *src, \
|
||||||
@ -496,6 +576,7 @@ fast_argmax(const size_t src_numel, const size_t el_to_sum_per_block,
|
|||||||
#if __CUDA_ARCH__ >= 800
|
#if __CUDA_ARCH__ >= 800
|
||||||
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
SOFTMAX_OP(__nv_bfloat16, float, softmax_bf16)
|
||||||
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
RMSNORM_OP(__nv_bfloat16, rmsnorm_bf16)
|
||||||
|
LAYERNORM_OP(__nv_bfloat16, layernorm_bf16)
|
||||||
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
ROPE_OP(__nv_bfloat16, rope_bf16, rope_i_bf16, rope_thd_bf16)
|
||||||
SUM_OP(__nv_bfloat16, sum_bf16)
|
SUM_OP(__nv_bfloat16, sum_bf16)
|
||||||
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argmax_bf16, fast_sum_bf16)
|
||||||
@ -504,6 +585,7 @@ FAST_OP(__nv_bfloat16, fast_min_bf16, fast_max_bf16, fast_argmin_bf16, fast_argm
|
|||||||
#if __CUDA_ARCH__ >= 530
|
#if __CUDA_ARCH__ >= 530
|
||||||
SOFTMAX_OP(__half, float, softmax_f16)
|
SOFTMAX_OP(__half, float, softmax_f16)
|
||||||
RMSNORM_OP(__half, rmsnorm_f16)
|
RMSNORM_OP(__half, rmsnorm_f16)
|
||||||
|
LAYERNORM_OP(__half, layernorm_f16)
|
||||||
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
ROPE_OP(__half, rope_f16, rope_i_f16, rope_thd_f16)
|
||||||
SUM_OP(__half, sum_f16)
|
SUM_OP(__half, sum_f16)
|
||||||
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
FAST_OP(__half, fast_min_f16, fast_max_f16, fast_argmin_f16, fast_argmax_f16, fast_sum_f16)
|
||||||
@ -516,6 +598,8 @@ SOFTMAX_OP(float, float, softmax_f32)
|
|||||||
SOFTMAX_OP(double, double, softmax_f64)
|
SOFTMAX_OP(double, double, softmax_f64)
|
||||||
RMSNORM_OP(float, rmsnorm_f32)
|
RMSNORM_OP(float, rmsnorm_f32)
|
||||||
RMSNORM_OP(double, rmsnorm_f64)
|
RMSNORM_OP(double, rmsnorm_f64)
|
||||||
|
LAYERNORM_OP(float, layernorm_f32)
|
||||||
|
LAYERNORM_OP(double, layernorm_f64)
|
||||||
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
ROPE_OP(float, rope_f32, rope_i_f32, rope_thd_f32)
|
||||||
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
ROPE_OP(double, rope_f64, rope_i_f64, rope_thd_f64)
|
||||||
|
|
||||||
|
@ -68,6 +68,50 @@ METAL_FUNC void im2col(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
METAL_FUNC void col2im1d(
|
||||||
|
constant size_t &dst_el,
|
||||||
|
constant size_t &l_out,
|
||||||
|
constant size_t &l_in,
|
||||||
|
constant size_t &c_out,
|
||||||
|
constant size_t &k_size,
|
||||||
|
constant size_t &stride,
|
||||||
|
device const T *src,
|
||||||
|
device T *dst,
|
||||||
|
uint dst_i [[ thread_position_in_grid ]]
|
||||||
|
) {
|
||||||
|
// src: (b_size, l_in, c_out, l_k)
|
||||||
|
// dst: (b_size, c_out, l_out)
|
||||||
|
if (dst_i >= dst_el) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t dst_s0 = c_out * l_out;
|
||||||
|
const size_t dst_s1 = l_out;
|
||||||
|
const size_t src_s0 = c_out * k_size * l_in;
|
||||||
|
const size_t src_s1 = c_out * k_size;
|
||||||
|
const size_t src_s2 = k_size;
|
||||||
|
|
||||||
|
size_t tmp_dst_i = dst_i;
|
||||||
|
const size_t b_idx = tmp_dst_i / dst_s0;
|
||||||
|
tmp_dst_i -= b_idx * dst_s0;
|
||||||
|
const size_t c_idx = tmp_dst_i / dst_s1;
|
||||||
|
tmp_dst_i -= c_idx * dst_s1;
|
||||||
|
const int l_out_idx = tmp_dst_i;
|
||||||
|
|
||||||
|
dst[dst_i] = static_cast<T>(0);
|
||||||
|
|
||||||
|
int l_in_idx = l_out_idx / stride;
|
||||||
|
int k0 = l_out_idx - l_in_idx * stride;
|
||||||
|
// l_out_idx = l_in_idx * stride + k0
|
||||||
|
for (; k0 < k_size && l_in_idx >= 0; k0 += stride, --l_in_idx) {
|
||||||
|
if (l_in_idx < l_in) {
|
||||||
|
const size_t src_i = b_idx * src_s0 + l_in_idx * src_s1 + c_idx * src_s2 + k0;
|
||||||
|
dst[dst_i] += src[src_i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
METAL_FUNC void im2col1d(
|
METAL_FUNC void im2col1d(
|
||||||
constant size_t &dst_numel,
|
constant size_t &dst_numel,
|
||||||
@ -191,6 +235,21 @@ kernel void FN_NAME( \
|
|||||||
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
im2col1d<T>(dst_numel, l_out, l_k, stride, padding, dilation, src_dims, src_strides, src, dst, tid); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define COL2IM1D_OP(T, FN_NAME) \
|
||||||
|
kernel void FN_NAME( \
|
||||||
|
constant size_t &dst_el, \
|
||||||
|
constant size_t &l_out, \
|
||||||
|
constant size_t &l_in, \
|
||||||
|
constant size_t &c_out, \
|
||||||
|
constant size_t &k_size, \
|
||||||
|
constant size_t &stride, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
uint tid [[ thread_position_in_grid ]] \
|
||||||
|
) { \
|
||||||
|
col2im1d<T>(dst_el, l_out, l_in, c_out, k_size, stride, src, dst, tid); \
|
||||||
|
} \
|
||||||
|
|
||||||
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
#define UPSAMPLE_NEAREST2D_OP(TYPENAME, FN_NAME) \
|
||||||
kernel void FN_NAME( \
|
kernel void FN_NAME( \
|
||||||
constant size_t &w_out, \
|
constant size_t &w_out, \
|
||||||
@ -493,6 +552,10 @@ IM2COL_OP(uint32_t, im2col_u32)
|
|||||||
IM2COL_OP(bfloat, im2col_bf16)
|
IM2COL_OP(bfloat, im2col_bf16)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
COL2IM1D_OP(float, col2im1d_f32)
|
||||||
|
COL2IM1D_OP(uint8_t, col2im1d_u8)
|
||||||
|
COL2IM1D_OP(uint32_t, col2im1d_u32)
|
||||||
|
|
||||||
IM2COL1D_OP(float, im2col1d_f32)
|
IM2COL1D_OP(float, im2col1d_f32)
|
||||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||||
|
@ -739,6 +739,69 @@ pub fn call_rms_norm(
|
|||||||
|
|
||||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.set_threadgroup_memory_length(0, (width * 4).max(16) as u64);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_layer_norm(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
kernel_name: &'static str,
|
||||||
|
length: usize,
|
||||||
|
elements_to_sum: usize,
|
||||||
|
eps: f32,
|
||||||
|
input: &Buffer,
|
||||||
|
input_offset: usize,
|
||||||
|
alpha: &Buffer,
|
||||||
|
alpha_offset: usize,
|
||||||
|
beta: &Buffer,
|
||||||
|
beta_offset: usize,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Reduce, kernel_name)?;
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(
|
||||||
|
length,
|
||||||
|
elements_to_sum,
|
||||||
|
(input, input_offset),
|
||||||
|
output,
|
||||||
|
(alpha, alpha_offset),
|
||||||
|
(beta, beta_offset),
|
||||||
|
eps
|
||||||
|
)
|
||||||
|
);
|
||||||
|
|
||||||
|
let out_length = length / elements_to_sum;
|
||||||
|
|
||||||
|
let thread_group_count = MTLSize {
|
||||||
|
width: out_length as u64,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
let width = std::cmp::min(
|
||||||
|
pipeline.max_total_threads_per_threadgroup(),
|
||||||
|
elements_to_sum as u64,
|
||||||
|
)
|
||||||
|
.next_power_of_two();
|
||||||
|
|
||||||
|
let thread_group_size = MTLSize {
|
||||||
|
width,
|
||||||
|
height: 1,
|
||||||
|
depth: 1,
|
||||||
|
};
|
||||||
|
|
||||||
|
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.set_threadgroup_memory_length(0, (width * 8).max(32) as u64);
|
||||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
encoder.end_encoding();
|
encoder.end_encoding();
|
||||||
Ok(())
|
Ok(())
|
||||||
@ -1588,6 +1651,39 @@ pub fn call_im2col1d_strided(
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
pub fn call_col2im1d(
|
||||||
|
device: &Device,
|
||||||
|
command_buffer: &CommandBufferRef,
|
||||||
|
kernels: &Kernels,
|
||||||
|
name: &'static str,
|
||||||
|
shape: &[usize],
|
||||||
|
k_size: usize,
|
||||||
|
stride: usize,
|
||||||
|
input: BufferOffset,
|
||||||
|
output: &Buffer,
|
||||||
|
) -> Result<(), MetalKernelError> {
|
||||||
|
let pipeline = kernels.load_pipeline(device, Source::Conv, name)?;
|
||||||
|
let l_in = shape[1];
|
||||||
|
let c_out = shape[2];
|
||||||
|
let l_out = (l_in - 1) * stride + k_size;
|
||||||
|
let dst_el = shape[0] * c_out * l_out;
|
||||||
|
|
||||||
|
let encoder = command_buffer.new_compute_command_encoder();
|
||||||
|
let (thread_group_count, thread_group_size) = linear_split(&pipeline, dst_el);
|
||||||
|
encoder.set_compute_pipeline_state(&pipeline);
|
||||||
|
set_params!(
|
||||||
|
encoder,
|
||||||
|
(dst_el, l_out, l_in, c_out, k_size, stride, &input, output)
|
||||||
|
);
|
||||||
|
encoder.use_resource(input.buffer, metal::MTLResourceUsage::Read);
|
||||||
|
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||||
|
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||||
|
encoder.end_encoding();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn call_im2col_strided(
|
pub fn call_im2col_strided(
|
||||||
device: &Device,
|
device: &Device,
|
||||||
|
@ -353,6 +353,65 @@ METAL_FUNC void rmsnorm(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename T>
|
||||||
|
METAL_FUNC void layernorm(
|
||||||
|
constant size_t & src_numel,
|
||||||
|
constant size_t & el_to_sum_per_block,
|
||||||
|
device const T * src,
|
||||||
|
device T * dst,
|
||||||
|
device const T * alpha,
|
||||||
|
device const T * beta,
|
||||||
|
constant float & eps,
|
||||||
|
uint id,
|
||||||
|
uint tid,
|
||||||
|
uint dst_id,
|
||||||
|
uint block_dim,
|
||||||
|
threadgroup float * shared_memory
|
||||||
|
) {
|
||||||
|
size_t start_idx = dst_id * el_to_sum_per_block;
|
||||||
|
size_t stop_idx = min(start_idx + el_to_sum_per_block, src_numel);
|
||||||
|
size_t idx = start_idx + tid;
|
||||||
|
|
||||||
|
float tmp1 = 0;
|
||||||
|
float tmp2 = 0;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
tmp1 += float(src[idx]);
|
||||||
|
tmp2 += float(src[idx]) * float(src[idx]);
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
shared_memory[tid] = tmp1;
|
||||||
|
shared_memory[tid + block_dim] = tmp2;
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
for (uint s = block_dim / 2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
shared_memory[tid] = shared_memory[tid] + shared_memory[tid + s];
|
||||||
|
shared_memory[block_dim + tid] = shared_memory[block_dim + tid] + shared_memory[block_dim + tid + s];
|
||||||
|
}
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* wait for shared_memory[0] to be filled */
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
float mean = shared_memory[0] / float(el_to_sum_per_block);
|
||||||
|
float var = shared_memory[block_dim] / float(el_to_sum_per_block) - mean * mean;
|
||||||
|
float inv_norm = 1.0f / sqrt(var + eps);
|
||||||
|
idx = start_idx + tid;
|
||||||
|
while (idx < stop_idx) {
|
||||||
|
float val = (float(src[idx]) - mean) * inv_norm;
|
||||||
|
if (alpha != nullptr) {
|
||||||
|
val *= float(alpha[idx - start_idx]);
|
||||||
|
}
|
||||||
|
if (beta != nullptr) {
|
||||||
|
val += float(beta[idx - start_idx]);
|
||||||
|
}
|
||||||
|
dst[idx] = T(val);
|
||||||
|
idx += block_dim;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#define RMSNORM(NAME, T) \
|
#define RMSNORM(NAME, T) \
|
||||||
kernel void NAME( \
|
kernel void NAME( \
|
||||||
constant size_t &src_numel, \
|
constant size_t &src_numel, \
|
||||||
@ -371,6 +430,25 @@ kernel void NAME( \
|
|||||||
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
|
rmsnorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
|
#define LAYERNORM(NAME, T) \
|
||||||
|
kernel void NAME( \
|
||||||
|
constant size_t &src_numel, \
|
||||||
|
constant size_t &el_to_sum_per_block, \
|
||||||
|
device const T *src, \
|
||||||
|
device T *dst, \
|
||||||
|
device const T *alpha, \
|
||||||
|
device const T *beta, \
|
||||||
|
constant float &eps, \
|
||||||
|
uint id [[ thread_position_in_grid ]], \
|
||||||
|
uint tid [[ thread_index_in_threadgroup ]], \
|
||||||
|
uint dst_id [[ threadgroup_position_in_grid ]], \
|
||||||
|
uint block_dim [[ threads_per_threadgroup ]] \
|
||||||
|
) { \
|
||||||
|
threadgroup float shared_memory[THREADGROUP_SIZE]; \
|
||||||
|
shared_memory[tid] = 0; \
|
||||||
|
layernorm<T>(src_numel, el_to_sum_per_block, src, dst, alpha, beta, eps, id, tid, dst_id, block_dim, shared_memory); \
|
||||||
|
} \
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
METAL_FUNC void ropei(
|
METAL_FUNC void ropei(
|
||||||
constant size_t &bh,
|
constant size_t &bh,
|
||||||
@ -511,6 +589,8 @@ SOFTMAX(softmax_f32, float)
|
|||||||
SOFTMAX(softmax_f16, half)
|
SOFTMAX(softmax_f16, half)
|
||||||
RMSNORM(rmsnorm_f32, float)
|
RMSNORM(rmsnorm_f32, float)
|
||||||
RMSNORM(rmsnorm_f16, half)
|
RMSNORM(rmsnorm_f16, half)
|
||||||
|
LAYERNORM(layernorm_f32, float)
|
||||||
|
LAYERNORM(layernorm_f16, half)
|
||||||
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
|
ROPE(rope_f32, rope_i_f32, rope_thd_f32, float)
|
||||||
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
|
ROPE(rope_f16, rope_i_f16, rope_thd_f16, half)
|
||||||
|
|
||||||
@ -535,5 +615,6 @@ ARGMIN(fast_argmin_bf16, bfloat, HUGE_VALBF)
|
|||||||
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
ARGMAX(fast_argmax_bf16, bfloat, -HUGE_VALBF)
|
||||||
SOFTMAX(softmax_bf16, bfloat)
|
SOFTMAX(softmax_bf16, bfloat)
|
||||||
RMSNORM(rmsnorm_bf16, bfloat)
|
RMSNORM(rmsnorm_bf16, bfloat)
|
||||||
|
LAYERNORM(layernorm_bf16, bfloat)
|
||||||
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
|
ROPE(rope_bf16, rope_i_bf16, rope_thd_bf16, bfloat)
|
||||||
#endif
|
#endif
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
#include <metal_stdlib>
|
#include <metal_stdlib>
|
||||||
#
|
|
||||||
using namespace metal;
|
using namespace metal;
|
||||||
|
|
||||||
METAL_FUNC uint get_strided_index(
|
METAL_FUNC uint get_strided_index(
|
||||||
@ -57,27 +56,31 @@ kernel void FN_NAME(
|
|||||||
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
where_cond<T, ID>(numel, num_dims, dims, strides, strides_t, strides_f, ids, t, f, out, i); \
|
||||||
} \
|
} \
|
||||||
|
|
||||||
// WHERE_OP(float, int64_t, where_i64_f32)
|
WHERE_OP(half, uint32_t, where_u32_f16)
|
||||||
// WHERE_OP(double, int64_t, where_i64_f64)
|
WHERE_OP(float, uint32_t, where_u32_f32)
|
||||||
// WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
||||||
// WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
||||||
// WHERE_OP(int64_t, int64_t, where_i64_i64)
|
|
||||||
//
|
|
||||||
// WHERE_OP(float, uint32_t, where_u32_f32)
|
|
||||||
// WHERE_OP(double, uint32_t, where_u32_f64)
|
|
||||||
// WHERE_OP(uint8_t, uint32_t, where_u32_u8)
|
|
||||||
// WHERE_OP(uint32_t, uint32_t, where_u32_u32)
|
|
||||||
// WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
|
||||||
|
|
||||||
WHERE_OP(float, uint8_t, where_u8_f32)
|
|
||||||
WHERE_OP(half, uint8_t, where_u8_f16)
|
WHERE_OP(half, uint8_t, where_u8_f16)
|
||||||
|
WHERE_OP(float, uint8_t, where_u8_f32)
|
||||||
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
WHERE_OP(uint8_t, uint8_t, where_u8_u8)
|
||||||
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
WHERE_OP(uint32_t, uint8_t, where_u8_u32)
|
||||||
|
|
||||||
#if __METAL_VERSION__ >= 220
|
#if __METAL_VERSION__ >= 220
|
||||||
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
WHERE_OP(int64_t, uint8_t, where_u8_i64)
|
||||||
|
WHERE_OP(int64_t, uint32_t, where_u32_i64)
|
||||||
|
|
||||||
|
WHERE_OP(half, int64_t, where_i64_f16)
|
||||||
|
WHERE_OP(float, int64_t, where_i64_f32)
|
||||||
|
WHERE_OP(uint8_t, int64_t, where_i64_u8)
|
||||||
|
WHERE_OP(uint32_t, int64_t, where_i64_u32)
|
||||||
|
WHERE_OP(int64_t, int64_t, where_i64_i64)
|
||||||
|
#if defined(__HAVE_BFLOAT__)
|
||||||
|
WHERE_OP(bfloat, int64_t, where_i64_bf16)
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(__HAVE_BFLOAT__)
|
#if defined(__HAVE_BFLOAT__)
|
||||||
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
WHERE_OP(bfloat, uint8_t, where_u8_bf16)
|
||||||
|
WHERE_OP(bfloat, uint32_t, where_u32_bf16)
|
||||||
#endif
|
#endif
|
@ -1,30 +1,25 @@
|
|||||||
use candle::{DType, Device, Result, Shape, Tensor};
|
use candle::{Result, Tensor};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct Cache {
|
pub struct Cache {
|
||||||
all_data: Tensor,
|
// all_data is an option on a Tensor, this makes it possible to only create the actual tensor
|
||||||
|
// on the first call where the batch size is easily known.
|
||||||
|
// Also this makes it safe to clone a KvCache that has been reseted (as in it will not share
|
||||||
|
// its internal state with the cloned instance).
|
||||||
|
all_data: Option<Tensor>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
current_seq_len: usize,
|
current_seq_len: usize,
|
||||||
max_seq_len: usize,
|
max_seq_len: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Cache {
|
impl Cache {
|
||||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
pub fn new(dim: usize, max_seq_len: usize) -> Self {
|
||||||
dim: D,
|
Self {
|
||||||
shape: S,
|
all_data: None,
|
||||||
dtype: DType,
|
|
||||||
dev: &Device,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let shape = shape.into();
|
|
||||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
|
||||||
let max_seq_len = shape.dims()[dim];
|
|
||||||
let all_data = Tensor::zeros(shape, dtype, dev)?;
|
|
||||||
Ok(Self {
|
|
||||||
all_data,
|
|
||||||
dim,
|
dim,
|
||||||
current_seq_len: 0,
|
current_seq_len: 0,
|
||||||
max_seq_len,
|
max_seq_len,
|
||||||
})
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn dim(&self) -> usize {
|
pub fn dim(&self) -> usize {
|
||||||
@ -39,16 +34,34 @@ impl Cache {
|
|||||||
self.max_seq_len
|
self.max_seq_len
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn all_data(&self) -> &Tensor {
|
pub fn all_data(&self) -> &Option<Tensor> {
|
||||||
&self.all_data
|
&self.all_data
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn current_data(&self) -> Result<Tensor> {
|
pub fn current_data(&self) -> Result<Option<Tensor>> {
|
||||||
self.all_data.narrow(self.dim, 0, self.current_seq_len)
|
let data = match self.all_data.as_ref() {
|
||||||
|
None => None,
|
||||||
|
Some(d) => Some(d.narrow(self.dim, 0, self.current_seq_len)?),
|
||||||
|
};
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.current_seq_len = 0;
|
||||||
|
self.all_data = None;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
pub fn append(&mut self, src: &Tensor) -> Result<()> {
|
||||||
let seq_len = src.dim(self.dim)?;
|
let seq_len = src.dim(self.dim)?;
|
||||||
|
// This doesn't seem very idiomatic but because the creation can fail, it's tricky to use
|
||||||
|
// self.all_data.get_or_insert_with.
|
||||||
|
if self.all_data.is_none() {
|
||||||
|
let mut shape = src.dims().to_vec();
|
||||||
|
shape[self.dim] = self.max_seq_len;
|
||||||
|
let ad = Tensor::zeros(shape, src.dtype(), src.device())?;
|
||||||
|
self.all_data = Some(ad)
|
||||||
|
};
|
||||||
|
let ad = self.all_data.as_mut().unwrap();
|
||||||
if self.current_seq_len + seq_len > self.max_seq_len {
|
if self.current_seq_len + seq_len > self.max_seq_len {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
"kv-cache: above max-seq-len {}+{seq_len}>{}",
|
||||||
@ -56,8 +69,7 @@ impl Cache {
|
|||||||
self.max_seq_len
|
self.max_seq_len
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
self.all_data
|
ad.slice_set(src, self.dim, self.current_seq_len)?;
|
||||||
.slice_set(src, self.dim, self.current_seq_len)?;
|
|
||||||
self.current_seq_len += seq_len;
|
self.current_seq_len += seq_len;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@ -70,32 +82,66 @@ pub struct KvCache {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl KvCache {
|
impl KvCache {
|
||||||
pub fn new<S: Into<Shape>, D: candle::shape::Dim>(
|
pub fn new(dim: usize, max_seq_len: usize) -> Self {
|
||||||
dim: D,
|
let k = Cache::new(dim, max_seq_len);
|
||||||
shape: S,
|
let v = Cache::new(dim, max_seq_len);
|
||||||
dtype: DType,
|
Self { k, v }
|
||||||
dev: &Device,
|
|
||||||
) -> Result<Self> {
|
|
||||||
let shape = shape.into();
|
|
||||||
let dim = dim.to_index(&shape, "kv-cache")?;
|
|
||||||
let k = Cache::new(dim, &shape, dtype, dev)?;
|
|
||||||
let v = Cache::new(dim, &shape, dtype, dev)?;
|
|
||||||
Ok(Self { k, v })
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn k(&self) -> Result<Tensor> {
|
pub fn k_cache(&self) -> &Cache {
|
||||||
|
&self.k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v_cache(&self) -> &Cache {
|
||||||
|
&self.v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k_cache_mut(&mut self) -> &mut Cache {
|
||||||
|
&mut self.k
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn v_cache_mut(&mut self) -> &mut Cache {
|
||||||
|
&mut self.v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn k(&self) -> Result<Option<Tensor>> {
|
||||||
self.k.current_data()
|
self.k.current_data()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn v(&self) -> Result<Tensor> {
|
pub fn v(&self) -> Result<Option<Tensor>> {
|
||||||
self.v.current_data()
|
self.v.current_data()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
pub fn append(&mut self, k: &Tensor, v: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
self.k.append(k)?;
|
self.k.append(k)?;
|
||||||
self.v.append(v)?;
|
self.v.append(v)?;
|
||||||
let k = self.k.current_data()?;
|
let out_k = self.k.current_data()?;
|
||||||
let v = self.v.current_data()?;
|
let out_v = self.v.current_data()?;
|
||||||
|
let k = match out_k {
|
||||||
|
None => {
|
||||||
|
let mut shape = k.dims().to_vec();
|
||||||
|
shape[self.k.dim] = 0;
|
||||||
|
Tensor::zeros(shape, k.dtype(), k.device())?
|
||||||
|
}
|
||||||
|
Some(k) => k,
|
||||||
|
};
|
||||||
|
let v = match out_v {
|
||||||
|
None => {
|
||||||
|
let mut shape = v.dims().to_vec();
|
||||||
|
shape[self.k.dim] = 0;
|
||||||
|
Tensor::zeros(shape, v.dtype(), v.device())?
|
||||||
|
}
|
||||||
|
Some(v) => v,
|
||||||
|
};
|
||||||
Ok((k, v))
|
Ok((k, v))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn current_seq_len(&self) -> usize {
|
||||||
|
self.k.current_seq_len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.k.reset();
|
||||||
|
self.v.reset();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -11,8 +11,8 @@
|
|||||||
//! use candle_nn::{LayerNorm, Module};
|
//! use candle_nn::{LayerNorm, Module};
|
||||||
//! # fn main() -> candle::Result<()> {
|
//! # fn main() -> candle::Result<()> {
|
||||||
//!
|
//!
|
||||||
//! let w = Tensor::new(1f32, &Cpu)?;
|
//! let w = Tensor::new(&[1f32, 1f32, 1f32], &Cpu)?;
|
||||||
//! let b = Tensor::new(0f32, &Cpu)?;
|
//! let b = Tensor::new(&[0f32, 0f32, 0f32], &Cpu)?;
|
||||||
//! let layer = LayerNorm::new(w, b, 1e-5);
|
//! let layer = LayerNorm::new(w, b, 1e-5);
|
||||||
//!
|
//!
|
||||||
//! let xs = Tensor::new(
|
//! let xs = Tensor::new(
|
||||||
@ -107,6 +107,11 @@ impl LayerNorm {
|
|||||||
|
|
||||||
impl Module for LayerNorm {
|
impl Module for LayerNorm {
|
||||||
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
fn forward(&self, x: &Tensor) -> Result<Tensor> {
|
||||||
|
if x.is_contiguous() && self.remove_mean {
|
||||||
|
if let Some(bias) = self.bias.as_ref() {
|
||||||
|
return crate::ops::layer_norm(x, &self.weight, bias, self.eps as f32);
|
||||||
|
}
|
||||||
|
}
|
||||||
let x_dtype = x.dtype();
|
let x_dtype = x.dtype();
|
||||||
let internal_dtype = match x_dtype {
|
let internal_dtype = match x_dtype {
|
||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor};
|
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
|
||||||
use rayon::prelude::*;
|
use rayon::prelude::*;
|
||||||
|
|
||||||
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
/// Applies the softmax function to the input tensor, rescaling the element so that elements on
|
||||||
@ -39,7 +39,7 @@ pub fn silu(xs: &Tensor) -> Result<Tensor> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
pub fn swiglu(xs: &Tensor) -> Result<Tensor> {
|
||||||
let xs = xs.chunk(2, candle::D::Minus1)?;
|
let xs = xs.chunk(2, D::Minus1)?;
|
||||||
&xs[0].silu()? * &xs[1]
|
&xs[0].silu()? * &xs[1]
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -620,15 +620,15 @@ pub fn rms_norm_slow(x: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
|||||||
DType::F16 | DType::BF16 => DType::F32,
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
d => d,
|
d => d,
|
||||||
};
|
};
|
||||||
let hidden_size = x.dim(candle::D::Minus1)?;
|
let hidden_size = x.dim(D::Minus1)?;
|
||||||
let x = x.to_dtype(internal_dtype)?;
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
let norm_x = (x.sqr()?.sum_keepdim(candle::D::Minus1)? / hidden_size as f64)?;
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||||
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
|
x_normed.to_dtype(x_dtype)?.broadcast_mul(alpha)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
||||||
let hidden_size_xs = xs.dim(candle::D::Minus1)?;
|
let hidden_size_xs = xs.dim(D::Minus1)?;
|
||||||
let hidden_size_alpha = alpha.dims1()?;
|
let hidden_size_alpha = alpha.dims1()?;
|
||||||
if hidden_size_xs != hidden_size_alpha {
|
if hidden_size_xs != hidden_size_alpha {
|
||||||
candle::bail!(
|
candle::bail!(
|
||||||
@ -640,6 +640,254 @@ pub fn rms_norm(xs: &Tensor, alpha: &Tensor, eps: f32) -> Result<Tensor> {
|
|||||||
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
|
xs.apply_op2_no_bwd(alpha, &RmsNorm { eps })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LayerNorm {
|
||||||
|
eps: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl candle::CustomOp3 for LayerNorm {
|
||||||
|
fn name(&self) -> &'static str {
|
||||||
|
"layer-norm"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn cpu_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &CpuStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &CpuStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &CpuStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
|
||||||
|
let eps = self.eps;
|
||||||
|
fn inner<
|
||||||
|
T: candle::WithDType
|
||||||
|
+ num_traits::Float
|
||||||
|
+ num_traits::AsPrimitive<f32>
|
||||||
|
+ num_traits::FromPrimitive,
|
||||||
|
>(
|
||||||
|
src: &[T],
|
||||||
|
layout: &Layout,
|
||||||
|
alpha: &[T],
|
||||||
|
alpha_layout: &Layout,
|
||||||
|
beta: &[T],
|
||||||
|
beta_layout: &Layout,
|
||||||
|
eps: f32,
|
||||||
|
) -> Result<(CpuStorage, Shape)> {
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &src[o1..o2],
|
||||||
|
};
|
||||||
|
let alpha = match alpha_layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("alpha has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &alpha[o1..o2],
|
||||||
|
};
|
||||||
|
let beta = match beta_layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("beta has to be contiguous"),
|
||||||
|
Some((o1, o2)) => &beta[o1..o2],
|
||||||
|
};
|
||||||
|
let el_count = layout.shape().elem_count();
|
||||||
|
let dims = layout.shape().dims();
|
||||||
|
let dim_m1 = dims[dims.len() - 1];
|
||||||
|
let mut dst = vec![T::zero(); el_count];
|
||||||
|
src.par_chunks(dim_m1)
|
||||||
|
.zip(dst.par_chunks_mut(dim_m1))
|
||||||
|
.for_each(|(src, dst)| {
|
||||||
|
let mut sum = 0f32;
|
||||||
|
let mut sum2 = 0f32;
|
||||||
|
for v in src {
|
||||||
|
let v = v.as_();
|
||||||
|
sum += v;
|
||||||
|
sum2 += v * v;
|
||||||
|
}
|
||||||
|
let mean = sum / dim_m1 as f32;
|
||||||
|
let var = sum2 / dim_m1 as f32 - mean * mean;
|
||||||
|
let inv_std = (var + eps).sqrt().recip();
|
||||||
|
for ((d, s), (alpha, beta)) in
|
||||||
|
dst.iter_mut().zip(src.iter()).zip(alpha.iter().zip(beta))
|
||||||
|
{
|
||||||
|
let alpha = alpha.as_();
|
||||||
|
let beta = beta.as_();
|
||||||
|
let d_ = (s.as_() - mean) * inv_std * alpha + beta;
|
||||||
|
*d = T::from_f32(d_).unwrap_or_else(T::nan);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
let storage = candle::WithDType::to_cpu_storage_owned(dst);
|
||||||
|
Ok((storage, Shape::from_dims(dims)))
|
||||||
|
}
|
||||||
|
|
||||||
|
use CpuStorage as C;
|
||||||
|
match (s1, s2, s3) {
|
||||||
|
(C::BF16(s1), C::BF16(s2), C::BF16(s3)) => {
|
||||||
|
inner::<half::bf16>(s1, l1, s2, l2, s3, l3, eps)
|
||||||
|
}
|
||||||
|
(C::F16(s1), C::F16(s2), C::F16(s3)) => inner::<half::f16>(s1, l1, s2, l2, s3, l3, eps),
|
||||||
|
(C::F32(s1), C::F32(s2), C::F32(s3)) => inner::<f32>(s1, l1, s2, l2, s3, l3, eps),
|
||||||
|
_ => candle::bail!("unsupported dtype for rmsnorm {:?}", s1.dtype()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
fn cuda_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &candle::CudaStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &candle::CudaStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &candle::CudaStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(candle::CudaStorage, Shape)> {
|
||||||
|
use candle::cuda_backend::cudarc::driver::{
|
||||||
|
CudaSlice, DeviceRepr, LaunchAsync, LaunchConfig,
|
||||||
|
};
|
||||||
|
use candle::cuda_backend::{kernel_name, kernels, Map3, WrapErr};
|
||||||
|
use candle::{CudaDevice, WithDType};
|
||||||
|
|
||||||
|
struct S {
|
||||||
|
eps: f32,
|
||||||
|
}
|
||||||
|
impl Map3 for S {
|
||||||
|
fn f<T: DeviceRepr + WithDType>(
|
||||||
|
&self,
|
||||||
|
src: &CudaSlice<T>,
|
||||||
|
layout: &Layout,
|
||||||
|
alpha: &CudaSlice<T>,
|
||||||
|
alpha_layout: &Layout,
|
||||||
|
beta: &CudaSlice<T>,
|
||||||
|
beta_layout: &Layout,
|
||||||
|
dev: &CudaDevice,
|
||||||
|
) -> Result<CudaSlice<T>> {
|
||||||
|
let src = match layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("input has to be contiguous"),
|
||||||
|
Some((o1, o2)) => src.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let alpha = match alpha_layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("alpha has to be contiguous"),
|
||||||
|
Some((o1, o2)) => alpha.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let beta = match beta_layout.contiguous_offsets() {
|
||||||
|
None => candle::bail!("beta has to be contiguous"),
|
||||||
|
Some((o1, o2)) => beta.slice(o1..o2),
|
||||||
|
};
|
||||||
|
let el = layout.shape().elem_count();
|
||||||
|
let dims = layout.shape().dims();
|
||||||
|
let dim_m1 = dims[dims.len() - 1];
|
||||||
|
let (n_rows, n_cols) = (el / dim_m1, dim_m1);
|
||||||
|
|
||||||
|
let cfg = LaunchConfig {
|
||||||
|
grid_dim: (n_rows as u32, 1, 1),
|
||||||
|
block_dim: (1024, 1, 1),
|
||||||
|
shared_mem_bytes: 0,
|
||||||
|
};
|
||||||
|
let func = dev.get_or_load_func(&kernel_name::<T>("layernorm"), kernels::REDUCE)?;
|
||||||
|
// SAFETY: Set later by running the kernel.
|
||||||
|
let dst = unsafe { dev.alloc::<T>(el) }.w()?;
|
||||||
|
let params = (&src, &dst, &alpha, &beta, n_cols as i32, self.eps);
|
||||||
|
// SAFETY: ffi.
|
||||||
|
unsafe { func.launch(cfg, params) }.w()?;
|
||||||
|
Ok(dst)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
let dev = s1.device();
|
||||||
|
let slice = S { eps: self.eps }.map(&s1.slice, l1, &s2.slice, l2, &s3.slice, l3, dev)?;
|
||||||
|
let dst = candle::cuda_backend::CudaStorage {
|
||||||
|
slice,
|
||||||
|
device: dev.clone(),
|
||||||
|
};
|
||||||
|
Ok((dst, l1.shape().clone()))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "metal")]
|
||||||
|
fn metal_fwd(
|
||||||
|
&self,
|
||||||
|
s1: &candle::MetalStorage,
|
||||||
|
l1: &Layout,
|
||||||
|
s2: &candle::MetalStorage,
|
||||||
|
l2: &Layout,
|
||||||
|
s3: &candle::MetalStorage,
|
||||||
|
l3: &Layout,
|
||||||
|
) -> Result<(candle::MetalStorage, Shape)> {
|
||||||
|
use candle::backend::BackendStorage;
|
||||||
|
let device = s1.device();
|
||||||
|
let command_buffer = device.command_buffer()?;
|
||||||
|
let kernels = device.kernels();
|
||||||
|
let name = match (s1.dtype(), s2.dtype(), s3.dtype()) {
|
||||||
|
(DType::F32, DType::F32, DType::F32) => "layernorm_f32",
|
||||||
|
(DType::F16, DType::F16, DType::F16) => "layernorm_f16",
|
||||||
|
(DType::BF16, DType::BF16, DType::BF16) => "layernorm_bf16",
|
||||||
|
(dt1, dt2, dt3) => {
|
||||||
|
candle::bail!("layernorm is not implemented for {dt1:?} {dt2:?} {dt3:?}")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !(l1.is_contiguous() && l2.is_contiguous() && l3.is_contiguous()) {
|
||||||
|
candle::bail!("Non contiguous layernorm is not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
let last_dim = l1.dims()[l1.shape().rank() - 1];
|
||||||
|
let elem_count = l1.shape().elem_count();
|
||||||
|
let output = device.new_buffer(elem_count, s1.dtype(), "layernorm")?;
|
||||||
|
candle_metal_kernels::call_layer_norm(
|
||||||
|
device.metal_device(),
|
||||||
|
&command_buffer,
|
||||||
|
kernels,
|
||||||
|
name,
|
||||||
|
elem_count,
|
||||||
|
last_dim,
|
||||||
|
self.eps,
|
||||||
|
s1.buffer(),
|
||||||
|
l1.start_offset() * s1.dtype().size_in_bytes(),
|
||||||
|
s2.buffer(),
|
||||||
|
l2.start_offset() * s2.dtype().size_in_bytes(),
|
||||||
|
s3.buffer(),
|
||||||
|
l3.start_offset() * s3.dtype().size_in_bytes(),
|
||||||
|
&output,
|
||||||
|
)
|
||||||
|
.map_err(candle::Error::wrap)?;
|
||||||
|
let newstorage = candle::MetalStorage::new(output, device.clone(), elem_count, s1.dtype());
|
||||||
|
Ok((newstorage, l1.shape().clone()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn layer_norm_slow(x: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
||||||
|
let x_dtype = x.dtype();
|
||||||
|
let internal_dtype = match x_dtype {
|
||||||
|
DType::F16 | DType::BF16 => DType::F32,
|
||||||
|
d => d,
|
||||||
|
};
|
||||||
|
let hidden_size = x.dim(D::Minus1)?;
|
||||||
|
let x = x.to_dtype(internal_dtype)?;
|
||||||
|
let x = {
|
||||||
|
let mean_x = (x.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
x.broadcast_sub(&mean_x)?
|
||||||
|
};
|
||||||
|
let norm_x = (x.sqr()?.sum_keepdim(D::Minus1)? / hidden_size as f64)?;
|
||||||
|
let x_normed = x.broadcast_div(&(norm_x + eps as f64)?.sqrt()?)?;
|
||||||
|
x_normed
|
||||||
|
.to_dtype(x_dtype)?
|
||||||
|
.broadcast_mul(alpha)?
|
||||||
|
.broadcast_add(beta)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn layer_norm(xs: &Tensor, alpha: &Tensor, beta: &Tensor, eps: f32) -> Result<Tensor> {
|
||||||
|
let hidden_size_xs = xs.dim(D::Minus1)?;
|
||||||
|
let hidden_size_alpha = alpha.dims1()?;
|
||||||
|
let hidden_size_beta = beta.dims1()?;
|
||||||
|
if hidden_size_xs != hidden_size_alpha || hidden_size_xs != hidden_size_beta {
|
||||||
|
candle::bail!(
|
||||||
|
"shape mismatch in layer-norm src: {:?} alpha: {:?} beta: {:?}",
|
||||||
|
xs.shape(),
|
||||||
|
alpha.shape(),
|
||||||
|
beta.shape()
|
||||||
|
)
|
||||||
|
}
|
||||||
|
xs.apply_op3_no_bwd(alpha, beta, &LayerNorm { eps })
|
||||||
|
}
|
||||||
|
|
||||||
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
// https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html
|
||||||
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
pub fn pixel_shuffle(xs: &Tensor, upscale_factor: usize) -> Result<Tensor> {
|
||||||
let (b_size, c, h, w) = xs.dims4()?;
|
let (b_size, c, h, w) = xs.dims4()?;
|
||||||
|
@ -13,6 +13,12 @@ fn layer_norm() -> Result<()> {
|
|||||||
let device = &Device::Cpu;
|
let device = &Device::Cpu;
|
||||||
let w = Tensor::new(&[3f32], device)?;
|
let w = Tensor::new(&[3f32], device)?;
|
||||||
let b = Tensor::new(&[0.5f32], device)?;
|
let b = Tensor::new(&[0.5f32], device)?;
|
||||||
|
let ln2 = LayerNorm::new(Tensor::cat(&[&w, &w], 0)?, Tensor::cat(&[&b, &b], 0)?, 1e-8);
|
||||||
|
let ln3 = LayerNorm::new(
|
||||||
|
Tensor::cat(&[&w, &w, &w], 0)?,
|
||||||
|
Tensor::cat(&[&b, &b, &b], 0)?,
|
||||||
|
1e-8,
|
||||||
|
);
|
||||||
let ln = LayerNorm::new(w, b, 1e-8);
|
let ln = LayerNorm::new(w, b, 1e-8);
|
||||||
|
|
||||||
let two = Tensor::new(&[[[2f32]]], device)?;
|
let two = Tensor::new(&[[[2f32]]], device)?;
|
||||||
@ -20,11 +26,11 @@ fn layer_norm() -> Result<()> {
|
|||||||
assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
|
assert_eq!(res.to_vec1::<f32>()?, [0.5f32]);
|
||||||
|
|
||||||
let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
|
let inp = Tensor::new(&[[[4f32, 0f32]]], device)?;
|
||||||
let res = ln.forward(&inp)?;
|
let res = ln2.forward(&inp)?;
|
||||||
assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
|
assert_eq!(res.to_vec3::<f32>()?, [[[3.5f32, -2.5]]]);
|
||||||
|
|
||||||
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
|
let inp = Tensor::new(&[[[1f32, 2., 3.], [4., 5., 6.], [9., 8., 7.]]], device)?;
|
||||||
let res = ln.forward(&inp)?;
|
let res = ln3.forward(&inp)?;
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
test_utils::to_vec3_round(&res, 4)?,
|
test_utils::to_vec3_round(&res, 4)?,
|
||||||
[[
|
[[
|
||||||
@ -35,7 +41,10 @@ fn layer_norm() -> Result<()> {
|
|||||||
);
|
);
|
||||||
let mean = (res.sum_keepdim(2)? / 3.0)?;
|
let mean = (res.sum_keepdim(2)? / 3.0)?;
|
||||||
// The average value should be `b`.
|
// The average value should be `b`.
|
||||||
assert_eq!(mean.to_vec3::<f32>()?, [[[0.5], [0.5], [0.5]]]);
|
assert_eq!(
|
||||||
|
test_utils::to_vec3_round(&mean, 4)?,
|
||||||
|
[[[0.5], [0.5], [0.5]]]
|
||||||
|
);
|
||||||
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
|
let std = (res.broadcast_sub(&mean)?.sqr()?.sum_keepdim(2)?.sqrt()? / 3.0)?;
|
||||||
// The standard deviation should be sqrt(`w`).
|
// The standard deviation should be sqrt(`w`).
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
|
@ -77,6 +77,32 @@ fn rms_norm(device: &Device) -> Result<()> {
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn layer_norm(device: &Device) -> Result<()> {
|
||||||
|
let data = &[[[3f32, 1., 4.], [1., 5., 9.]], [[2., 1., 7.], [8., 2., 8.]]];
|
||||||
|
let tensor = Tensor::new(data, device)?;
|
||||||
|
let alpha = Tensor::new(&[1f32, 2f32, 3f32], device)?;
|
||||||
|
let beta = Tensor::new(&[0.5f32, 0f32, -0.2f32], device)?;
|
||||||
|
let t = candle_nn::ops::layer_norm(&tensor, &alpha, &beta, 1e-5)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec3_round(&t, 4)?,
|
||||||
|
&[
|
||||||
|
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
|
||||||
|
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let t2 = candle_nn::ops::layer_norm_slow(&tensor, &alpha, &beta, 1e-5)?;
|
||||||
|
assert_eq!(
|
||||||
|
to_vec3_round(&t2, 4)?,
|
||||||
|
&[
|
||||||
|
[[0.7673, -2.6726, 3.0071], [-0.7247, 0.0, 3.4742]],
|
||||||
|
[[-0.008, -1.778, 3.991], [1.2071, -2.8284, 1.9213]]
|
||||||
|
]
|
||||||
|
);
|
||||||
|
let diff = (t - t2)?.abs()?.sum_all()?.to_vec0::<f32>()?;
|
||||||
|
assert!(diff < 1e-5);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn softmax_numerical_stability() -> Result<()> {
|
fn softmax_numerical_stability() -> Result<()> {
|
||||||
let dev = &Device::Cpu;
|
let dev = &Device::Cpu;
|
||||||
@ -185,4 +211,5 @@ test_device!(rope, rope_cpu, rope_gpu, rope_metal);
|
|||||||
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
test_device!(rope_thd, rope_thd_cpu, rope_thd_gpu, rope_thd_metal);
|
||||||
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
test_device!(softmax, softmax_cpu, softmax_gpu, softmax_metal);
|
||||||
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
test_device!(rms_norm, rms_norm_cpu, rms_norm_gpu, rms_norm_metal);
|
||||||
|
test_device!(layer_norm, ln_cpu, ln_gpu, ln_metal);
|
||||||
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
|
test_device!(sigmoid, sigmoid_cpu, sigmoid_gpu, sigmoid_metal);
|
||||||
|
@ -3,6 +3,7 @@ use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
|
|||||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub phi_config: PhiConfig,
|
pub phi_config: PhiConfig,
|
||||||
pub vision_config: VisionConfig,
|
pub vision_config: VisionConfig,
|
||||||
|
@ -56,24 +56,20 @@ impl RotaryEmbedding {
|
|||||||
.to_dtype(DType::F32)?
|
.to_dtype(DType::F32)?
|
||||||
.reshape((cfg.max_position_embeddings, 1))?;
|
.reshape((cfg.max_position_embeddings, 1))?;
|
||||||
let freqs = t.matmul(&inv_freq)?;
|
let freqs = t.matmul(&inv_freq)?;
|
||||||
let emb = Tensor::cat(&[&freqs, &freqs], D::Minus1)?;
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
dim,
|
dim,
|
||||||
sin: emb.sin()?,
|
sin: freqs.sin()?,
|
||||||
cos: emb.cos()?,
|
cos: freqs.cos()?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
fn apply_rotary_emb(&self, xs: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
|
||||||
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
let (_b_size, _num_heads, seq_len, _headdim) = xs.dims4()?;
|
||||||
let xs_rot = xs.i((.., .., .., ..self.dim))?;
|
let xs_rot = xs.i((.., .., .., ..self.dim))?.contiguous()?;
|
||||||
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
let xs_pass = xs.i((.., .., .., self.dim..))?;
|
||||||
let xs12 = xs_rot.chunk(2, D::Minus1)?;
|
|
||||||
let (xs1, xs2) = (&xs12[0], &xs12[1]);
|
|
||||||
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
let c = self.cos.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
let s = self.sin.narrow(0, seqlen_offset, seq_len)?;
|
||||||
let rotate_half = Tensor::cat(&[&xs2.neg()?, xs1], D::Minus1)?;
|
let xs_rot = candle_nn::rotary_emb::rope(&xs_rot, &c, &s)?;
|
||||||
let xs_rot = (xs_rot.broadcast_mul(&c)? + rotate_half.broadcast_mul(&s)?)?;
|
|
||||||
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
Tensor::cat(&[&xs_rot, &xs_pass], D::Minus1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -146,7 +146,7 @@ impl LayerWeights {
|
|||||||
};
|
};
|
||||||
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
let att = candle_nn::ops::softmax_last_dim(&att)?;
|
||||||
// Convert to contiguous as matmul doesn't support strided vs for now.
|
// Convert to contiguous as matmul doesn't support strided vs for now.
|
||||||
att.matmul(&v.contiguous()?)?
|
att.matmul(&v)?
|
||||||
};
|
};
|
||||||
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
let y = y.transpose(1, 2)?.reshape(&[b_sz, seq_len, n_embd])?;
|
||||||
let y = self.attn_output.forward(&y)?;
|
let y = self.attn_output.forward(&y)?;
|
||||||
@ -203,7 +203,6 @@ fn precomput_freqs_cis(
|
|||||||
|
|
||||||
impl ModelWeights {
|
impl ModelWeights {
|
||||||
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
pub fn from_gguf<R: std::io::Seek + std::io::Read>(
|
||||||
batch_size: usize,
|
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
ct: gguf_file::Content,
|
ct: gguf_file::Content,
|
||||||
reader: &mut R,
|
reader: &mut R,
|
||||||
@ -252,12 +251,7 @@ impl ModelWeights {
|
|||||||
)?;
|
)?;
|
||||||
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
let span_attn = tracing::span!(tracing::Level::TRACE, "attn");
|
||||||
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
let span_rot = tracing::span!(tracing::Level::TRACE, "attn-rot");
|
||||||
let kv_cache = KvCache::new(
|
let kv_cache = KvCache::new(2, max_seq_len);
|
||||||
2,
|
|
||||||
(batch_size, head_count_kv, max_seq_len, head_dim),
|
|
||||||
DType::F32,
|
|
||||||
device,
|
|
||||||
)?;
|
|
||||||
layers.push(LayerWeights {
|
layers.push(LayerWeights {
|
||||||
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
attn_qkv: QLinear::new(&ct, reader, &format!("{prefix}.attn_qkv"), device)?,
|
||||||
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
attn_output: QLinear::new(&ct, reader, &format!("{prefix}.attn_output"), device)?,
|
||||||
|
Reference in New Issue
Block a user