mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 19:18:50 +00:00
Compare commits
5 Commits
cuda-conv-
...
spkemb
Author | SHA1 | Date | |
---|---|---|---|
9dc53ec8ad | |||
577316bc4e | |||
b5ee026cea | |||
52ed77c16f | |||
dae32d13d6 |
18
Cargo.toml
18
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.2"
|
||||
version = "0.4.1"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,14 +31,14 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.2" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.1" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.1" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.1" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.1" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.1" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.1" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.1" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.1" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
|
@ -175,7 +175,6 @@ And then head over to
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
|
@ -98,19 +98,6 @@ pub trait BackendStorage: Sized {
|
||||
) -> Result<Self>;
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_d1: usize,
|
||||
_d2: usize,
|
||||
_src_stride1: usize,
|
||||
_dst_stride1: usize,
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
@ -5,7 +5,6 @@ use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -1023,26 +1022,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn copy2d_<T: Copy>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_stride1: usize,
|
||||
dst_stride1: usize,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
) {
|
||||
for i1 in 0..d1 {
|
||||
let dst_idx = i1 * dst_stride1 + dst_offset;
|
||||
let src_idx = i1 * src_stride1 + src_offset;
|
||||
let dst = &mut dst[dst_idx..dst_idx + d2];
|
||||
let src = &src[src_idx..src_idx + d2];
|
||||
dst.copy_from_slice(src)
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||
match src_l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
@ -1277,34 +1256,6 @@ impl Map1 for Im2Col {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
||||
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
||||
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
||||
for l_in_i in 0..l_in {
|
||||
for k_i in 0..k_size {
|
||||
let l_out_i = l_in_i * stride + k_i;
|
||||
for b_i in 0..b_size {
|
||||
for c_i in 0..c_out {
|
||||
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
||||
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
im[dst_idx] += col[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(im)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
@ -2472,48 +2423,6 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::I64(src), Self::I64(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(_, dst) => {
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: dst.dtype(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -2602,53 +2511,8 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!(
|
||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||
)
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col, &col_l)
|
||||
} else {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
&self,
|
||||
|
@ -608,34 +608,6 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
@ -1893,54 +1865,8 @@ impl BackendStorage for CudaStorage {
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
let slice = Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col.slice, &device, &col_l)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
@ -2219,67 +2145,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
let dst_s = dst_s as u32;
|
||||
let src_s = src_s as u32;
|
||||
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
||||
(S::U8(s), S::U8(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_u8",
|
||||
),
|
||||
(S::U32(s), S::U32(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_u32",
|
||||
),
|
||||
(S::I64(s), S::I64(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_i64",
|
||||
),
|
||||
(S::BF16(s), S::BF16(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_bf16",
|
||||
),
|
||||
(S::F16(s), S::F16(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f16",
|
||||
),
|
||||
(S::F32(s), S::F32(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f32",
|
||||
),
|
||||
(S::F64(s), S::F64(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f64",
|
||||
),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
|
||||
};
|
||||
let func = dev.get_or_load_func(kname, kernels::FILL)?;
|
||||
let cfg = LaunchConfig::for_num_elems(d1 * d2);
|
||||
let params = (src, dst, d1, d2, src_s, dst_s);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
|
@ -65,13 +65,12 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
|
||||
/// Options for Tensor pretty printing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PrinterOptions {
|
||||
pub precision: usize,
|
||||
pub threshold: usize,
|
||||
pub edge_items: usize,
|
||||
pub line_width: usize,
|
||||
pub sci_mode: Option<bool>,
|
||||
precision: usize,
|
||||
threshold: usize,
|
||||
edge_items: usize,
|
||||
line_width: usize,
|
||||
sci_mode: Option<bool>,
|
||||
}
|
||||
|
||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||
@ -90,10 +89,6 @@ impl PrinterOptions {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
||||
&PRINT_OPTS
|
||||
}
|
||||
|
||||
pub fn set_print_options(options: PrinterOptions) {
|
||||
*PRINT_OPTS.lock().unwrap() = options
|
||||
}
|
||||
@ -122,26 +117,6 @@ pub fn set_print_options_full() {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_line_width(line_width: usize) {
|
||||
PRINT_OPTS.lock().unwrap().line_width = line_width
|
||||
}
|
||||
|
||||
pub fn set_precision(precision: usize) {
|
||||
PRINT_OPTS.lock().unwrap().precision = precision
|
||||
}
|
||||
|
||||
pub fn set_edge_items(edge_items: usize) {
|
||||
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
||||
}
|
||||
|
||||
pub fn set_threshold(threshold: usize) {
|
||||
PRINT_OPTS.lock().unwrap().threshold = threshold
|
||||
}
|
||||
|
||||
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
||||
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
||||
}
|
||||
|
||||
struct FmtSize {
|
||||
current_size: usize,
|
||||
}
|
||||
|
@ -23,15 +23,7 @@ pub enum DType {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct DTypeParseError(String);
|
||||
|
||||
impl std::fmt::Display for DTypeParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "cannot parse '{}' as a dtype", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DTypeParseError {}
|
||||
pub struct DTypeParseError;
|
||||
|
||||
impl std::str::FromStr for DType {
|
||||
type Err = DTypeParseError;
|
||||
@ -44,7 +36,7 @@ impl std::str::FromStr for DType {
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
"f64" => Ok(Self::F64),
|
||||
_ => Err(DTypeParseError(s.to_string())),
|
||||
_ => Err(DTypeParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,19 +154,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -166,19 +166,6 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ impl Layout {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
@ -99,7 +99,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
@ -120,7 +120,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
|
@ -67,7 +67,6 @@ pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
mod tensor_cat;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
|
@ -9,7 +9,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
@ -60,8 +60,7 @@ impl From<String> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||
type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
@ -69,7 +68,7 @@ pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: CommandQueue,
|
||||
command_queue: metal::CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
@ -79,7 +78,7 @@ pub struct MetalDevice {
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
@ -88,7 +87,7 @@ pub struct MetalDevice {
|
||||
compute_per_buffer: usize,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
kernels: Arc<Kernels>,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
@ -100,7 +99,7 @@ pub struct MetalDevice {
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
@ -146,8 +145,6 @@ impl MetalDevice {
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
|
||||
self.drop_unused_buffers()?;
|
||||
}
|
||||
*index += 1;
|
||||
Ok(command_buffer)
|
||||
@ -166,7 +163,6 @@ impl MetalDevice {
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -203,25 +199,39 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||
/// This method will block the computation because of the
|
||||
/// lack of lifetime management through the GPU.
|
||||
/// Internal comment for technical details.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const c_void,
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
let real = self.allocate_buffer(
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.end_encoding();
|
||||
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
Ok(new_buffer)
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage to do
|
||||
// deallocate properly.
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
@ -245,40 +255,6 @@ impl MetalDevice {
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &RwLockWriteGuard<BufferMap>,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return best_buffer.map(|b| b.clone());
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
@ -287,18 +263,24 @@ impl MetalDevice {
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
|
||||
let size = buf_size(size);
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
return Ok(sub.clone());
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
@ -323,8 +305,6 @@ pub struct MetalStorage {
|
||||
buffer: Arc<metal::Buffer>,
|
||||
/// a reference to the device owning this buffer
|
||||
device: MetalDevice,
|
||||
/// The count of allocated elements in the buffer
|
||||
count: usize,
|
||||
/// The dtype is kept since buffers are untyped.
|
||||
dtype: DType,
|
||||
}
|
||||
@ -406,7 +386,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
|
||||
@ -422,7 +402,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
DType::BF16 => "powf_bf16",
|
||||
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_powf(
|
||||
@ -440,7 +419,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32_strided",
|
||||
DType::F16 => "powf_f16_strided",
|
||||
DType::BF16 => "powf_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_powf_strided(
|
||||
@ -457,7 +435,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
@ -473,7 +451,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
DType::BF16 => "elu_bf16",
|
||||
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_elu(
|
||||
@ -491,7 +468,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32_strided",
|
||||
DType::F16 => "elu_f16_strided",
|
||||
DType::BF16 => "elu_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_elu_strided(
|
||||
@ -508,7 +484,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
@ -586,7 +562,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::new(buffer, device, dst_el, dtype))
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
@ -609,41 +585,28 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
(DType::U8, DType::F16) => "cast_u8_f16",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F32, DType::I64) => "cast_f32_i64",
|
||||
(DType::F32, DType::U32) => "cast_f32_u32",
|
||||
(DType::F32, DType::U8) => "cast_f32_u8",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::BF16) => "cast_i64_bf16",
|
||||
(DType::I64, DType::F16) => "cast_i64_f16",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::I64, DType::U32) => "cast_i64_u32",
|
||||
(DType::I64, DType::U8) => "cast_i64_u8",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::F16, DType::I64) => "cast_f16_i64",
|
||||
(DType::F16, DType::U32) => "cast_f16_u32",
|
||||
(DType::F16, DType::U8) => "cast_f16_u8",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
(DType::BF16, DType::I64) => "cast_bf16_i64",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
@ -691,7 +654,7 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("to_dtype");
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
@ -811,7 +774,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
@ -872,7 +835,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device, el, dtype))
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -917,7 +880,6 @@ impl BackendStorage for MetalStorage {
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
count: dst_el,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let l_out = params.l_out();
|
||||
@ -1002,7 +964,6 @@ impl BackendStorage for MetalStorage {
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
count: dst_el,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let h_out = params.out_h();
|
||||
@ -1088,7 +1049,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype))
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -1122,7 +1083,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
@ -1145,15 +1106,7 @@ impl BackendStorage for MetalStorage {
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||
(DType::U8, DType::F16) => "sa_u8_f16",
|
||||
(DType::U8, DType::BF16) => "sa_u8_bf16",
|
||||
(DType::U32, DType::F32) => "sa_u32_f32",
|
||||
(DType::U32, DType::F16) => "sa_u32_f16",
|
||||
(DType::U32, DType::BF16) => "sa_u32_bf16",
|
||||
(DType::I64, DType::F32) => "sa_i64_f32",
|
||||
(DType::I64, DType::F16) => "sa_i64_f16",
|
||||
(DType::I64, DType::BF16) => "sa_i64_bf16",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
@ -1219,7 +1172,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -1301,73 +1254,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
b * m * n,
|
||||
self.dtype(),
|
||||
))
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
if self.dtype() != dst.dtype() {
|
||||
crate::bail!(
|
||||
"copy2d with inconsistent dtypes {:?} {:?}",
|
||||
self.dtype(),
|
||||
dst.dtype()
|
||||
)
|
||||
}
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if src_s == d2 && dst_s == d2 {
|
||||
command_buffer.set_label("copy2d_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("copy2d_contiguous");
|
||||
let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let el_count = d1 * d2;
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::copy2d::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::copy2d::HALF,
|
||||
DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
|
||||
DType::I64 => candle_metal_kernels::copy2d::I64,
|
||||
DType::U32 => candle_metal_kernels::copy2d::U32,
|
||||
DType::U8 => candle_metal_kernels::copy2d::U8,
|
||||
dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_copy2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
&self.buffer,
|
||||
&dst.buffer,
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
src_o * self.dtype.size_in_bytes(),
|
||||
dst_o * self.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy2d");
|
||||
}
|
||||
Ok(())
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
@ -1416,11 +1303,10 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, count: usize, dtype: DType) -> Self {
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
device,
|
||||
count,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
@ -1635,23 +1521,29 @@ impl MetalStorage {
|
||||
(buffer, dtype)
|
||||
};
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
|
||||
let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
|
||||
let buffer = self.device.new_buffer_managed(size)?;
|
||||
let length = self.buffer.length() as usize;
|
||||
let size = self.dtype.size_in_bytes();
|
||||
if length % size != 0 {
|
||||
crate::bail!(
|
||||
"The Metal buffer length is not aligned with dtype {:?}",
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
Ok(read_to_vec(&buffer, self.count))
|
||||
Ok(read_to_vec(&buffer, length / size))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1669,7 +1561,7 @@ impl BackendDevice for MetalDevice {
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 50,
|
||||
_ => 10,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
@ -1701,12 +1593,7 @@ impl BackendDevice for MetalDevice {
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let size = shape.elem_count() * dtype.size_in_bytes();
|
||||
let buffer = self.allocate_zeros(size)?;
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
@ -1716,21 +1603,16 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match storage {
|
||||
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
};
|
||||
Ok(Self::Storage::new(
|
||||
buffer?,
|
||||
self.clone(),
|
||||
count,
|
||||
storage.dtype(),
|
||||
))
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
}?;
|
||||
Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
@ -1761,12 +1643,7 @@ impl BackendDevice for MetalDevice {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
@ -1797,12 +1674,7 @@ impl BackendDevice for MetalDevice {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
@ -1813,7 +1685,7 @@ impl BackendDevice for MetalDevice {
|
||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
let contents = seed_buffer.contents();
|
||||
unsafe {
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 1);
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||
}
|
||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||
|
||||
@ -1821,10 +1693,6 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
(size - 1).next_power_of_two() as NSUInteger
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
|
@ -313,7 +313,7 @@ impl QCudaStorage {
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0);
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
|
@ -106,12 +106,7 @@ impl QMetalStorage {
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
elem_count,
|
||||
DType::F32,
|
||||
))
|
||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
@ -175,7 +170,7 @@ impl QMetalStorage {
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
@ -398,7 +398,7 @@ impl QMatMul {
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
|
@ -701,32 +701,4 @@ impl Storage {
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -666,7 +666,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: self.shape().clone(),
|
||||
@ -2149,6 +2149,152 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
// TODO: Avoid these transpositions and have an implementation that works
|
||||
// for dim != 0...
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
|
@ -1,240 +0,0 @@
|
||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
||||
|
||||
impl Tensor {
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else {
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[dim] = 0;
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == dim {
|
||||
cat_dims[dim] += v2;
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
let cat_target_dim_len = cat_dims[dim];
|
||||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
let mut dst_o = 0;
|
||||
for arg in args.iter() {
|
||||
let arg = arg.as_ref();
|
||||
let arg_dims = arg.shape().dims();
|
||||
let d1: usize = arg_dims.iter().take(dim).product();
|
||||
let d2 = block_size * arg_dims[dim];
|
||||
let dst_s = block_size * cat_target_dim_len;
|
||||
let src_o = arg.layout().start_offset();
|
||||
arg.storage().copy2d(
|
||||
&mut storage,
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
dst_s,
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
dst_o += d2;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
}
|
@ -53,16 +53,7 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
// conv-transposes are not implemented for metal.
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
for w in [w.clone(), w.contiguous()?] {
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
@ -71,7 +62,7 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
@ -82,7 +73,6 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -168,7 +158,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
if !dev.is_metal() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
@ -194,7 +183,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
]
|
||||
]
|
||||
);
|
||||
}
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
@ -203,7 +191,6 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
[2.45, -2.3504],
|
||||
);
|
||||
|
||||
if !dev.is_metal() {
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
@ -217,10 +204,7 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[
|
||||
-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
|
||||
-3.5024
|
||||
],
|
||||
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
@ -233,14 +217,10 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[
|
||||
-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
|
||||
1.0171
|
||||
]
|
||||
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||
]
|
||||
]
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -294,12 +274,6 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
|
||||
// conv-transposes are not implemented for metal
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
@ -401,10 +375,6 @@ print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
// conv-transposes are not implemented for metal
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(clippy::approx_constant)]
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
|
||||
@ -97,24 +96,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
y.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
);
|
||||
let y = x.exp()?.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 3)?,
|
||||
[403.429, 7.389, 2980.958, 1.35]
|
||||
y.to_vec1::<f32>()?,
|
||||
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||
);
|
||||
// exp(x)^2 = exp(2*x)
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[806.86, 14.78, 5961.92, 2.7]
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||
);
|
||||
let y = x.sin()?;
|
||||
let grads = y.backward()?;
|
||||
@ -262,7 +261,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
|
@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let data: Vec<f32> = vec![
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
@ -22,9 +19,6 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
@ -49,9 +43,6 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
||||
print(res)
|
||||
*/
|
||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
|
@ -672,31 +672,6 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
|
||||
// 3D
|
||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
||||
|
||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let t1 = t1.t()?.contiguous()?.t()?;
|
||||
let t2 = t2.t()?.contiguous()?.t()?;
|
||||
let t3 = t3.t()?.contiguous()?.t()?;
|
||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1105,33 +1080,8 @@ fn broadcasting(device: &Device) -> Result<()> {
|
||||
fn randn(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
// We do not expect deterministic elements at any index.
|
||||
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
||||
const N: usize = 2;
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the randn tensors"
|
||||
);
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the rand tensors"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -100,4 +100,6 @@ required-features = ["candle-datasets"]
|
||||
name = "encodec"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
|
||||
[[example]]
|
||||
name = "metavoice"
|
||||
required-features = ["symphonia"]
|
||||
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -93,7 +93,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -66,7 +66,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -109,7 +109,8 @@ fn main() -> Result<()> {
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
let codes = codes.get("codes").expect("no codes in input file").i(0)?;
|
||||
codes
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
|
@ -1,4 +1,4 @@
|
||||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
||||
# candle-mistral: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
@ -10,8 +10,9 @@ use std::io::Write;
|
||||
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::encodec;
|
||||
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
||||
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
|
||||
use candle_transformers::models::metavoice::{
|
||||
adapters, gpt, speaker_encoder, tokenizers, transformer,
|
||||
};
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
@ -20,6 +21,60 @@ use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ArgDType {
|
||||
F32,
|
||||
@ -27,11 +82,6 @@ enum ArgDType {
|
||||
Bf16,
|
||||
}
|
||||
|
||||
enum Transformer {
|
||||
Normal(transformer::Model),
|
||||
Quantized(qtransformer::Model),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -46,10 +96,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// The guidance scale.
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
guidance_scale: f64,
|
||||
@ -79,9 +125,14 @@ struct Args {
|
||||
#[arg(long)]
|
||||
second_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
speaker_encoder_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
encodec_weights: Option<String>,
|
||||
|
||||
/// The speaker embeddings, either an audio files in which case they are extracted, or a
|
||||
/// safetensors file with the embeddings already extracted.
|
||||
#[arg(long)]
|
||||
spk_emb: Option<String>,
|
||||
|
||||
@ -89,6 +140,13 @@ struct Args {
|
||||
dtype: ArgDType,
|
||||
}
|
||||
|
||||
fn mel_filters() -> Result<Vec<f32>> {
|
||||
let mel_bytes = include_bytes!("melfilters40.bytes").as_slice();
|
||||
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||
Ok(mel_filters)
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
@ -126,6 +184,10 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
||||
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let second_stage_weights = match &args.second_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("second_stage.safetensors")?,
|
||||
@ -141,27 +203,10 @@ fn main() -> Result<()> {
|
||||
ArgDType::F16 => DType::F16,
|
||||
ArgDType::Bf16 => DType::BF16,
|
||||
};
|
||||
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = if args.quantized {
|
||||
let filename = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage_q4k.gguf")?,
|
||||
};
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
|
||||
Transformer::Quantized(first_stage_model)
|
||||
} else {
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
Transformer::Normal(first_stage_model)
|
||||
};
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
|
||||
let second_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||
@ -182,16 +227,41 @@ fn main() -> Result<()> {
|
||||
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
||||
let mut tokens = prompt_tokens.clone();
|
||||
println!("{tokens:?}");
|
||||
let safetensors_embeddings = args
|
||||
.spk_emb
|
||||
.as_ref()
|
||||
.map_or(true, |v| v.ends_with("safetensors"));
|
||||
let spk_emb = if safetensors_embeddings {
|
||||
let spk_emb_file = match &args.spk_emb {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("spk_emb.safetensors")?,
|
||||
};
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||
let spk_emb = match spk_emb.get("spk_emb") {
|
||||
match spk_emb.get("spk_emb") {
|
||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
|
||||
}
|
||||
} else {
|
||||
let weights = match &args.speaker_encoder_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("speaker_encoder.safetensors")?,
|
||||
};
|
||||
let mel_filters = mel_filters()?;
|
||||
let config = speaker_encoder::Config::cfg();
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
|
||||
let model = speaker_encoder::Model::new(config, vb)?;
|
||||
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
|
||||
if sample_rate != 16_000 {
|
||||
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
|
||||
}
|
||||
model.embed_utterance(
|
||||
&pcm,
|
||||
&mel_filters,
|
||||
/* rate */ 1.3,
|
||||
/* min_c */ 0.75,
|
||||
&device,
|
||||
)?
|
||||
};
|
||||
let spk_emb = spk_emb.to_device(&device)?;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||
|
||||
// First stage generation.
|
||||
@ -201,12 +271,7 @@ fn main() -> Result<()> {
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||
let logits = match &mut first_stage_model {
|
||||
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
|
||||
Transformer::Quantized(m) => {
|
||||
m.forward(&input, &spk_emb, tokens.len() - context_size)?
|
||||
}
|
||||
};
|
||||
let logits = first_stage_model.forward(&input, &spk_emb, tokens.len() - context_size)?;
|
||||
let logits0 = logits.i((0, 0))?;
|
||||
let logits1 = logits.i((1, 0))?;
|
||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||
|
BIN
candle-examples/examples/metavoice/melfilters40.bytes
Normal file
BIN
candle-examples/examples/metavoice/melfilters40.bytes
Normal file
Binary file not shown.
@ -63,7 +63,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -78,7 +78,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -45,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -141,7 +141,7 @@ impl std::fmt::Display for Which {
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "RWKV/v5-Eagle-7B-HF",
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
Self::World6_1b6 => "paperfun/rwkv",
|
||||
|
@ -96,10 +96,6 @@ struct Args {
|
||||
/// information.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
img2img_strength: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long)]
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
@ -378,7 +374,6 @@ fn run(args: Args) -> Result<()> {
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
seed,
|
||||
..
|
||||
} = args;
|
||||
|
||||
@ -432,9 +427,6 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
let use_guide_scale = guidance_scale > 1.0;
|
||||
|
||||
let which = match sd_version {
|
||||
|
@ -10,6 +10,11 @@ order to be able to use it.
|
||||
|
||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||
|
||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
||||
Candle, so to run it you can download a somewhat compatible
|
||||
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||
and pass it via the --tokenizer-file argument.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
|
@ -239,7 +239,14 @@ fn main() -> Result<()> {
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
None => match args.which {
|
||||
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
Which::V2 | Which::V2Zephyr => api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("tokenizer-gpt4.json")?,
|
||||
},
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
|
@ -33,7 +33,7 @@ struct Args {
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.4.2"
|
||||
version = "0.4.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.2" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.1" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.4.2"
|
||||
version = "0.4.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -51,48 +51,6 @@ __device__ void conv1d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void col2im1d(
|
||||
const size_t l_in,
|
||||
const size_t l_out,
|
||||
const size_t c_out,
|
||||
const size_t k_size,
|
||||
const size_t b_size,
|
||||
const size_t stride,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, l_in, c_out, k_size)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= b_size * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
|
||||
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
|
||||
const size_t b_i = dst_i / dst_s0;
|
||||
const size_t dst_i2 = dst_i - b_i * dst_s0;
|
||||
const size_t c_i = dst_i2 / dst_s1;
|
||||
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
|
||||
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
T d = 0;
|
||||
for (size_t k_i = 0; k_i < min(dst_i3 + 1, k_size); ++k_i) {
|
||||
const size_t l_in_i_times_stride = dst_i3 - k_i;
|
||||
const size_t l_in_i = l_in_i_times_stride / stride;
|
||||
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
if (l_in_i * stride == l_in_i_times_stride && l_in_i < l_in) {
|
||||
d += src[src_i];
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
@ -569,7 +527,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME, FN_NAME2) \
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t l_out, \
|
||||
@ -583,18 +541,6 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME2( \
|
||||
const size_t l_in, \
|
||||
const size_t l_out, \
|
||||
const size_t c_out, \
|
||||
const size_t k_size, \
|
||||
const size_t b_size, \
|
||||
const size_t stride, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
col2im1d<TYPENAME>(l_in, l_out, c_out, k_size, b_size, stride, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -696,7 +642,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16, col2im1d_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -708,7 +654,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16, col2im1d_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -751,7 +697,7 @@ IM2COL_OP(double, im2col_f64)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32, col2im1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64, col2im1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8, col2im1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32, col2im1d_u32)
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
@ -10,39 +10,11 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
|
||||
template<typename T>
|
||||
__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) {
|
||||
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= d1 * d2) {
|
||||
return;
|
||||
}
|
||||
uint32_t idx1 = idx / d2;
|
||||
uint32_t idx2 = idx - d2 * idx1;
|
||||
dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2];
|
||||
}
|
||||
|
||||
#define COPY2D_OP(TYPENAME, FNNAME) \
|
||||
extern "C" __global__ \
|
||||
void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \
|
||||
copy2d(src, dst, d1, d2, src_s, dst_s); \
|
||||
} \
|
||||
|
||||
COPY2D_OP(float, copy2d_f32)
|
||||
COPY2D_OP(double, copy2d_f64)
|
||||
COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
#endif
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.4.2"
|
||||
version = "0.4.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
@ -89,7 +89,7 @@ kernel void FN_NAME( \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[id]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
|
@ -72,60 +72,27 @@ kernel void FN_NAME_STRIDED( \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
// u32
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u8
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
// f16
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
|
||||
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
|
||||
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
|
||||
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
|
||||
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
|
||||
#endif
|
||||
|
||||
// f32
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
|
||||
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
|
||||
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
||||
// bf16
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
@ -167,16 +167,11 @@ kernel void NAME( \
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
|
||||
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
|
||||
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
@ -185,10 +180,6 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||
|
@ -127,16 +127,6 @@ pub enum Source {
|
||||
Quantized,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
pub struct Kernel(pub &'static str);
|
||||
pub const FLOAT: Kernel = Kernel("copy2d_f32");
|
||||
pub const HALF: Kernel = Kernel("copy2d_f16");
|
||||
pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
|
||||
pub const I64: Kernel = Kernel("copy2d_i64");
|
||||
pub const U32: Kernel = Kernel("copy2d_u32");
|
||||
pub const U8: Kernel = Kernel("copy2d_u8");
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
@ -375,46 +365,6 @@ pub fn call_unary_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_copy2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: copy2d::Kernel,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o_in_bytes: usize,
|
||||
dst_o_in_bytes: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
(input, src_o_in_bytes),
|
||||
(output, dst_o_in_bytes)
|
||||
)
|
||||
);
|
||||
|
||||
let width: usize = d1 * d2;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
@ -1608,10 +1558,8 @@ pub fn call_random_uniform(
|
||||
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(
|
||||
seed,
|
||||
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
|
||||
);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1641,10 +1589,8 @@ pub fn call_random_normal(
|
||||
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(
|
||||
seed,
|
||||
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
|
||||
);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
@ -123,20 +123,16 @@ template<typename T> METAL_FUNC void rand_uniform(
|
||||
return;
|
||||
}
|
||||
|
||||
// Evenly sized vectors need an offset when writing the mirror element.
|
||||
uint off = 1 - size % 2;
|
||||
float diff = abs(min - max);
|
||||
uint s = atomic_load_explicit(seed, memory_order_relaxed);
|
||||
HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1});
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0 && off == 0, otherwise we will write to out[size].
|
||||
if (off == 0)
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - off - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
@ -152,10 +148,7 @@ template<typename T> METAL_FUNC void normal(
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
// Evenly sized vectors need an offset when writing the mirror element.
|
||||
uint off = 1 - size % 2;
|
||||
uint s = atomic_load_explicit(seed, memory_order_relaxed);
|
||||
HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1});
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
@ -169,12 +162,11 @@ template<typename T> METAL_FUNC void normal(
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0 && off == 0, otherwise we will write to out[size].
|
||||
if (off == 0)
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - off - tid] = static_cast<T>(z1);
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
|
@ -292,7 +292,7 @@ fn binary_ops_bf16() {
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
@ -319,189 +319,107 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_f32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
// f32 -> f16
|
||||
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
// f32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f32 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_f16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// f16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
// f16 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_bf16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// bf16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
// bf16 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// bf16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// bf16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// bf16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
// u32 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// u32 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// u32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u8() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
// u8 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
// u8 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u8 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u8 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// u8 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_i64() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
// i64 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
// i64 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
// i64 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// i64 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
// i64 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
@ -1148,107 +1066,3 @@ fn random() {
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
|
||||
fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
input: &[T],
|
||||
ids: &[I],
|
||||
shape: &[usize],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec(&output, input.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scatter_add() {
|
||||
let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3];
|
||||
let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3];
|
||||
let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3];
|
||||
|
||||
let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0];
|
||||
let input_f16 = input_f32
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let input_bf16 = input_f32
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let output_dim1_f16 = output_dim1_f32
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let output_dim1_bf16 = output_dim1_f32
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0];
|
||||
let output_dim2_f16 = output_dim2_f32
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let output_dim2_bf16 = output_dim2_f32
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (shape, output_f32, output_f16, output_bf16) in [
|
||||
(vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16),
|
||||
(
|
||||
vec![4, 2],
|
||||
output_dim2_f32,
|
||||
output_dim2_f16,
|
||||
output_dim2_bf16,
|
||||
),
|
||||
] {
|
||||
for results in [
|
||||
run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"),
|
||||
run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"),
|
||||
run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"),
|
||||
] {
|
||||
assert_eq!(results, output_f32);
|
||||
}
|
||||
for results in [
|
||||
run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"),
|
||||
run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"),
|
||||
run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"),
|
||||
] {
|
||||
assert_eq!(results, output_f16);
|
||||
}
|
||||
for results in [
|
||||
run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"),
|
||||
run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"),
|
||||
run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"),
|
||||
] {
|
||||
assert_eq!(results, output_bf16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,30 +102,6 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define COPY2D(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &d1, \
|
||||
constant size_t &d2, \
|
||||
constant size_t &src_s, \
|
||||
constant size_t &dst_s, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= d1 * d2) { \
|
||||
return; \
|
||||
} \
|
||||
size_t idx1 = tid / d2; \
|
||||
size_t idx2 = tid - idx1 * d2; \
|
||||
size_t src_idx = idx1 * src_s + idx2; \
|
||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||
output[dst_idx] = input[src_idx]; \
|
||||
}
|
||||
|
||||
COPY2D(copy2d_f32, float)
|
||||
COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
@ -152,7 +128,6 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -176,6 +151,4 @@ BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
COPY2D(copy2d_bf64, bfloat)
|
||||
#endif
|
||||
|
@ -238,23 +238,6 @@ impl Benchmark for QMatMul {
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Cat;
|
||||
impl Benchmark for Cat {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
Tensor::cat(&[&d.0, &d.1], 2)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1000;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
@ -312,7 +295,6 @@ enum Task {
|
||||
Qmatmul,
|
||||
Softmax,
|
||||
SoftmaxLastDim,
|
||||
Cat,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -337,7 +319,6 @@ fn main() -> Result<()> {
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
|
||||
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
||||
Task::Cat => run::<Cat>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
||||
xs * mask
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct Dropout {
|
||||
drop_p: f32,
|
||||
}
|
||||
@ -238,8 +238,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage =
|
||||
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.4.2"
|
||||
version = "0.4.1"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.2" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.1" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.1" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -2,7 +2,7 @@ use candle::{DType, Device, Error as E, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{embedding, linear_b, rms_norm, Embedding, Linear, RmsNorm, VarBuilder};
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
pub(crate) fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
fn repeat_interleave(img: &Tensor, repeats: usize, dim: usize) -> Result<Tensor> {
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
@ -55,12 +55,12 @@ pub mod speaker_encoder {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(
|
||||
cfg.mel_n_channels,
|
||||
cfg.model_hidden_size,
|
||||
c,
|
||||
vb_l.pp(layer_idx),
|
||||
)?;
|
||||
let in_c = if layer_idx == 0 {
|
||||
cfg.mel_n_channels
|
||||
} else {
|
||||
cfg.model_hidden_size
|
||||
};
|
||||
let lstm = candle_nn::lstm(in_c, cfg.model_hidden_size, c, vb_l.clone())?;
|
||||
lstms.push(lstm)
|
||||
}
|
||||
let linear = linear_b(
|
||||
@ -143,7 +143,7 @@ pub mod speaker_encoder {
|
||||
.iter()
|
||||
.flat_map(|s| [mel[s.0], mel[s.1]])
|
||||
.collect::<Vec<_>>();
|
||||
let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
|
||||
let mels = Tensor::from_vec(mels, (1, mel_slices.len(), 2), device)?;
|
||||
let partial_embeds = self.forward(&mels)?;
|
||||
let raw_embed = partial_embeds.mean(0)?;
|
||||
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
|
||||
@ -181,7 +181,6 @@ pub mod tokenizers {
|
||||
pub end_of_text: usize,
|
||||
pub offset: usize,
|
||||
pub ranks: HashMap<Vec<u8>, Rank>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl BPE {
|
||||
@ -232,7 +231,6 @@ pub mod tokenizers {
|
||||
end_of_text,
|
||||
offset,
|
||||
ranks,
|
||||
span: tracing::span!(tracing::Level::TRACE, "bpe"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -312,7 +310,6 @@ pub mod tokenizers {
|
||||
}
|
||||
|
||||
pub fn encode(&self, text: &str) -> Result<Vec<u32>> {
|
||||
let _enter = self.span.enter();
|
||||
let mut bpe_tokens: Vec<u32> = Vec::new();
|
||||
for word in self.re.find_iter(text) {
|
||||
let word = word.map_err(E::wrap)?;
|
||||
@ -429,7 +426,6 @@ pub mod gpt {
|
||||
c_attn: Linear,
|
||||
c_proj: Linear,
|
||||
n_head: usize,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl SelfAttention {
|
||||
@ -448,14 +444,12 @@ pub mod gpt {
|
||||
c_attn,
|
||||
c_proj,
|
||||
n_head: cfg.n_head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "self-attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for SelfAttention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, t, c) = xs.dims3()?;
|
||||
let c_x = xs
|
||||
.apply(&self.c_attn)?
|
||||
@ -480,13 +474,11 @@ pub mod gpt {
|
||||
Gelu {
|
||||
c_fc: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
},
|
||||
Swiglu {
|
||||
w1: Linear,
|
||||
w3: Linear,
|
||||
c_proj: Linear,
|
||||
span: tracing::Span,
|
||||
},
|
||||
}
|
||||
|
||||
@ -497,11 +489,7 @@ pub mod gpt {
|
||||
NonLinearityType::Gelu => {
|
||||
let c_fc = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("c_fc"))?;
|
||||
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||
Self::Gelu {
|
||||
c_fc,
|
||||
c_proj,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp-gelu"),
|
||||
}
|
||||
Self::Gelu { c_fc, c_proj }
|
||||
}
|
||||
NonLinearityType::Swiglu => {
|
||||
let hidden_dim = (2 * hidden_dim) / 3;
|
||||
@ -514,12 +502,7 @@ pub mod gpt {
|
||||
let w1 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w1"))?;
|
||||
let w3 = linear_b(cfg.n_embd, hidden_dim, cfg.bias, vb.pp("w3"))?;
|
||||
let c_proj = linear_b(hidden_dim, cfg.n_embd, cfg.bias, vb.pp("c_proj"))?;
|
||||
Self::Swiglu {
|
||||
w1,
|
||||
w3,
|
||||
c_proj,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp-swiglu"),
|
||||
}
|
||||
Self::Swiglu { w1, w3, c_proj }
|
||||
}
|
||||
};
|
||||
Ok(slf)
|
||||
@ -529,17 +512,8 @@ pub mod gpt {
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
Self::Gelu { c_fc, c_proj, span } => {
|
||||
let _enter = span.enter();
|
||||
xs.apply(c_fc)?.gelu()?.apply(c_proj)
|
||||
}
|
||||
Self::Swiglu {
|
||||
w1,
|
||||
w3,
|
||||
c_proj,
|
||||
span,
|
||||
} => {
|
||||
let _enter = span.enter();
|
||||
Self::Gelu { c_fc, c_proj } => xs.apply(c_fc)?.gelu()?.apply(c_proj),
|
||||
Self::Swiglu { w1, w3, c_proj } => {
|
||||
let w1 = xs.apply(w1)?;
|
||||
let w3 = xs.apply(w3)?;
|
||||
(w1.silu()? * w3)?.apply(c_proj)
|
||||
@ -554,7 +528,6 @@ pub mod gpt {
|
||||
ln_2: Norm,
|
||||
attn: SelfAttention,
|
||||
mlp: MLP,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
@ -568,14 +541,12 @@ pub mod gpt {
|
||||
ln_2,
|
||||
attn,
|
||||
mlp,
|
||||
span: tracing::span!(tracing::Level::TRACE, "gpt-block"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Block {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let xs = (xs + xs.apply(&self.ln_1)?.apply(&self.attn))?;
|
||||
let xs = (&xs + xs.apply(&self.ln_2)?.apply(&self.mlp))?;
|
||||
Ok(xs)
|
||||
@ -592,7 +563,6 @@ pub mod gpt {
|
||||
lm_heads: Vec<Linear>,
|
||||
cfg: Config,
|
||||
dtype: DType,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -628,7 +598,6 @@ pub mod gpt {
|
||||
lm_heads,
|
||||
cfg,
|
||||
dtype: vb.dtype(),
|
||||
span: tracing::span!(tracing::Level::TRACE, "gpt"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -637,7 +606,6 @@ pub mod gpt {
|
||||
}
|
||||
|
||||
pub fn forward(&self, idx: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let _enter = self.span.enter();
|
||||
let device = idx.device();
|
||||
let (b, _num_hierarchies, t) = idx.dims3()?;
|
||||
let pos = Tensor::arange(0u32, t as u32, device)?;
|
||||
@ -696,15 +664,15 @@ pub mod transformer {
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn n_local_heads(&self) -> usize {
|
||||
fn n_local_heads(&self) -> usize {
|
||||
self.n_local_heads.unwrap_or(self.n_head)
|
||||
}
|
||||
|
||||
pub(crate) fn head_dim(&self) -> usize {
|
||||
fn head_dim(&self) -> usize {
|
||||
self.dim / self.n_head
|
||||
}
|
||||
|
||||
pub(crate) fn intermediate_size(&self) -> usize {
|
||||
fn intermediate_size(&self) -> usize {
|
||||
match self.intermediate_size {
|
||||
Some(intermediate_size) => intermediate_size,
|
||||
None => {
|
||||
@ -721,7 +689,6 @@ pub mod transformer {
|
||||
w1: Linear,
|
||||
w2: Linear,
|
||||
w3: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
@ -730,18 +697,12 @@ pub mod transformer {
|
||||
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
||||
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
||||
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
||||
Ok(Self {
|
||||
w1,
|
||||
w2,
|
||||
w3,
|
||||
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
|
||||
})
|
||||
Ok(Self { w1, w2, w3 })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
||||
swiglu.apply(&self.w2)
|
||||
}
|
||||
@ -757,7 +718,6 @@ pub mod transformer {
|
||||
head_dim: usize,
|
||||
n_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
@ -776,12 +736,10 @@ pub mod transformer {
|
||||
head_dim,
|
||||
n_head: cfg.n_head,
|
||||
kv_cache: None,
|
||||
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seqlen, _) = xs.dims3()?;
|
||||
|
||||
let qkv = xs.apply(&self.wqkv)?;
|
||||
@ -835,7 +793,6 @@ pub mod transformer {
|
||||
feed_forward: FeedForward,
|
||||
ffn_norm: RmsNorm,
|
||||
attention_norm: RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
@ -849,12 +806,10 @@ pub mod transformer {
|
||||
feed_forward,
|
||||
ffn_norm,
|
||||
attention_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hs = xs.apply(&self.attention_norm)?;
|
||||
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
||||
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
||||
@ -874,7 +829,6 @@ pub mod transformer {
|
||||
norm: RmsNorm,
|
||||
output: Linear,
|
||||
spk_cond_mask: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
@ -911,7 +865,6 @@ pub mod transformer {
|
||||
norm,
|
||||
output,
|
||||
spk_cond_mask,
|
||||
span: tracing::span!(tracing::Level::TRACE, "transformer"),
|
||||
})
|
||||
}
|
||||
|
||||
@ -922,7 +875,6 @@ pub mod transformer {
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_sz, seqlen) = xs.dims2()?;
|
||||
let mask: Vec<_> = (0..seqlen)
|
||||
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
@ -953,19 +905,14 @@ pub mod adapters {
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/tilted_encodec.py
|
||||
pub struct TiltedEncodec {
|
||||
end_of_audio_token: u32,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl TiltedEncodec {
|
||||
pub fn new(end_of_audio_token: u32) -> Self {
|
||||
Self {
|
||||
end_of_audio_token,
|
||||
span: tracing::span!(tracing::Level::TRACE, "tilted-encodec"),
|
||||
}
|
||||
Self { end_of_audio_token }
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[Vec<u32>]) -> (Vec<u32>, Vec<Vec<u32>>) {
|
||||
let _enter = self.span.enter();
|
||||
let mut text_ids = vec![];
|
||||
let mut extracted_audio_ids = vec![];
|
||||
let mut min_audio_ids_len = usize::MAX;
|
||||
@ -994,19 +941,14 @@ pub mod adapters {
|
||||
// https://github.com/metavoiceio/metavoice-src/blob/9078234c496d76adbec06df789b6b04b1875f129/fam/llm/adapters/flattened_encodec.py#L4
|
||||
pub struct FlattenedInterleavedEncodec2Codebook {
|
||||
end_of_audio_token: u32,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FlattenedInterleavedEncodec2Codebook {
|
||||
pub fn new(end_of_audio_token: u32) -> Self {
|
||||
Self {
|
||||
end_of_audio_token,
|
||||
span: tracing::span!(tracing::Level::TRACE, "encodec2codebook"),
|
||||
}
|
||||
Self { end_of_audio_token }
|
||||
}
|
||||
|
||||
pub fn decode(&self, tokens: &[u32]) -> (Vec<u32>, Vec<u32>, Vec<u32>) {
|
||||
let _enter = self.span.enter();
|
||||
let mut text_ids = vec![];
|
||||
let mut audio_ids1 = vec![];
|
||||
let mut audio_ids2 = vec![];
|
||||
|
@ -30,7 +30,6 @@ pub mod quantized_blip;
|
||||
pub mod quantized_blip_text;
|
||||
pub mod quantized_llama;
|
||||
pub mod quantized_llama2_c;
|
||||
pub mod quantized_metavoice;
|
||||
pub mod quantized_mistral;
|
||||
pub mod quantized_mixformer;
|
||||
pub mod quantized_mpt;
|
||||
|
@ -1,242 +0,0 @@
|
||||
use crate::quantized_nn::{linear_b, Embedding, Linear, RmsNorm};
|
||||
pub use crate::quantized_var_builder::VarBuilder;
|
||||
|
||||
use crate::models::metavoice::repeat_interleave;
|
||||
use candle::{Module, Result, Tensor, D};
|
||||
|
||||
pub mod transformer {
|
||||
use super::*;
|
||||
|
||||
type Config = crate::models::metavoice::transformer::Config;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct FeedForward {
|
||||
w1: Linear,
|
||||
w2: Linear,
|
||||
w3: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl FeedForward {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let i_size = cfg.intermediate_size();
|
||||
let w1 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w1"))?;
|
||||
let w2 = linear_b(i_size, cfg.dim, false, vb.pp("w2"))?;
|
||||
let w3 = linear_b(cfg.dim, i_size, false, vb.pp("swiglu.w3"))?;
|
||||
Ok(Self {
|
||||
w1,
|
||||
w2,
|
||||
w3,
|
||||
span: tracing::span!(tracing::Level::TRACE, "feed-forward"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for FeedForward {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let swiglu = (candle_nn::ops::silu(&xs.apply(&self.w1)?)? * xs.apply(&self.w3))?;
|
||||
swiglu.apply(&self.w2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Attention {
|
||||
wqkv: Linear,
|
||||
wo: Linear,
|
||||
dim: usize,
|
||||
kv_size: usize,
|
||||
n_local_heads: usize,
|
||||
head_dim: usize,
|
||||
n_head: usize,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let n_local_heads = cfg.n_local_heads();
|
||||
let head_dim = cfg.head_dim();
|
||||
let total_head_dim = (cfg.n_head + 2 * n_local_heads) * head_dim;
|
||||
let wqkv = linear_b(cfg.dim, total_head_dim, false, vb.pp("wqkv"))?;
|
||||
let wo = linear_b(cfg.dim, cfg.dim, false, vb.pp("wo"))?;
|
||||
Ok(Self {
|
||||
wqkv,
|
||||
wo,
|
||||
dim: cfg.dim,
|
||||
kv_size: n_local_heads * head_dim,
|
||||
n_local_heads,
|
||||
head_dim,
|
||||
n_head: cfg.n_head,
|
||||
kv_cache: None,
|
||||
span: tracing::span!(tracing::Level::TRACE, "attention"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, _pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b_sz, seqlen, _) = xs.dims3()?;
|
||||
|
||||
let qkv = xs.apply(&self.wqkv)?;
|
||||
let q = qkv.narrow(D::Minus1, 0, self.dim)?;
|
||||
let k = qkv.narrow(D::Minus1, self.dim, self.kv_size)?;
|
||||
let v = qkv.narrow(D::Minus1, self.dim + self.kv_size, self.kv_size)?;
|
||||
let q = q
|
||||
.reshape((b_sz, seqlen, self.n_head, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
let v = v
|
||||
.reshape((b_sz, seqlen, self.n_local_heads, self.head_dim))?
|
||||
.transpose(1, 2)?;
|
||||
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
let k = Tensor::cat(&[prev_k, &k], 2)?;
|
||||
let v = Tensor::cat(&[prev_v, &v], 2)?;
|
||||
(k, v)
|
||||
}
|
||||
};
|
||||
self.kv_cache = Some((k.clone(), v.clone()));
|
||||
|
||||
let k = repeat_interleave(&k, self.n_head / self.n_local_heads, 1)?;
|
||||
let v = repeat_interleave(&v, self.n_head / self.n_local_heads, 1)?;
|
||||
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
|
||||
let attn_weights = attn_weights.broadcast_add(mask)?;
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
let attn_output = attn_weights.matmul(&v)?;
|
||||
attn_output
|
||||
.transpose(1, 2)?
|
||||
.reshape((b_sz, seqlen, self.dim))?
|
||||
.apply(&self.wo)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.kv_cache = None
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Block {
|
||||
attention: Attention,
|
||||
feed_forward: FeedForward,
|
||||
ffn_norm: RmsNorm,
|
||||
attention_norm: RmsNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Block {
|
||||
fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let attention = Attention::new(cfg, vb.pp("attention"))?;
|
||||
let feed_forward = FeedForward::new(cfg, vb.pp("feed_forward"))?;
|
||||
let ffn_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("ffn_norm"))?;
|
||||
let attention_norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("attention_norm"))?;
|
||||
Ok(Self {
|
||||
attention,
|
||||
feed_forward,
|
||||
ffn_norm,
|
||||
attention_norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "block"),
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&mut self, xs: &Tensor, pos: usize, mask: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let hs = xs.apply(&self.attention_norm)?;
|
||||
let hs = (xs + self.attention.forward(&hs, pos, mask))?;
|
||||
&hs + hs.apply(&self.ffn_norm)?.apply(&self.feed_forward)
|
||||
}
|
||||
|
||||
fn clear_kv_cache(&mut self) {
|
||||
self.attention.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
tok_embeddings: Embedding,
|
||||
pos_embeddings: Embedding,
|
||||
speaker_cond_pos: Linear,
|
||||
layers: Vec<Block>,
|
||||
norm: RmsNorm,
|
||||
output: Linear,
|
||||
spk_cond_mask: Tensor,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let tok_embeddings = Embedding::new(cfg.vocab_size, cfg.dim, vb.pp("tok_embeddings"))?;
|
||||
let pos_embeddings = Embedding::new(cfg.block_size, cfg.dim, vb.pp("pos_embeddings"))?;
|
||||
let speaker_cond_pos = linear_b(
|
||||
cfg.speaker_emb_dim,
|
||||
cfg.dim,
|
||||
false,
|
||||
vb.pp("speaker_cond_pos"),
|
||||
)?;
|
||||
let mut layers = Vec::with_capacity(cfg.n_layer);
|
||||
let vb_l = vb.pp("layers");
|
||||
for layer_idx in 0..cfg.n_layer {
|
||||
let layer = Block::new(cfg, vb_l.pp(layer_idx))?;
|
||||
layers.push(layer)
|
||||
}
|
||||
let norm = RmsNorm::new(cfg.dim, cfg.norm_eps, vb.pp("norm"))?;
|
||||
let output = linear_b(cfg.dim, cfg.vocab_size, false, vb.pp("output"))?;
|
||||
let spk_cond_mask = Tensor::cat(
|
||||
&[
|
||||
Tensor::ones((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
|
||||
Tensor::zeros((1, 1, cfg.dim), candle::DType::F32, vb.device())?,
|
||||
],
|
||||
0,
|
||||
)?;
|
||||
Ok(Self {
|
||||
tok_embeddings,
|
||||
pos_embeddings,
|
||||
speaker_cond_pos,
|
||||
layers,
|
||||
norm,
|
||||
output,
|
||||
spk_cond_mask,
|
||||
span: tracing::span!(tracing::Level::TRACE, "qtransformer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn clear_kv_cache(&mut self) {
|
||||
for layer in self.layers.iter_mut() {
|
||||
layer.clear_kv_cache()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn forward(&mut self, xs: &Tensor, spk_emb: &Tensor, pos: usize) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (_b_sz, seqlen) = xs.dims2()?;
|
||||
let mask: Vec<_> = (0..seqlen)
|
||||
.flat_map(|i| (0..seqlen).map(move |j| if i < j { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
let mask = Tensor::from_slice(&mask, (1, 1, seqlen, seqlen), xs.device())?;
|
||||
let input_pos = Tensor::arange(pos as u32, (pos + seqlen) as u32, xs.device())?;
|
||||
let tok_embeddings = xs.apply(&self.tok_embeddings)?;
|
||||
let pos_embeddings = input_pos.apply(&self.pos_embeddings)?;
|
||||
let mut xs = tok_embeddings
|
||||
.broadcast_add(&pos_embeddings)?
|
||||
.broadcast_add(
|
||||
&spk_emb
|
||||
.apply(&self.speaker_cond_pos)?
|
||||
.broadcast_mul(&self.spk_cond_mask)?,
|
||||
)?;
|
||||
let mask = mask.to_dtype(xs.dtype())?;
|
||||
for layer in self.layers.iter_mut() {
|
||||
xs = layer.forward(&xs, pos, &mask)?
|
||||
}
|
||||
xs.narrow(1, seqlen - 1, 1)?
|
||||
.apply(&self.norm)?
|
||||
.apply(&self.output)
|
||||
}
|
||||
}
|
||||
}
|
@ -116,12 +116,6 @@ impl QMatMul {
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
|
||||
pub fn from_weights(ws: std::sync::Arc<candle::quantized::QTensor>) -> Result<Self> {
|
||||
let inner = candle::quantized::QMatMul::from_arc(ws)?;
|
||||
let span = tracing::span!(tracing::Level::TRACE, "qmatmul");
|
||||
Ok(Self { inner, span })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for QMatMul {
|
||||
|
@ -35,14 +35,6 @@ pub struct Linear {
|
||||
}
|
||||
|
||||
impl Linear {
|
||||
pub fn from_arc(
|
||||
weight: std::sync::Arc<candle::quantized::QTensor>,
|
||||
bias: Option<Tensor>,
|
||||
) -> Result<Self> {
|
||||
let weight = QMatMul::from_weights(weight)?;
|
||||
Ok(Self { weight, bias })
|
||||
}
|
||||
|
||||
pub fn from_weights(weight: QMatMul, bias: Option<Tensor>) -> Self {
|
||||
Self { weight, bias }
|
||||
}
|
||||
@ -58,16 +50,6 @@ impl Module for Linear {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn linear_b(in_dim: usize, out_dim: usize, bias: bool, vb: VarBuilder) -> Result<Linear> {
|
||||
let bias = if bias {
|
||||
Some(vb.get(out_dim, "bias")?.dequantize(vb.device())?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
Ok(Linear { weight, bias })
|
||||
}
|
||||
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vb: VarBuilder) -> Result<Linear> {
|
||||
let bias = vb.get(out_dim, "bias")?.dequantize(vb.device())?;
|
||||
let weight = QMatMul::new(in_dim, out_dim, vb)?;
|
||||
|
@ -3,7 +3,6 @@ use candle::{Device, Result, Shape};
|
||||
use std::sync::Arc;
|
||||
|
||||
// VarBuilder specialized for QTensors
|
||||
#[derive(Clone)]
|
||||
pub struct VarBuilder {
|
||||
data: Arc<std::collections::HashMap<String, Arc<QTensor>>>,
|
||||
path: Vec<String>,
|
||||
|
Reference in New Issue
Block a user