mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
1 Commits
cuda-conv-
...
precompile
Author | SHA1 | Date | |
---|---|---|---|
5ac3302fac |
19
Cargo.toml
19
Cargo.toml
@ -19,7 +19,7 @@ exclude = [
|
||||
resolver = "2"
|
||||
|
||||
[workspace.package]
|
||||
version = "0.4.2"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
description = "Minimalist ML framework."
|
||||
repository = "https://github.com/huggingface/candle"
|
||||
@ -31,18 +31,17 @@ license = "MIT OR Apache-2.0"
|
||||
accelerate-src = { version = "0.3.2" }
|
||||
anyhow = { version = "1", features = ["backtrace"] }
|
||||
byteorder = "1.4.3"
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.2" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.2" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.2" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.2" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.2" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.2" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.2" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.2" }
|
||||
candle = { path = "./candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-datasets = { path = "./candle-datasets", version = "0.4.0" }
|
||||
candle-flash-attn = { path = "./candle-flash-attn", version = "0.4.0" }
|
||||
candle-kernels = { path = "./candle-kernels", version = "0.4.0" }
|
||||
candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.4.0" }
|
||||
candle-nn = { path = "./candle-nn", version = "0.4.0" }
|
||||
candle-onnx = { path = "./candle-onnx", version = "0.4.0" }
|
||||
candle-transformers = { path = "./candle-transformers", version = "0.4.0" }
|
||||
clap = { version = "4.2.4", features = ["derive"] }
|
||||
criterion = { version = "0.5.1", default-features=false }
|
||||
cudarc = { version = "0.10.0", features = ["f16"] }
|
||||
fancy-regex = "0.13.0"
|
||||
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
|
||||
hf-hub = "0.3.0"
|
||||
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
|
||||
|
26
README.md
26
README.md
@ -63,8 +63,6 @@ We also provide a some command line based examples using state of the art models
|
||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM, includes
|
||||
the SOLAR-10.7B variant.
|
||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||
- [Gemma](./candle-examples/examples/gemma/): 2b and 7b general LLMs from Google
|
||||
Deepmind.
|
||||
- [Phi-1, Phi-1.5, and Phi-2](./candle-examples/examples/phi/): 1.3b and 2.7b general LLMs with performance on par with LLaMA-v2 7b.
|
||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||
pre-trained on 1T tokens of English and code datasets. Also supports
|
||||
@ -76,10 +74,9 @@ We also provide a some command line based examples using state of the art models
|
||||
- [Mixtral8x7b-v0.1](./candle-examples/examples/mixtral/): a sparse mixture of
|
||||
experts 8x7b general LLM with better performance than a Llama 2 70B model with
|
||||
much faster inference.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/) and
|
||||
[StarCoder2](./candle-examples/examples/starcoder2/): LLM specialized to code generation.
|
||||
- [StarCoder](./candle-examples/examples/bigcode/): LLM specialized to code generation.
|
||||
- [Qwen1.5](./candle-examples/examples/qwen/): Bilingual (English/Chinese) LLMs.
|
||||
- [RWKV v5 and v6](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
- [RWKV v5](./candle-examples/examples/rwkv/): An RNN with transformer level LLM
|
||||
performance.
|
||||
- [Replit-code-v1.5](./candle-examples/examples/replit-code/): a 3.3b LLM specialized for code completion.
|
||||
- [Yi-6B / Yi-34B](./candle-examples/examples/yi/): two bilingual
|
||||
@ -110,12 +107,7 @@ We also provide a some command line based examples using state of the art models
|
||||
|
||||
<img src="https://github.com/huggingface/candle/raw/main/candle-examples/examples/segment-anything/assets/sam_merged.jpg" width="200">
|
||||
|
||||
- [SegFormer](./candle-examples/examples/segformer/): transformer based semantic segmantation model.
|
||||
- [Whisper](./candle-examples/examples/whisper/): speech recognition model.
|
||||
- [EnCodec](./candle-examples/examples/encodec/): high-quality audio compression
|
||||
model using residual vector quantization.
|
||||
- [MetaVoice](./candle-examples/examples/metavoice/): foundational model for
|
||||
text-to-speech.
|
||||
- [T5](./candle-examples/examples/t5), [Bert](./candle-examples/examples/bert/),
|
||||
[JinaBert](./candle-examples/examples/jina-bert/) : useful for sentence embeddings.
|
||||
- [DINOv2](./candle-examples/examples/dinov2/): computer vision model trained
|
||||
@ -175,7 +167,6 @@ And then head over to
|
||||
- [`kalosm`](https://github.com/floneum/floneum/tree/master/interfaces/kalosm): A multi-modal meta-framework in Rust for interfacing with local pre-trained models with support for controlled generation, custom samplers, in-memory vector databases, audio transcription, and more.
|
||||
- [`candle-sampling`](https://github.com/EricLBuehler/candle-sampling): Sampling techniques for Candle.
|
||||
- [`gpt-from-scratch-rs`](https://github.com/jeroenvlek/gpt-from-scratch-rs): A port of Andrej Karpathy's _Let's build GPT_ tutorial on YouTube showcasing the Candle API on a toy problem.
|
||||
- [`candle-einops`](https://github.com/tomsanbear/candle-einops): A pure rust implementation of the python [einops](https://github.com/arogozhnikov/einops) library.
|
||||
|
||||
If you have an addition to this list, please submit a pull request.
|
||||
|
||||
@ -196,10 +187,9 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Language Models.
|
||||
- LLaMA v1 and v2 with variants such as SOLAR-10.7B.
|
||||
- Falcon.
|
||||
- StarCoder, StarCoder2.
|
||||
- StarCoder.
|
||||
- Phi 1, 1.5, and 2.
|
||||
- Mamba, Minimal Mamba
|
||||
- Gemma 2b and 7b.
|
||||
- Mistral 7b v0.1.
|
||||
- Mixtral 8x7b v0.1.
|
||||
- StableLM-3B-4E1T, StableLM-2-1.6B, Stable-Code-3B.
|
||||
@ -207,7 +197,7 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Bert.
|
||||
- Yi-6B and Yi-34B.
|
||||
- Qwen1.5.
|
||||
- RWKV v5 and v6.
|
||||
- RWKV.
|
||||
- Quantized LLMs.
|
||||
- Llama 7b, 13b, 70b, as well as the chat and code variants.
|
||||
- Mistral 7b, and 7b instruct.
|
||||
@ -217,22 +207,18 @@ If you have an addition to this list, please submit a pull request.
|
||||
- Text to text.
|
||||
- T5 and its variants: FlanT5, UL2, MADLAD400 (translation), CoEdit (Grammar correction).
|
||||
- Marian MT (Machine Translation).
|
||||
- Whisper (multi-lingual support).
|
||||
- Text to image.
|
||||
- Stable Diffusion v1.5, v2.1, XL v1.0.
|
||||
- Wurstchen v2.
|
||||
- Image to text.
|
||||
- BLIP.
|
||||
- TrOCR.
|
||||
- Audio.
|
||||
- Whisper, multi-lingual speech-to-text.
|
||||
- EnCodec, audio compression model.
|
||||
- MetaVoice-1B, text-to-speech model.
|
||||
- Computer Vision Models.
|
||||
- DINOv2, ConvMixer, EfficientNet, ResNet, ViT, VGG, RepVGG, ConvNeXT,
|
||||
ConvNeXTv2, MobileOne, EfficientVit (MSRA).
|
||||
ConvNeXTv2.
|
||||
- yolo-v3, yolo-v8.
|
||||
- Segment-Anything Model (SAM).
|
||||
- SegFormer.
|
||||
- File formats: load models from safetensors, npz, ggml, or PyTorch files.
|
||||
- Serverless (on CPU), small and fast deployments.
|
||||
- Quantization support using the llama.cpp quantized types.
|
||||
|
@ -5,32 +5,25 @@ extern crate accelerate_src;
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Module, Tensor};
|
||||
|
||||
use candle_core::quantized::{QMatMul, QTensor};
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
let q = Tensor::randn(0f32, 1.0, (72, 256), &device)?;
|
||||
let q_cpu = q.to_device(&Device::Cpu)?;
|
||||
let q = QTensor::quantize(&q, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q = QMatMul::from_qtensor(q)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (5, 256), &device)?;
|
||||
let res_q_cuda = q.forward(&x)?;
|
||||
println!("{res_q_cuda}");
|
||||
|
||||
let q_cpu = QTensor::quantize(&q_cpu, candle_core::quantized::GgmlDType::Q8K)?;
|
||||
let q_cpu_tensor = q_cpu.dequantize(&Device::Cpu)?;
|
||||
let q_cpu = QMatMul::from_qtensor(q_cpu)?;
|
||||
let x_cpu = x.to_device(&Device::Cpu)?;
|
||||
let res_q_cpu = q_cpu.forward(&x_cpu)?;
|
||||
println!("{res_q_cpu}");
|
||||
|
||||
let res_mm = x_cpu.matmul(&q_cpu_tensor.t()?)?;
|
||||
let diff = (res_mm - res_q_cuda.to_device(&Device::Cpu))?
|
||||
.abs()?
|
||||
.flatten_all()?
|
||||
.max(0)?;
|
||||
let in_t = Tensor::rand(-1f32, 1f32, (1, 3, 12, 7), &device)?;
|
||||
let k_t = Tensor::rand(-1f32, 1f32, (6, 3, 1, 1), &device)?;
|
||||
let out_t = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
println!("{out_t}");
|
||||
let in_t = in_t.to_device(&Device::Cpu)?;
|
||||
let k_t = k_t.to_device(&Device::Cpu)?;
|
||||
let out_t2 = in_t.conv2d(&k_t, 0, 1, 1, 1)?;
|
||||
let diff = (out_t.to_device(&Device::Cpu)? - out_t2)?
|
||||
.sqr()?
|
||||
.sum_all()?;
|
||||
println!("{diff}");
|
||||
|
||||
let t = Tensor::randn(0f32, 1f32, (2, 4, 96, 96), &device)?;
|
||||
let w = Tensor::randn(0f32, 1f32, (320, 4, 3, 3), &device)?;
|
||||
let res = t.conv2d(&w, 1, 1, 1, 1)?;
|
||||
println!("{res:?}");
|
||||
Ok(())
|
||||
}
|
||||
|
@ -98,19 +98,6 @@ pub trait BackendStorage: Sized {
|
||||
) -> Result<Self>;
|
||||
|
||||
fn copy_strided_src(&self, _: &mut Self, _: usize, _: &Layout) -> Result<()>;
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
// Similar to cudaMemcpy2D, though values are in elements and not in bytes.
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_d1: usize,
|
||||
_d2: usize,
|
||||
_src_stride1: usize,
|
||||
_dst_stride1: usize,
|
||||
_src_offset: usize,
|
||||
_dst_offset: usize,
|
||||
) -> Result<()>;
|
||||
}
|
||||
|
||||
pub trait BackendDevice: Sized + std::fmt::Debug + Clone {
|
||||
|
@ -113,7 +113,7 @@ impl Tensor {
|
||||
| Op::Unary(_node, UnaryOp::Floor)
|
||||
| Op::Unary(_node, UnaryOp::Round) => nodes,
|
||||
Op::Reshape(node)
|
||||
| Op::UpsampleNearest1D { arg: node, .. }
|
||||
| Op::UpsampleNearest1D(node)
|
||||
| Op::UpsampleNearest2D { arg: node, .. }
|
||||
| Op::AvgPool2D { arg: node, .. }
|
||||
| Op::MaxPool2D { arg: node, .. }
|
||||
@ -250,7 +250,6 @@ impl Tensor {
|
||||
out_padding,
|
||||
*stride,
|
||||
*dilation,
|
||||
/* groups */ 1,
|
||||
)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
@ -348,18 +347,9 @@ impl Tensor {
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = sum_grad.add(&grad_arg)?;
|
||||
}
|
||||
Op::UpsampleNearest1D { arg, target_size } => {
|
||||
let (_n, c, size) = arg.dims3()?;
|
||||
if target_size % size != 0 {
|
||||
crate::bail!("backward not supported for non integer upscaling factors")
|
||||
}
|
||||
let scale = target_size / size;
|
||||
|
||||
let kernel = Tensor::ones((c, 1, scale), arg.dtype(), arg.device())?;
|
||||
let conv_sum = grad.conv1d(&kernel, 0, scale, 1, c)?;
|
||||
let sum_grad = grads.or_insert(arg)?;
|
||||
*sum_grad = conv_sum;
|
||||
}
|
||||
Op::UpsampleNearest1D { .. } => Err(Error::BackwardNotSupported {
|
||||
op: "upsample-nearest1d",
|
||||
})?,
|
||||
Op::UpsampleNearest2D {
|
||||
arg,
|
||||
target_h,
|
||||
|
@ -187,16 +187,36 @@ impl Tensor {
|
||||
}
|
||||
}
|
||||
|
||||
fn conv_transpose1d_single_group(
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
params: &ParamsConvTranspose1D,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
) -> Result<Self> {
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
let storage = self.storage().conv_transpose1d(
|
||||
self.layout(),
|
||||
&kernel.storage(),
|
||||
kernel.layout(),
|
||||
params,
|
||||
¶ms,
|
||||
)?;
|
||||
let op = BackpropOp::new2(self, kernel, |arg, kernel| Op::ConvTranspose1D {
|
||||
arg,
|
||||
@ -210,49 +230,6 @@ impl Tensor {
|
||||
Ok(crate::tensor::from_storage(storage, out_dims, op, false))
|
||||
}
|
||||
|
||||
/// Applies a 1D transposed convolution over the input tensor.
|
||||
pub fn conv_transpose1d(
|
||||
&self,
|
||||
kernel: &Self,
|
||||
padding: usize,
|
||||
output_padding: usize,
|
||||
stride: usize,
|
||||
dilation: usize,
|
||||
groups: usize,
|
||||
) -> Result<Self> {
|
||||
let (c_in_k, c_out, k_size) = kernel.dims3()?;
|
||||
let (b_size, c_in, l_in) = self.dims3()?;
|
||||
if c_in != c_in_k {
|
||||
crate::bail!("in_channel mismatch between input ({c_in}) and kernel ({c_in_k})")
|
||||
}
|
||||
if c_in % groups != 0 {
|
||||
crate::bail!("in_channel {c_in} is not divisible by the number of groups")
|
||||
}
|
||||
let params = ParamsConvTranspose1D {
|
||||
b_size,
|
||||
l_in,
|
||||
k_size,
|
||||
c_out,
|
||||
c_in: c_in / groups,
|
||||
padding,
|
||||
output_padding,
|
||||
stride,
|
||||
dilation,
|
||||
};
|
||||
if groups == 1 {
|
||||
self.conv_transpose1d_single_group(kernel, ¶ms)
|
||||
} else {
|
||||
let blocks = self.chunk(groups, 1)?;
|
||||
let kernel = kernel.chunk(groups, 0)?;
|
||||
let blocks = blocks
|
||||
.iter()
|
||||
.zip(&kernel)
|
||||
.map(|(block, kernel)| block.conv_transpose1d_single_group(kernel, ¶ms))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Tensor::cat(&blocks, 1)
|
||||
}
|
||||
}
|
||||
|
||||
fn conv2d_single_group(&self, kernel: &Self, params: &ParamsConv2D) -> Result<Self> {
|
||||
let storage =
|
||||
self.storage()
|
||||
|
@ -5,7 +5,6 @@ use half::{bf16, f16};
|
||||
use rayon::prelude::*;
|
||||
|
||||
const USE_IM2COL_CONV1D: bool = true;
|
||||
const USE_IM2COL_CONV1D_TR: bool = true;
|
||||
const USE_IM2COL_CONV2D: bool = true;
|
||||
|
||||
// TODO: Maybe we should not implement [Clone] here and instead have an explicit allocator +
|
||||
@ -1023,26 +1022,6 @@ impl<'a, I: IntDType> Map2 for IndexAdd<'a, I> {
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn copy2d_<T: Copy>(
|
||||
src: &[T],
|
||||
dst: &mut [T],
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_stride1: usize,
|
||||
dst_stride1: usize,
|
||||
src_offset: usize,
|
||||
dst_offset: usize,
|
||||
) {
|
||||
for i1 in 0..d1 {
|
||||
let dst_idx = i1 * dst_stride1 + dst_offset;
|
||||
let src_idx = i1 * src_stride1 + src_offset;
|
||||
let dst = &mut dst[dst_idx..dst_idx + d2];
|
||||
let src = &src[src_idx..src_idx + d2];
|
||||
dst.copy_from_slice(src)
|
||||
}
|
||||
}
|
||||
|
||||
fn copy_strided_src_<T: Copy>(src: &[T], dst: &mut [T], dst_offset: usize, src_l: &Layout) {
|
||||
match src_l.strided_blocks() {
|
||||
crate::StridedBlocks::SingleBlock { start_offset, len } => {
|
||||
@ -1277,34 +1256,6 @@ impl Map1 for Im2Col {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: WithDType>(&self, col: &[T], l: &Layout) -> Result<Vec<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = l.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
let mut im = vec![T::zero(); b_size * c_out * l_out];
|
||||
let (dst_s0, dst_s1) = (c_out * l_out, l_out);
|
||||
let (src_s0, src_s1, src_s2) = (c_out * k_size * l_in, c_out * k_size, k_size);
|
||||
for l_in_i in 0..l_in {
|
||||
for k_i in 0..k_size {
|
||||
let l_out_i = l_in_i * stride + k_i;
|
||||
for b_i in 0..b_size {
|
||||
for c_i in 0..c_out {
|
||||
let dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_out_i;
|
||||
let src_idx = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
im[dst_idx] += col[src_idx]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(im)
|
||||
}
|
||||
}
|
||||
|
||||
struct ConvTranspose1D<'a>(&'a crate::conv::ParamsConvTranspose1D);
|
||||
|
||||
impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
@ -1312,7 +1263,6 @@ impl<'a> Map2 for ConvTranspose1D<'a> {
|
||||
fn f<T: WithDType>(&self, inp: &[T], inp_l: &Layout, k: &[T], k_l: &Layout) -> Result<Vec<T>> {
|
||||
let p = self.0;
|
||||
let inp = &inp[inp_l.start_offset()..];
|
||||
let k = &k[k_l.start_offset()..];
|
||||
let (inp_s0, inp_s1, inp_s2) = crate::shape::dims3(inp_l.stride())?;
|
||||
let (k_s0, k_s1, k_s2) = crate::shape::dims3(k_l.stride())?;
|
||||
let l_out = p.l_out();
|
||||
@ -2472,48 +2422,6 @@ impl BackendStorage for CpuStorage {
|
||||
}
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::U32(src), Self::U32(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::I64(src), Self::I64(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::BF16(src), Self::BF16(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F16(src), Self::F16(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F32(src), Self::F32(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(Self::F64(src), Self::F64(dst)) => {
|
||||
copy2d_(src, dst, d1, d2, src_s, dst_s, src_o, dst_o)
|
||||
}
|
||||
(_, dst) => {
|
||||
return Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: self.dtype(),
|
||||
rhs: dst.dtype(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::U8(src), Self::U8(dst)) => copy_strided_src_(src, dst, dst_offset, src_l),
|
||||
@ -2602,52 +2510,7 @@ impl BackendStorage for CpuStorage {
|
||||
kernel_l: &Layout,
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if USE_IM2COL_CONV1D_TR && can_use_col2im {
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!(
|
||||
"convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}"
|
||||
)
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col, &col_l)
|
||||
} else {
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
ConvTranspose1D(params).map(self, l, kernel, kernel_l)
|
||||
}
|
||||
|
||||
fn conv2d(
|
||||
@ -2711,7 +2574,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => IndexSelect { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "index-select")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2720,7 +2583,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::U32(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
Self::I64(ids) => Gather { ids, ids_l, dim }.map(self, l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "gather")),
|
||||
}
|
||||
}
|
||||
|
||||
@ -2737,7 +2600,7 @@ impl BackendStorage for CpuStorage {
|
||||
Self::U8(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::U32(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
Self::I64(ids) => ScatterAdd { ids, ids_l, dim }.map(self, l, src, src_l),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add").bt()),
|
||||
_ => Err(Error::UnsupportedDTypeForOp(self.dtype(), "scatter-add")),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -608,34 +608,6 @@ impl Map1 for Elu {
|
||||
}
|
||||
}
|
||||
|
||||
struct Col2Im1D {
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl Map1 for Col2Im1D {
|
||||
fn f<T: DeviceRepr + WithDType>(
|
||||
&self,
|
||||
src: &CudaSlice<T>,
|
||||
dev: &CudaDevice,
|
||||
layout: &Layout,
|
||||
) -> Result<CudaSlice<T>> {
|
||||
let (b_size, l_in, c_out, k_size) = layout.shape().dims4()?;
|
||||
let stride = self.stride;
|
||||
let l_out = (l_in - 1) * stride + k_size;
|
||||
|
||||
let dst_el = b_size * c_out * l_out;
|
||||
let cfg = LaunchConfig::for_num_elems(dst_el as u32);
|
||||
let src = &src.slice(layout.start_offset()..);
|
||||
let func = dev.get_or_load_func(&kernel_name::<T>("col2im1d"), kernels::CONV)?;
|
||||
// SAFETY: Set later by running the kernel.
|
||||
let dst = unsafe { dev.alloc::<T>(dst_el) }.w()?;
|
||||
let params = (l_in, l_out, c_out, k_size, b_size, stride, src, &dst);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(dst)
|
||||
}
|
||||
}
|
||||
|
||||
struct Im2Col1D {
|
||||
l_k: usize,
|
||||
stride: usize,
|
||||
@ -1893,54 +1865,8 @@ impl BackendStorage for CudaStorage {
|
||||
params: &crate::conv::ParamsConvTranspose1D,
|
||||
) -> Result<Self> {
|
||||
let device = self.device().clone();
|
||||
const USE_COL2IM_CONV1D_TR: bool = true;
|
||||
|
||||
let can_use_col2im = kernel_l.is_contiguous()
|
||||
&& params.dilation == 1
|
||||
&& params.padding == 0
|
||||
&& params.output_padding == 0;
|
||||
if !can_use_col2im || !USE_COL2IM_CONV1D_TR {
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
return Ok(Self { slice, device });
|
||||
}
|
||||
|
||||
let (b_size, c_in, l_in) = l.shape().dims3()?;
|
||||
let (c_in2, c_out, k_size) = kernel_l.shape().dims3()?;
|
||||
if !kernel_l.is_contiguous() {
|
||||
crate::bail!("convtr1d: the second argument (kernel) has to be contiguous {kernel_l:?}")
|
||||
}
|
||||
if c_in != c_in2 {
|
||||
crate::bail!(
|
||||
"convtr1d: shape mismatch on c_in {:?} {:?}",
|
||||
l.shape(),
|
||||
kernel_l.shape()
|
||||
)
|
||||
}
|
||||
let col = {
|
||||
// This merges the last two dimensions of the kernel together.
|
||||
let kernel_l_mm = Layout::new(
|
||||
(b_size, c_in, k_size * c_out).into(),
|
||||
vec![0, k_size * c_out, 1],
|
||||
kernel_l.start_offset(),
|
||||
);
|
||||
self.matmul(
|
||||
kernel,
|
||||
(
|
||||
b_size,
|
||||
/* m */ l_in,
|
||||
/* n */ c_out * k_size,
|
||||
/* k */ c_in,
|
||||
),
|
||||
&l.transpose(1, 2)?,
|
||||
&kernel_l_mm,
|
||||
)?
|
||||
};
|
||||
let col_l = Layout::contiguous((b_size, l_in, c_out, k_size));
|
||||
let slice = Col2Im1D {
|
||||
stride: params.stride,
|
||||
}
|
||||
.map(&col.slice, &device, &col_l)?;
|
||||
let slice =
|
||||
ConvTranspose1D(params).map(&self.slice, l, &kernel.slice, kernel_l, &device)?;
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
@ -2219,67 +2145,6 @@ impl BackendStorage for CudaStorage {
|
||||
Ok(Self { slice, device })
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
let dev = &self.device;
|
||||
let d1 = d1 as u32;
|
||||
let d2 = d2 as u32;
|
||||
let dst_s = dst_s as u32;
|
||||
let src_s = src_s as u32;
|
||||
let (src, dst, kname) = match (&self.slice, &mut dst.slice) {
|
||||
(S::U8(s), S::U8(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_u8",
|
||||
),
|
||||
(S::U32(s), S::U32(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_u32",
|
||||
),
|
||||
(S::I64(s), S::I64(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_i64",
|
||||
),
|
||||
(S::BF16(s), S::BF16(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_bf16",
|
||||
),
|
||||
(S::F16(s), S::F16(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f16",
|
||||
),
|
||||
(S::F32(s), S::F32(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f32",
|
||||
),
|
||||
(S::F64(s), S::F64(d)) => (
|
||||
*s.slice(src_o..).device_ptr(),
|
||||
*d.slice(dst_o..).device_ptr(),
|
||||
"copy2d_f64",
|
||||
),
|
||||
_ => Err(CudaError::InternalError("dtype mismatch in copy2d"))?,
|
||||
};
|
||||
let func = dev.get_or_load_func(kname, kernels::FILL)?;
|
||||
let cfg = LaunchConfig::for_num_elems(d1 * d2);
|
||||
let params = (src, dst, d1, d2, src_s, dst_s);
|
||||
// SAFETY: ffi.
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
let src_shape = src_l.shape();
|
||||
let dims = src_shape.dims();
|
||||
|
@ -65,13 +65,12 @@ impl std::fmt::Debug for Tensor {
|
||||
}
|
||||
|
||||
/// Options for Tensor pretty printing
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PrinterOptions {
|
||||
pub precision: usize,
|
||||
pub threshold: usize,
|
||||
pub edge_items: usize,
|
||||
pub line_width: usize,
|
||||
pub sci_mode: Option<bool>,
|
||||
precision: usize,
|
||||
threshold: usize,
|
||||
edge_items: usize,
|
||||
line_width: usize,
|
||||
sci_mode: Option<bool>,
|
||||
}
|
||||
|
||||
static PRINT_OPTS: std::sync::Mutex<PrinterOptions> =
|
||||
@ -90,10 +89,6 @@ impl PrinterOptions {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn print_options() -> &'static std::sync::Mutex<PrinterOptions> {
|
||||
&PRINT_OPTS
|
||||
}
|
||||
|
||||
pub fn set_print_options(options: PrinterOptions) {
|
||||
*PRINT_OPTS.lock().unwrap() = options
|
||||
}
|
||||
@ -122,26 +117,6 @@ pub fn set_print_options_full() {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_line_width(line_width: usize) {
|
||||
PRINT_OPTS.lock().unwrap().line_width = line_width
|
||||
}
|
||||
|
||||
pub fn set_precision(precision: usize) {
|
||||
PRINT_OPTS.lock().unwrap().precision = precision
|
||||
}
|
||||
|
||||
pub fn set_edge_items(edge_items: usize) {
|
||||
PRINT_OPTS.lock().unwrap().edge_items = edge_items
|
||||
}
|
||||
|
||||
pub fn set_threshold(threshold: usize) {
|
||||
PRINT_OPTS.lock().unwrap().threshold = threshold
|
||||
}
|
||||
|
||||
pub fn set_sci_mode(sci_mode: Option<bool>) {
|
||||
PRINT_OPTS.lock().unwrap().sci_mode = sci_mode
|
||||
}
|
||||
|
||||
struct FmtSize {
|
||||
current_size: usize,
|
||||
}
|
||||
|
@ -23,15 +23,7 @@ pub enum DType {
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq)]
|
||||
pub struct DTypeParseError(String);
|
||||
|
||||
impl std::fmt::Display for DTypeParseError {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "cannot parse '{}' as a dtype", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for DTypeParseError {}
|
||||
pub struct DTypeParseError;
|
||||
|
||||
impl std::str::FromStr for DType {
|
||||
type Err = DTypeParseError;
|
||||
@ -44,7 +36,7 @@ impl std::str::FromStr for DType {
|
||||
"f16" => Ok(Self::F16),
|
||||
"f32" => Ok(Self::F32),
|
||||
"f64" => Ok(Self::F64),
|
||||
_ => Err(DTypeParseError(s.to_string())),
|
||||
_ => Err(DTypeParseError),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -154,19 +154,6 @@ impl crate::backend::BackendStorage for CudaStorage {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
@ -166,19 +166,6 @@ impl crate::backend::BackendStorage for MetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
_: &mut Self,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
_: usize,
|
||||
) -> Result<()> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
||||
fn avg_pool2d(&self, _: &Layout, _: (usize, usize), _: (usize, usize)) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ impl Layout {
|
||||
self.shape.is_fortran_contiguous(&self.stride)
|
||||
}
|
||||
|
||||
pub fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
pub(crate) fn narrow(&self, dim: usize, start: usize, len: usize) -> Result<Self> {
|
||||
let dims = self.shape().dims();
|
||||
if dim >= dims.len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
@ -99,7 +99,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
pub(crate) fn transpose(&self, dim1: usize, dim2: usize) -> Result<Self> {
|
||||
let rank = self.shape.rank();
|
||||
if rank <= dim1 || rank <= dim2 {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
@ -120,7 +120,7 @@ impl Layout {
|
||||
})
|
||||
}
|
||||
|
||||
pub fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
pub(crate) fn permute(&self, idxs: &[usize]) -> Result<Self> {
|
||||
let is_permutation =
|
||||
idxs.len() == self.shape.rank() && (0..idxs.len()).all(|i| idxs.contains(&i));
|
||||
if !is_permutation {
|
||||
|
@ -67,7 +67,6 @@ pub mod shape;
|
||||
mod storage;
|
||||
mod strided_index;
|
||||
mod tensor;
|
||||
mod tensor_cat;
|
||||
pub mod test_utils;
|
||||
pub mod utils;
|
||||
mod variable;
|
||||
@ -130,15 +129,6 @@ impl<T: Fn(&Tensor) -> Result<Tensor>> Module for T {
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> Module for Option<&M> {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
None => Ok(xs.clone()),
|
||||
Some(m) => m.forward(xs),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// A trait defining a module with forward method using a single tensor argument and a flag to
|
||||
// separate the training and evaluation behaviors.
|
||||
pub trait ModuleT {
|
||||
|
@ -9,7 +9,7 @@ use metal::{Buffer, CommandBuffer, CommandQueue, MTLResourceOptions, NSUInteger}
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, Mutex, RwLock, RwLockWriteGuard, TryLockError};
|
||||
use std::sync::{Arc, Mutex, RwLock, TryLockError};
|
||||
|
||||
/// Simple way to catch lock error without
|
||||
/// depending on T
|
||||
@ -60,8 +60,7 @@ impl From<String> for MetalError {
|
||||
}
|
||||
}
|
||||
|
||||
type BufferMap = HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>;
|
||||
type AllocatedBuffers = Arc<RwLock<BufferMap>>;
|
||||
type AllocatedBuffers = Arc<RwLock<HashMap<(NSUInteger, MTLResourceOptions), Vec<Arc<Buffer>>>>>;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct MetalDevice {
|
||||
@ -69,7 +68,7 @@ pub struct MetalDevice {
|
||||
device: metal::Device,
|
||||
|
||||
/// Single command queue for the entire device.
|
||||
command_queue: CommandQueue,
|
||||
command_queue: metal::CommandQueue,
|
||||
/// One command buffer at a time.
|
||||
/// The scheduler works by allowing multiple
|
||||
/// [ComputeCommandEncoder](https://developer.apple.com/documentation/metal/mtlcomputecommandencoder?language=objc)
|
||||
@ -79,7 +78,7 @@ pub struct MetalDevice {
|
||||
/// Despite what the documentation says, command buffers are NOT ordered. They are ordered
|
||||
/// for their START time, but there's no guarantee that command buffer1 will finish before
|
||||
/// command buffer2 starts (or there are metal bugs there)
|
||||
command_buffer: Arc<RwLock<CommandBuffer>>,
|
||||
command_buffer: Arc<RwLock<metal::CommandBuffer>>,
|
||||
/// Keeps track of the current amount of compute command encoders on the current
|
||||
/// command buffer
|
||||
/// Arc, RwLock because of the interior mutability.
|
||||
@ -88,7 +87,7 @@ pub struct MetalDevice {
|
||||
compute_per_buffer: usize,
|
||||
/// Simple keeper struct to keep track of the already compiled kernels so we can reuse them.
|
||||
/// Heavily used by [`candle_metal_kernels`]
|
||||
kernels: Arc<Kernels>,
|
||||
kernels: Arc<candle_metal_kernels::Kernels>,
|
||||
/// Simple allocator struct.
|
||||
/// The buffers are stored in size buckets since ML tends to use similar shapes over and over.
|
||||
/// We store the buffers in [`Arc`] because it's much faster than Obj-c internal ref counting
|
||||
@ -100,7 +99,7 @@ pub struct MetalDevice {
|
||||
/// operation, so that this buffer is not being used by another kernel at the same time.
|
||||
/// Arc is the CPU reference count, it doesn't mean anything on the GPU side of things.
|
||||
///
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to clean up unused buffers
|
||||
/// Whenever we actually allocate a new buffer, we make a full sweep to cleanup unused buffers
|
||||
/// (strong_count = 1).
|
||||
buffers: AllocatedBuffers,
|
||||
/// Seed for random number generation.
|
||||
@ -146,8 +145,6 @@ impl MetalDevice {
|
||||
command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
*command_buffer_lock = command_buffer.clone();
|
||||
*index = 0;
|
||||
|
||||
self.drop_unused_buffers()?;
|
||||
}
|
||||
*index += 1;
|
||||
Ok(command_buffer)
|
||||
@ -166,7 +163,6 @@ impl MetalDevice {
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
*command_buffer = self.command_queue.new_command_buffer().to_owned();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -203,25 +199,39 @@ impl MetalDevice {
|
||||
}
|
||||
|
||||
/// Creates a new buffer from data.
|
||||
/// The buffer is [MTLManaged](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
/// The buffer is [MTLPrivate](https://developer.apple.com/documentation/metal/mtlstoragemode)
|
||||
///
|
||||
/// Does not require synchronization, as [newBufferWithBytes](https://developer.apple.com/documentation/metal/mtldevice/1433429-newbufferwithbytes)
|
||||
/// allocates the buffer and copies over the existing data before returning the MTLBuffer.
|
||||
/// This method will block the computation because of the
|
||||
/// lack of lifetime management through the GPU.
|
||||
/// Internal comment for technical details.
|
||||
pub fn new_buffer_with_data<T>(&self, data: &[T]) -> Result<Arc<Buffer>> {
|
||||
let size = core::mem::size_of_val(data) as NSUInteger;
|
||||
let new_buffer = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const c_void,
|
||||
let tmp = self.device.new_buffer_with_data(
|
||||
data.as_ptr() as *const core::ffi::c_void,
|
||||
size,
|
||||
MTLResourceOptions::StorageModeManaged,
|
||||
metal::MTLResourceOptions::StorageModeManaged,
|
||||
);
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
let subbuffers = buffers
|
||||
.entry((size, MTLResourceOptions::StorageModeManaged))
|
||||
.or_insert(vec![]);
|
||||
let real = self.allocate_buffer(
|
||||
size,
|
||||
metal::MTLResourceOptions::StorageModePrivate,
|
||||
"with_data",
|
||||
)?;
|
||||
let command_buffer = self.command_buffer()?;
|
||||
command_buffer.set_label("with_data");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("with_data_blit");
|
||||
blit.copy_from_buffer(&tmp, 0, &real, 0, tmp.length());
|
||||
blit.end_encoding();
|
||||
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
Ok(new_buffer)
|
||||
// This is necessary, for mmaped safetensors
|
||||
// Because of the unsafe slice cast we're doing.
|
||||
// The slice might not live long enough for metal
|
||||
// To actually fill the GPU buffer.
|
||||
// Putting this wait forces the GPU buffer to be filled
|
||||
// with the actual data allowing the CPU storage to do
|
||||
// deallocate properly.
|
||||
self.wait_until_completed()?;
|
||||
Ok(real)
|
||||
}
|
||||
|
||||
pub fn allocate_zeros(&self, size_in_bytes: usize) -> Result<Arc<Buffer>> {
|
||||
@ -245,40 +255,6 @@ impl MetalDevice {
|
||||
Ok(buffer)
|
||||
}
|
||||
|
||||
fn find_available_buffer(
|
||||
&self,
|
||||
size: NSUInteger,
|
||||
option: MTLResourceOptions,
|
||||
buffers: &RwLockWriteGuard<BufferMap>,
|
||||
) -> Option<Arc<Buffer>> {
|
||||
let mut best_buffer: Option<&Arc<Buffer>> = None;
|
||||
let mut best_buffer_size: NSUInteger = NSUInteger::MAX;
|
||||
for ((buffer_size, buffer_option), subbuffers) in buffers.iter() {
|
||||
if buffer_size >= &size && buffer_size < &best_buffer_size && buffer_option == &option {
|
||||
for sub in subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
best_buffer = Some(sub);
|
||||
best_buffer_size = *buffer_size;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return best_buffer.map(|b| b.clone());
|
||||
}
|
||||
|
||||
fn drop_unused_buffers(&self) -> Result<()> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(*s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// The critical allocator algorithm
|
||||
fn allocate_buffer(
|
||||
&self,
|
||||
@ -287,18 +263,24 @@ impl MetalDevice {
|
||||
_name: &str,
|
||||
) -> Result<Arc<Buffer>> {
|
||||
let mut buffers = self.buffers.try_write().map_err(MetalError::from)?;
|
||||
if let Some(b) = self.find_available_buffer(size, option, &buffers) {
|
||||
// Cloning also ensures we increment the strong count
|
||||
return Ok(b.clone());
|
||||
}
|
||||
|
||||
let size = buf_size(size);
|
||||
let subbuffers = buffers.entry((size, option)).or_insert(vec![]);
|
||||
|
||||
for sub in &mut *subbuffers {
|
||||
if Arc::strong_count(sub) == 1 {
|
||||
return Ok(sub.clone());
|
||||
}
|
||||
}
|
||||
let new_buffer = self.device.new_buffer(size as NSUInteger, option);
|
||||
let new_buffer = Arc::new(new_buffer);
|
||||
subbuffers.push(new_buffer.clone());
|
||||
|
||||
for subbuffers in buffers.values_mut() {
|
||||
let newbuffers = subbuffers
|
||||
.iter()
|
||||
.filter(|s| Arc::strong_count(s) > 1)
|
||||
.map(Arc::clone)
|
||||
.collect();
|
||||
*subbuffers = newbuffers;
|
||||
}
|
||||
Ok(new_buffer)
|
||||
}
|
||||
|
||||
@ -323,8 +305,6 @@ pub struct MetalStorage {
|
||||
buffer: Arc<metal::Buffer>,
|
||||
/// a reference to the device owning this buffer
|
||||
device: MetalDevice,
|
||||
/// The count of allocated elements in the buffer
|
||||
count: usize,
|
||||
/// The dtype is kept since buffers are untyped.
|
||||
dtype: DType,
|
||||
}
|
||||
@ -406,7 +386,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn powf(&self, layout: &Layout, pow: f64) -> Result<Self> {
|
||||
@ -422,7 +402,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32",
|
||||
DType::F16 => "powf_f16",
|
||||
DType::BF16 => "powf_bf16",
|
||||
dtype => crate::bail!("Metal contiguous powf {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_powf(
|
||||
@ -440,7 +419,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "powf_f32_strided",
|
||||
DType::F16 => "powf_f16_strided",
|
||||
DType::BF16 => "powf_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided powf {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_powf_strided(
|
||||
@ -457,7 +435,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn elu(&self, layout: &Layout, alpha: f64) -> Result<Self> {
|
||||
@ -473,7 +451,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32",
|
||||
DType::F16 => "elu_f16",
|
||||
DType::BF16 => "elu_bf16",
|
||||
dtype => crate::bail!("Metal contiguous elu {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_elu(
|
||||
@ -491,7 +468,6 @@ impl BackendStorage for MetalStorage {
|
||||
let name = match self.dtype {
|
||||
DType::F32 => "elu_f32_strided",
|
||||
DType::F16 => "elu_f16_strided",
|
||||
DType::BF16 => "elu_bf16_strided",
|
||||
dtype => crate::bail!("Metal strided elu {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_elu_strided(
|
||||
@ -508,7 +484,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn reduce_op(&self, op: ReduceOp, layout: &Layout, sum_dims: &[usize]) -> Result<Self> {
|
||||
@ -586,7 +562,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::new(buffer, device, dst_el, dtype))
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn cmp(&self, op: CmpOp, rhs: &Self, lhs_l: &Layout, rhs_l: &Layout) -> Result<Self> {
|
||||
@ -609,41 +585,29 @@ impl BackendStorage for MetalStorage {
|
||||
let command_buffer = device.command_buffer()?;
|
||||
if layout.is_contiguous() && layout.start_offset() == 0 {
|
||||
let kernel_name = match (self.dtype, dtype) {
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::F32) => "cast_u32_f32",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::U8) => "cast_u32_u8",
|
||||
(DType::U32, DType::I64) => "cast_u32_i64",
|
||||
(DType::U32, DType::F16) => "cast_u32_f16",
|
||||
(DType::U32, DType::BF16) => "cast_u32_bf16",
|
||||
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
(DType::U8, DType::F16) => "cast_u8_f16",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::F32) => "cast_u8_f32",
|
||||
(DType::U8, DType::I64) => "cast_u8_i64",
|
||||
(DType::U8, DType::U32) => "cast_u8_u32",
|
||||
(DType::U8, DType::BF16) => "cast_u8_bf16",
|
||||
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
(DType::F32, DType::F16) => "cast_f32_f16",
|
||||
(DType::F32, DType::I64) => "cast_f32_i64",
|
||||
(DType::F32, DType::U32) => "cast_f32_u32",
|
||||
(DType::F32, DType::U8) => "cast_f32_u8",
|
||||
(DType::F32, DType::BF16) => "cast_f32_bf16",
|
||||
|
||||
(DType::I64, DType::BF16) => "cast_i64_bf16",
|
||||
(DType::I64, DType::F16) => "cast_i64_f16",
|
||||
(DType::I64, DType::F32) => "cast_i64_f32",
|
||||
(DType::I64, DType::U32) => "cast_i64_u32",
|
||||
(DType::I64, DType::U8) => "cast_i64_u8",
|
||||
|
||||
(DType::F16, DType::BF16) => "cast_f16_bf16",
|
||||
(DType::F16, DType::F32) => "cast_f16_f32",
|
||||
(DType::F16, DType::I64) => "cast_f16_i64",
|
||||
(DType::F16, DType::U32) => "cast_f16_u32",
|
||||
(DType::F16, DType::U8) => "cast_f16_u8",
|
||||
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::F16) => "cast_bf16_f16",
|
||||
(DType::BF16, DType::F32) => "cast_bf16_f32",
|
||||
(DType::BF16, DType::I64) => "cast_bf16_i64",
|
||||
(DType::BF16, DType::U32) => "cast_bf16_u32",
|
||||
(DType::BF16, DType::U8) => "cast_bf16_u8",
|
||||
|
||||
(left, right) => {
|
||||
crate::bail!("Metal contiguous to_dtype {left:?} {right:?} not implemented")
|
||||
@ -691,7 +655,7 @@ impl BackendStorage for MetalStorage {
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
command_buffer.set_label("to_dtype");
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn unary_impl<B: UnaryOpT>(&self, layout: &Layout) -> Result<Self> {
|
||||
@ -775,7 +739,6 @@ impl BackendStorage for MetalStorage {
|
||||
("ufloor", DType::F32) => strided::floor::FLOAT,
|
||||
("urelu", DType::F32) => strided::relu::FLOAT,
|
||||
("uround", DType::F32) => strided::round::FLOAT,
|
||||
("utanh", DType::F32) => strided::tanh::FLOAT,
|
||||
("ucos", DType::F16) => strided::cos::HALF,
|
||||
("usin", DType::F16) => strided::sin::HALF,
|
||||
("usqr", DType::F16) => strided::sqr::HALF,
|
||||
@ -792,7 +755,6 @@ impl BackendStorage for MetalStorage {
|
||||
("ufloor", DType::F16) => strided::floor::HALF,
|
||||
("urelu", DType::F16) => strided::relu::HALF,
|
||||
("uround", DType::F16) => strided::round::HALF,
|
||||
("utanh", DType::F16) => strided::tanh::HALF,
|
||||
(name, dtype) => {
|
||||
crate::bail!("Metal strided unary {name} {dtype:?} not implemented")
|
||||
}
|
||||
@ -811,7 +773,7 @@ impl BackendStorage for MetalStorage {
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
}
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn binary_impl<B: BinaryOpT>(
|
||||
@ -866,13 +828,13 @@ impl BackendStorage for MetalStorage {
|
||||
layout.start_offset() * self.dtype.size_in_bytes(),
|
||||
),
|
||||
&t.buffer,
|
||||
(t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
(&t_l.stride(), t_l.start_offset() * t.dtype.size_in_bytes()),
|
||||
&f.buffer,
|
||||
(f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
(&f_l.stride(), f_l.start_offset() * f.dtype.size_in_bytes()),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device, el, dtype))
|
||||
Ok(Self::new(buffer, device, dtype))
|
||||
}
|
||||
|
||||
fn conv1d(
|
||||
@ -917,7 +879,6 @@ impl BackendStorage for MetalStorage {
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
count: dst_el,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let l_out = params.l_out();
|
||||
@ -1002,7 +963,6 @@ impl BackendStorage for MetalStorage {
|
||||
let col = Self {
|
||||
buffer: dst,
|
||||
device,
|
||||
count: dst_el,
|
||||
dtype: self.dtype,
|
||||
};
|
||||
let h_out = params.out_h();
|
||||
@ -1088,7 +1048,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, self.device.clone(), dst_el, self.dtype))
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype))
|
||||
}
|
||||
|
||||
fn gather(&self, src_l: &Layout, ids: &Self, ids_l: &Layout, dim: usize) -> Result<Self> {
|
||||
@ -1122,7 +1082,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn scatter_add(
|
||||
@ -1145,15 +1105,7 @@ impl BackendStorage for MetalStorage {
|
||||
None => Err(crate::Error::RequiresContiguous { op: "scatter-add" }.bt())?,
|
||||
};
|
||||
let name = match (ids.dtype, self.dtype) {
|
||||
(DType::U8, DType::F32) => "sa_u8_f32",
|
||||
(DType::U8, DType::F16) => "sa_u8_f16",
|
||||
(DType::U8, DType::BF16) => "sa_u8_bf16",
|
||||
(DType::U32, DType::F32) => "sa_u32_f32",
|
||||
(DType::U32, DType::F16) => "sa_u32_f16",
|
||||
(DType::U32, DType::BF16) => "sa_u32_bf16",
|
||||
(DType::I64, DType::F32) => "sa_i64_f32",
|
||||
(DType::I64, DType::F16) => "sa_i64_f16",
|
||||
(DType::I64, DType::BF16) => "sa_i64_bf16",
|
||||
_ => Err(MetalError::UnexpectedDType {
|
||||
msg: "scatter-add ids should be u8/u32/i64",
|
||||
expected: DType::U32,
|
||||
@ -1219,7 +1171,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(buffer, device.clone(), dst_el, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
fn index_add(
|
||||
@ -1301,73 +1253,7 @@ impl BackendStorage for MetalStorage {
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
Ok(Self::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
b * m * n,
|
||||
self.dtype(),
|
||||
))
|
||||
}
|
||||
|
||||
fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
if self.dtype() != dst.dtype() {
|
||||
crate::bail!(
|
||||
"copy2d with inconsistent dtypes {:?} {:?}",
|
||||
self.dtype(),
|
||||
dst.dtype()
|
||||
)
|
||||
}
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
if src_s == d2 && dst_s == d2 {
|
||||
command_buffer.set_label("copy2d_contiguous");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("copy2d_contiguous");
|
||||
let src_offset = (src_o * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (d1 * d2 * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_o * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let el_count = d1 * d2;
|
||||
if el_count == 0 {
|
||||
return Ok(());
|
||||
}
|
||||
let kernel_name = match self.dtype {
|
||||
DType::F32 => candle_metal_kernels::copy2d::FLOAT,
|
||||
DType::F16 => candle_metal_kernels::copy2d::HALF,
|
||||
DType::BF16 => candle_metal_kernels::copy2d::BFLOAT,
|
||||
DType::I64 => candle_metal_kernels::copy2d::I64,
|
||||
DType::U32 => candle_metal_kernels::copy2d::U32,
|
||||
DType::U8 => candle_metal_kernels::copy2d::U8,
|
||||
dtype => crate::bail!("Metal copy2d {dtype:?} not implemented"),
|
||||
};
|
||||
candle_metal_kernels::call_copy2d(
|
||||
&self.device.device,
|
||||
&command_buffer,
|
||||
&self.device.kernels,
|
||||
kernel_name,
|
||||
&self.buffer,
|
||||
&dst.buffer,
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
src_o * self.dtype.size_in_bytes(),
|
||||
dst_o * self.dtype.size_in_bytes(),
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
command_buffer.set_label("copy2d");
|
||||
}
|
||||
Ok(())
|
||||
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
|
||||
}
|
||||
|
||||
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {
|
||||
@ -1379,7 +1265,7 @@ impl BackendStorage for MetalStorage {
|
||||
let src_offset = (src_l.start_offset() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let length = (src_l.shape().elem_count() * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
let dst_offset = (dst_offset * dst.dtype().size_in_bytes()) as NSUInteger;
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, dst.buffer(), dst_offset, length);
|
||||
blit.copy_from_buffer(&self.buffer, src_offset, &dst.buffer(), dst_offset, length);
|
||||
blit.end_encoding();
|
||||
} else {
|
||||
let src_shape = src_l.shape();
|
||||
@ -1416,11 +1302,10 @@ impl BackendStorage for MetalStorage {
|
||||
}
|
||||
|
||||
impl MetalStorage {
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, count: usize, dtype: DType) -> Self {
|
||||
pub fn new(buffer: Arc<Buffer>, device: MetalDevice, dtype: DType) -> Self {
|
||||
Self {
|
||||
buffer,
|
||||
device,
|
||||
count,
|
||||
dtype,
|
||||
}
|
||||
}
|
||||
@ -1635,23 +1520,29 @@ impl MetalStorage {
|
||||
(buffer, dtype)
|
||||
};
|
||||
command_buffer.set_label("binary");
|
||||
Ok(Self::new(buffer, device.clone(), el_count, dtype))
|
||||
Ok(Self::new(buffer, device.clone(), dtype))
|
||||
}
|
||||
|
||||
pub(crate) fn to_cpu<T: Clone>(&self) -> Result<Vec<T>> {
|
||||
let size = (self.count * self.dtype.size_in_bytes()) as NSUInteger;
|
||||
|
||||
let buffer = self.device.new_buffer_managed(size)?;
|
||||
let length = self.buffer.length() as usize;
|
||||
let size = self.dtype.size_in_bytes();
|
||||
if length % size != 0 {
|
||||
crate::bail!(
|
||||
"The Metal buffer length is not aligned with dtype {:?}",
|
||||
self.dtype
|
||||
);
|
||||
}
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
{
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
let blit = command_buffer.new_blit_command_encoder();
|
||||
blit.set_label("blit_to_cpu");
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, size);
|
||||
blit.copy_from_buffer(&self.buffer, 0, &buffer, 0, self.buffer.length());
|
||||
blit.end_encoding();
|
||||
}
|
||||
self.device.wait_until_completed()?;
|
||||
Ok(read_to_vec(&buffer, self.count))
|
||||
Ok(read_to_vec(&buffer, length / size))
|
||||
}
|
||||
}
|
||||
|
||||
@ -1669,7 +1560,7 @@ impl BackendDevice for MetalDevice {
|
||||
let buffers = Arc::new(RwLock::new(HashMap::new()));
|
||||
let compute_per_buffer = match std::env::var("CANDLE_METAL_COMPUTE_PER_BUFFER") {
|
||||
Ok(val) => val.parse()?,
|
||||
_ => 50,
|
||||
_ => 10,
|
||||
};
|
||||
let seed = Arc::new(Mutex::new(device.new_buffer_with_data(
|
||||
[299792458].as_ptr() as *const c_void,
|
||||
@ -1701,12 +1592,7 @@ impl BackendDevice for MetalDevice {
|
||||
fn zeros_impl(&self, shape: &Shape, dtype: DType) -> Result<MetalStorage> {
|
||||
let size = shape.elem_count() * dtype.size_in_bytes();
|
||||
let buffer = self.allocate_zeros(size)?;
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
Ok(MetalStorage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn ones_impl(&self, shape: &Shape, dtype: DType) -> Result<Self::Storage> {
|
||||
@ -1716,21 +1602,16 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
|
||||
fn storage_from_cpu_storage(&self, storage: &CpuStorage) -> Result<Self::Storage> {
|
||||
let (count, buffer) = match storage {
|
||||
CpuStorage::U8(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::U32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::I64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::BF16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::F16(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::F32(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
CpuStorage::F64(storage) => (storage.len(), self.new_buffer_with_data(storage)),
|
||||
};
|
||||
Ok(Self::Storage::new(
|
||||
buffer?,
|
||||
self.clone(),
|
||||
count,
|
||||
storage.dtype(),
|
||||
))
|
||||
let buffer = match storage {
|
||||
CpuStorage::U8(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::U32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::I64(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::BF16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F16(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F32(storage) => self.new_buffer_with_data(storage),
|
||||
CpuStorage::F64(storage) => self.new_buffer_with_data(storage),
|
||||
}?;
|
||||
Ok(Self::Storage::new(buffer, self.clone(), storage.dtype()))
|
||||
}
|
||||
|
||||
fn rand_uniform(
|
||||
@ -1756,17 +1637,12 @@ impl BackendDevice for MetalDevice {
|
||||
min as f32,
|
||||
max as f32,
|
||||
shape.elem_count(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn rand_normal(
|
||||
@ -1792,17 +1668,12 @@ impl BackendDevice for MetalDevice {
|
||||
mean as f32,
|
||||
stddev as f32,
|
||||
shape.elem_count(),
|
||||
&self.seed.lock().unwrap(),
|
||||
&*self.seed.lock().unwrap(),
|
||||
&buffer,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
|
||||
Ok(Self::Storage::new(
|
||||
buffer,
|
||||
self.clone(),
|
||||
shape.elem_count(),
|
||||
dtype,
|
||||
))
|
||||
Ok(Self::Storage::new(buffer, self.clone(), dtype))
|
||||
}
|
||||
|
||||
fn set_seed(&self, seed: u64) -> Result<()> {
|
||||
@ -1813,7 +1684,7 @@ impl BackendDevice for MetalDevice {
|
||||
let seed_buffer = self.seed.try_lock().map_err(MetalError::from)?;
|
||||
let contents = seed_buffer.contents();
|
||||
unsafe {
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 1);
|
||||
std::ptr::copy([seed].as_ptr(), contents as *mut u32, 4);
|
||||
}
|
||||
seed_buffer.did_modify_range(metal::NSRange::new(0, 4));
|
||||
|
||||
@ -1821,10 +1692,6 @@ impl BackendDevice for MetalDevice {
|
||||
}
|
||||
}
|
||||
|
||||
fn buf_size(size: NSUInteger) -> NSUInteger {
|
||||
(size - 1).next_power_of_two() as NSUInteger
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &Buffer, n: usize) -> Vec<T> {
|
||||
let ptr = buffer.contents() as *const T;
|
||||
assert!(!ptr.is_null());
|
||||
|
@ -132,10 +132,7 @@ pub enum Op {
|
||||
stride: (usize, usize),
|
||||
},
|
||||
|
||||
UpsampleNearest1D {
|
||||
arg: Tensor,
|
||||
target_size: usize,
|
||||
},
|
||||
UpsampleNearest1D(Tensor),
|
||||
UpsampleNearest2D {
|
||||
arg: Tensor,
|
||||
target_h: usize,
|
||||
|
@ -42,7 +42,7 @@ pub enum OpCode {
|
||||
Stop = b'.',
|
||||
NewObj = 0x81,
|
||||
EmptyList = b']',
|
||||
BinFloat = b'G',
|
||||
BinFloat = b'g',
|
||||
Append = b'a',
|
||||
Appends = b'e',
|
||||
}
|
||||
@ -462,10 +462,7 @@ impl Stack {
|
||||
self.push(Object::Int(arg))
|
||||
}
|
||||
OpCode::BinFloat => {
|
||||
// Somehow floats are encoded using BigEndian whereas int types use LittleEndian.
|
||||
// https://github.com/python/cpython/blob/0c80da4c14d904a367968955544dd6ae58c8101c/Lib/pickletools.py#L855
|
||||
// https://github.com/pytorch/pytorch/blob/372d078f361e726bb4ac0884ac334b04c58179ef/torch/_weights_only_unpickler.py#L243
|
||||
let arg = r.read_f64::<byteorder::BigEndian>()?;
|
||||
let arg = r.read_f64::<LittleEndian>()?;
|
||||
self.push(Object::Float(arg))
|
||||
}
|
||||
OpCode::BinUnicode => {
|
||||
|
@ -1,343 +0,0 @@
|
||||
use super::{GgmlDType, QStorage};
|
||||
use crate::{backend::BackendDevice, cuda_backend::WrapErr};
|
||||
use crate::{CudaDevice, CudaStorage, Result};
|
||||
|
||||
use cudarc::driver::{CudaSlice, DeviceSlice};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
data: CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
pub const WARP_SIZE: usize = 32;
|
||||
pub const MMQ_X_Q4_0_AMPERE: usize = 4;
|
||||
pub const MMQ_Y_Q4_0_AMPERE: usize = 32;
|
||||
pub const NWARPS_Q4_0_AMPERE: usize = 4;
|
||||
pub const GGML_CUDA_MMV_X: usize = 32;
|
||||
pub const GGML_CUDA_MMV_Y: usize = 1;
|
||||
pub const CUDA_DEQUANTIZE_BLOCK_SIZE: usize = 256;
|
||||
|
||||
fn dequantize(
|
||||
data: &CudaSlice<u8>,
|
||||
dtype: GgmlDType,
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1", false, 32, nb),
|
||||
GgmlDType::Q5_0 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_0",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let nb = (elem_count + 2 * CUDA_DEQUANTIZE_BLOCK_SIZE - 1)
|
||||
/ (2 * CUDA_DEQUANTIZE_BLOCK_SIZE);
|
||||
(
|
||||
"dequantize_block_q5_1",
|
||||
false,
|
||||
CUDA_DEQUANTIZE_BLOCK_SIZE,
|
||||
nb,
|
||||
)
|
||||
}
|
||||
GgmlDType::Q8_0 => ("dequantize_block_q8_0", false, 32, nb),
|
||||
GgmlDType::Q2K => ("dequantize_block_q2_K", true, 64, nb),
|
||||
GgmlDType::Q3K => ("dequantize_block_q3_K", true, 64, nb),
|
||||
GgmlDType::Q4K => ("dequantize_block_q4_K", true, 32, nb),
|
||||
GgmlDType::Q5K => ("dequantize_block_q5_K", true, 64, nb),
|
||||
GgmlDType::Q6K => ("dequantize_block_q6_K", true, 64, nb),
|
||||
GgmlDType::Q8K => ("dequantize_block_q8_K", true, 32, nb),
|
||||
_ => crate::bail!("unsupported dtype for dequantize {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(elem_count).w()?;
|
||||
// See e.g.
|
||||
// https://github.com/ggerganov/llama.cpp/blob/cbbd1efa06f8c09f9dff58ff9d9af509cc4c152b/ggml-cuda.cu#L7270
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (num_blocks as u32, 1, 1),
|
||||
block_dim: (block_dim as u32, 1, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
if is_k {
|
||||
let params = (data, &dst);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
} else {
|
||||
let nb32 = match dtype {
|
||||
GgmlDType::Q5_0 | GgmlDType::Q5_1 => elem_count,
|
||||
_ => elem_count / 32,
|
||||
};
|
||||
let params = (data, &dst, nb32 as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
}
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
fn dequantize_mut_mal_vec(
|
||||
data: &CudaSlice<u8>,
|
||||
y: &cudarc::driver::CudaView<f32>,
|
||||
dtype: GgmlDType,
|
||||
ncols: usize,
|
||||
nrows: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
use cudarc::driver::LaunchAsync;
|
||||
|
||||
let kernel_name = match dtype {
|
||||
GgmlDType::Q4_0 => "dequantize_mul_mat_vec_q4_0_cuda",
|
||||
GgmlDType::Q4_1 => "dequantize_mul_mat_vec_q4_1_cuda",
|
||||
GgmlDType::Q5_0 => "dequantize_mul_mat_vec_q5_0_cuda",
|
||||
GgmlDType::Q5_1 => "dequantize_mul_mat_vec_q5_1_cuda",
|
||||
GgmlDType::Q8_0 => "dequantize_mul_mat_vec_q8_0_cuda",
|
||||
GgmlDType::Q2K => "dequantize_mul_mat_vec_q2_k",
|
||||
GgmlDType::Q3K => "dequantize_mul_mat_vec_q3_k",
|
||||
GgmlDType::Q4K => "dequantize_mul_mat_vec_q4_k",
|
||||
GgmlDType::Q5K => "dequantize_mul_mat_vec_q5_k",
|
||||
GgmlDType::Q6K => "dequantize_mul_mat_vec_q6_k",
|
||||
_ => crate::bail!("unsupported dtype for quantized matmul {dtype:?}"),
|
||||
};
|
||||
let func = dev.get_or_load_func(kernel_name, candle_kernels::QUANTIZED)?;
|
||||
let dst = dev.alloc_zeros::<f32>(nrows).w()?;
|
||||
let block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
grid_dim: (block_num_y as u32, 1, 1),
|
||||
block_dim: (WARP_SIZE as u32, GGML_CUDA_MMV_Y as u32, 1),
|
||||
shared_mem_bytes: 0,
|
||||
};
|
||||
|
||||
let params = (data, y, &dst, ncols as i32, nrows as i32);
|
||||
unsafe { func.launch(cfg, params) }.w()?;
|
||||
Ok(CudaStorage::wrap_cuda_slice(dst, dev.clone()))
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(device: &CudaDevice, el_count: usize, dtype: GgmlDType) -> Result<Self> {
|
||||
let size_in_bytes = el_count * dtype.type_size() / dtype.block_size();
|
||||
let data = device.alloc_zeros::<u8>(size_in_bytes).w()?;
|
||||
Ok(QCudaStorage {
|
||||
data,
|
||||
device: device.clone(),
|
||||
dtype,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<CudaStorage> {
|
||||
let fast_kernel = matches!(
|
||||
self.dtype,
|
||||
GgmlDType::Q4_0
|
||||
| GgmlDType::Q4_1
|
||||
| GgmlDType::Q5_0
|
||||
| GgmlDType::Q5_1
|
||||
| GgmlDType::Q8_0
|
||||
| GgmlDType::Q2K
|
||||
| GgmlDType::Q3K
|
||||
| GgmlDType::Q4K
|
||||
| GgmlDType::Q5K
|
||||
| GgmlDType::Q6K
|
||||
| GgmlDType::Q8K
|
||||
);
|
||||
if fast_kernel {
|
||||
return dequantize(&self.data, self.dtype, elem_count, self.device());
|
||||
}
|
||||
// Run the dequantization on cpu.
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.dtoh_sync_copy(&self.data).w()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let slice =
|
||||
unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const f32, block_len) };
|
||||
out.copy_from_slice(slice)
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
self.device
|
||||
.storage_from_cpu_storage(&crate::CpuStorage::F32(out))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &CudaStorage) -> Result<()> {
|
||||
// Run the quantization on cpu.
|
||||
let src = match &src.slice {
|
||||
crate::cuda_backend::CudaStorageSlice::F32(data) => {
|
||||
self.device.dtoh_sync_copy(data).w()?
|
||||
}
|
||||
_ => crate::bail!("only f32 can be quantized"),
|
||||
};
|
||||
let src_len = src.len();
|
||||
let src = crate::Storage::Cpu(crate::CpuStorage::F32(src));
|
||||
let mut qcpu_storage = crate::Device::Cpu.qzeros(src_len, self.dtype)?;
|
||||
qcpu_storage.quantize(&src)?;
|
||||
let data = qcpu_storage.data()?;
|
||||
let data = self.device.htod_sync_copy(data.as_ref()).w()?;
|
||||
self.data = data;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
self.data.len()
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
if matches!(layout.shape().dims(), [1, 1, _] | [1, _]) {
|
||||
self.dequantize_matmul_vec(self_shape, storage, layout)
|
||||
} else {
|
||||
self.dequantize_matmul(self_shape, storage, layout)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
fn dequantize_matmul_vec(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
rhs: &CudaStorage,
|
||||
rhs_l: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
let (nrows, ncols) = self_shape.dims2()?;
|
||||
let rhs = rhs.as_cuda_slice::<f32>()?;
|
||||
let rhs = match rhs_l.contiguous_offsets() {
|
||||
Some((o1, o2)) => rhs.slice(o1..o2),
|
||||
None => Err(crate::Error::RequiresContiguous { op: "dmmv" }.bt())?,
|
||||
};
|
||||
let (with_batch, k) = match rhs_l.shape().dims() {
|
||||
[1, 1, k] => (true, k),
|
||||
[1, k] => (false, k),
|
||||
_ => crate::bail!("unexpected rhs shape in dmmv {:?}", rhs_l.shape()),
|
||||
};
|
||||
if ncols != *k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", rhs_l.shape())
|
||||
}
|
||||
|
||||
let out =
|
||||
dequantize_mut_mal_vec(&self.data, &rhs, self.dtype, ncols, nrows, self.device())?;
|
||||
let out_shape = if with_batch {
|
||||
vec![1, 1, nrows]
|
||||
} else {
|
||||
vec![1, nrows]
|
||||
};
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
|
||||
fn dequantize_matmul(
|
||||
&self,
|
||||
self_shape: &crate::Shape,
|
||||
storage: &CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
use crate::backend::BackendStorage;
|
||||
let (n, k) = self_shape.dims2()?;
|
||||
let (b, m, k2) = match layout.shape().dims() {
|
||||
&[b, m, k2] => (b, m, k2),
|
||||
&[m, k2] => (1, m, k2),
|
||||
s => crate::bail!("unexpected shape for input {s:?}"),
|
||||
};
|
||||
if k2 != k {
|
||||
crate::bail!("mismatch on matmul dim {self_shape:?} {:?}", layout.shape())
|
||||
}
|
||||
|
||||
let data_f32 = self.dequantize(n * k)?;
|
||||
let rhs_l = crate::Layout::new((k, n).into(), vec![1, k], 0).broadcast_as((b, k, n))?;
|
||||
let out = storage.matmul(&data_f32, (b, m, n, k), layout, &rhs_l)?;
|
||||
let mut out_shape = layout.shape().dims().to_vec();
|
||||
out_shape.pop();
|
||||
out_shape.push(n);
|
||||
Ok((out, out_shape.into()))
|
||||
}
|
||||
}
|
||||
|
||||
fn read_to_vec<T: Clone>(buffer: &[u8], n: usize) -> Vec<T> {
|
||||
let slice = unsafe { std::slice::from_raw_parts(buffer.as_ptr() as *const T, n) };
|
||||
slice.to_vec()
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &CudaDevice,
|
||||
data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
let data = unsafe {
|
||||
std::slice::from_raw_parts(data.as_ptr() as *const u8, core::mem::size_of_val(data))
|
||||
};
|
||||
let data = device.htod_sync_copy(data).w()?;
|
||||
Ok(QStorage::Cuda(QCudaStorage {
|
||||
data,
|
||||
device: device.clone(),
|
||||
dtype: T::DTYPE,
|
||||
}))
|
||||
}
|
@ -1,50 +0,0 @@
|
||||
#![allow(unused)]
|
||||
use super::GgmlDType;
|
||||
use crate::{CudaDevice, CudaStorage, Error, Result};
|
||||
|
||||
pub struct QCudaStorage {
|
||||
dtype: GgmlDType,
|
||||
device: CudaDevice,
|
||||
}
|
||||
|
||||
impl QCudaStorage {
|
||||
pub fn zeros(_: &CudaDevice, _: usize, _: GgmlDType) -> Result<Self> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn dtype(&self) -> GgmlDType {
|
||||
self.dtype
|
||||
}
|
||||
|
||||
pub fn device(&self) -> &CudaDevice {
|
||||
&self.device
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, _elem_count: usize) -> Result<CudaStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, _src: &CudaStorage) -> Result<()> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
|
||||
pub fn storage_size_in_bytes(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
pub fn fwd(
|
||||
&self,
|
||||
_self_shape: &crate::Shape,
|
||||
_storage: &CudaStorage,
|
||||
_layout: &crate::Layout,
|
||||
) -> Result<(CudaStorage, crate::Shape)> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &CudaDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithCudaSupport)
|
||||
}
|
@ -41,10 +41,3 @@ impl QMetalStorage {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
_device: &MetalDevice,
|
||||
_data: &[T],
|
||||
) -> Result<super::QStorage> {
|
||||
Err(Error::NotCompiledWithMetalSupport)
|
||||
}
|
||||
|
@ -1,5 +1,7 @@
|
||||
//! Support for the GGML file format.
|
||||
|
||||
#[cfg(feature = "metal")]
|
||||
use super::metal::load_quantized_metal;
|
||||
use super::{k_quants, GgmlDType, QStorage};
|
||||
use crate::{Device, Result};
|
||||
use byteorder::{LittleEndian, ReadBytesExt};
|
||||
@ -128,8 +130,13 @@ fn from_raw_data<T: super::GgmlType + Send + Sync + 'static>(
|
||||
let data = unsafe { std::slice::from_raw_parts(raw_data_ptr as *const T, n_blocks) };
|
||||
let data: QStorage = match device {
|
||||
Device::Cpu => QStorage::Cpu(Box::new(data.to_vec())),
|
||||
Device::Metal(metal) => super::metal::load_quantized(metal, data)?,
|
||||
Device::Cuda(cuda) => super::cuda::load_quantized(cuda, data)?,
|
||||
#[cfg(feature = "metal")]
|
||||
Device::Metal(metal) => load_quantized_metal(metal, data)?,
|
||||
#[cfg(not(feature = "metal"))]
|
||||
Device::Metal(_metal) => {
|
||||
crate::bail!("Metal backend requires `metal` feature")
|
||||
}
|
||||
device => unimplemented!("Implement quantized tensor for device {device:?}"),
|
||||
};
|
||||
super::QTensor::new(data, dims)
|
||||
}
|
||||
|
@ -34,8 +34,6 @@ impl QMetalStorage {
|
||||
}
|
||||
|
||||
pub fn dequantize(&self, elem_count: usize) -> Result<MetalStorage> {
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
|
||||
let buffer = self.device.new_buffer_managed(self.buffer.length())?;
|
||||
let command_buffer = self.device.command_buffer()?;
|
||||
command_buffer.set_label("to_cpu");
|
||||
@ -45,73 +43,87 @@ impl QMetalStorage {
|
||||
blit.end_encoding();
|
||||
self.device.wait_until_completed()?;
|
||||
let mut out = vec![0.0; elem_count];
|
||||
let block_len = elem_count / self.dtype.block_size();
|
||||
match self.dtype {
|
||||
GgmlDType::F32 => {
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<f32> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
f32::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::F16 => {
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<half::f16> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
half::f16::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_0 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8_0> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_0::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8_1 => {
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8_1> = read_to_vec(&buffer, elem_count);
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8_1::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q2K => {
|
||||
let vec: Vec<crate::quantized::BlockQ2K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ2K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ2K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q3K => {
|
||||
let vec: Vec<crate::quantized::BlockQ3K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ3K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ3K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q4K => {
|
||||
let vec: Vec<crate::quantized::BlockQ4K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ4K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ4K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q5K => {
|
||||
let vec: Vec<crate::quantized::BlockQ5K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ5K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ5K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q6K => {
|
||||
let vec: Vec<crate::quantized::BlockQ6K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ6K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ6K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
GgmlDType::Q8K => {
|
||||
let vec: Vec<crate::quantized::BlockQ8K> = read_to_vec(&buffer, block_len);
|
||||
let vec: Vec<crate::quantized::BlockQ8K> =
|
||||
read_to_vec(&buffer, elem_count / self.dtype.block_size());
|
||||
use crate::quantized::k_quants::GgmlType;
|
||||
crate::quantized::BlockQ8K::to_float(&vec, &mut out)?;
|
||||
}
|
||||
}
|
||||
|
||||
let buffer = self.device.new_buffer_with_data(&out)?;
|
||||
Ok(MetalStorage::new(
|
||||
buffer,
|
||||
self.device.clone(),
|
||||
elem_count,
|
||||
DType::F32,
|
||||
))
|
||||
Ok(MetalStorage::new(buffer, self.device.clone(), DType::F32))
|
||||
}
|
||||
|
||||
pub fn quantize(&mut self, src: &MetalStorage) -> Result<()> {
|
||||
@ -175,12 +187,12 @@ impl QMetalStorage {
|
||||
&dst,
|
||||
)
|
||||
.map_err(MetalError::from)?;
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, dst_shape.elem_count(), DType::F32);
|
||||
let dst_storage = crate::MetalStorage::new(dst, device, DType::F32);
|
||||
Ok((dst_storage, dst_shape))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn load_quantized<T: super::GgmlType + Send + Sync + 'static>(
|
||||
pub fn load_quantized_metal<T: super::GgmlType + Send + Sync + 'static>(
|
||||
device: &MetalDevice,
|
||||
data: &[T],
|
||||
) -> Result<QStorage> {
|
||||
|
@ -4,7 +4,6 @@ use std::borrow::Cow;
|
||||
|
||||
#[cfg(target_feature = "avx")]
|
||||
pub mod avx;
|
||||
mod dummy_cuda;
|
||||
mod dummy_metal;
|
||||
pub mod ggml_file;
|
||||
pub mod gguf_file;
|
||||
@ -15,13 +14,6 @@ pub mod metal;
|
||||
mod metal {
|
||||
pub use super::dummy_metal::*;
|
||||
}
|
||||
#[cfg(feature = "cuda")]
|
||||
pub mod cuda;
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
mod cuda {
|
||||
pub use super::dummy_cuda::*;
|
||||
}
|
||||
|
||||
#[cfg(target_feature = "neon")]
|
||||
pub mod neon;
|
||||
#[cfg(target_feature = "simd128")]
|
||||
@ -47,9 +39,8 @@ impl Device {
|
||||
let storage = metal::QMetalStorage::zeros(metal, elem_count, dtype)?;
|
||||
Ok(QStorage::Metal(storage))
|
||||
}
|
||||
Device::Cuda(cuda) => {
|
||||
let storage = cuda::QCudaStorage::zeros(cuda, elem_count, dtype)?;
|
||||
Ok(QStorage::Cuda(storage))
|
||||
Device::Cuda(_cuda) => {
|
||||
crate::bail!("Cuda ggml quantization not supported");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -58,7 +49,6 @@ impl Device {
|
||||
pub enum QStorage {
|
||||
Cpu(Box<dyn QuantizedType>),
|
||||
Metal(metal::QMetalStorage),
|
||||
Cuda(cuda::QCudaStorage),
|
||||
}
|
||||
|
||||
impl QStorage {
|
||||
@ -66,7 +56,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.block_size(),
|
||||
QStorage::Metal(storage) => storage.dtype().block_size(),
|
||||
QStorage::Cuda(storage) => storage.dtype().block_size(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -74,7 +63,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.dtype(),
|
||||
QStorage::Metal(storage) => storage.dtype(),
|
||||
QStorage::Cuda(storage) => storage.dtype(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,7 +70,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(_storage) => Device::Cpu,
|
||||
QStorage::Metal(storage) => Device::Metal(storage.device().clone()),
|
||||
QStorage::Cuda(storage) => Device::Cuda(storage.device().clone()),
|
||||
}
|
||||
}
|
||||
|
||||
@ -90,7 +77,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Metal(storage) => storage.storage_size_in_bytes(),
|
||||
QStorage::Cuda(storage) => storage.storage_size_in_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -100,7 +86,6 @@ impl QStorage {
|
||||
storage.from_float(src.as_slice::<f32>()?)?;
|
||||
}
|
||||
(QStorage::Metal(storage), Storage::Metal(src)) => storage.quantize(src)?,
|
||||
(QStorage::Cuda(storage), Storage::Cuda(src)) => storage.quantize(src)?,
|
||||
_ => crate::bail!("Invalid dequantize storage locations do not match"),
|
||||
}
|
||||
Ok(())
|
||||
@ -110,7 +95,6 @@ impl QStorage {
|
||||
match self {
|
||||
QStorage::Cpu(storage) => Ok(Storage::Cpu(storage.dequantize(elem_count)?)),
|
||||
QStorage::Metal(storage) => Ok(Storage::Metal(storage.dequantize(elem_count)?)),
|
||||
QStorage::Cuda(storage) => Ok(Storage::Cuda(storage.dequantize(elem_count)?)),
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,7 +106,7 @@ impl QStorage {
|
||||
let data = unsafe { std::slice::from_raw_parts(data_ptr, size_in_bytes) };
|
||||
Ok(Cow::from(data))
|
||||
}
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => {
|
||||
QStorage::Metal(_storage) => {
|
||||
crate::bail!("not implemented");
|
||||
}
|
||||
}
|
||||
@ -398,7 +382,7 @@ impl QMatMul {
|
||||
_ => DEQUANTIZE_ALL.with(|b| *b),
|
||||
};
|
||||
let t = if dequantize {
|
||||
let tensor = qtensor.dequantize(&qtensor.device())?;
|
||||
let tensor = qtensor.dequantize(&Device::Cpu)?;
|
||||
Self::Tensor(tensor)
|
||||
} else {
|
||||
Self::QTensor(qtensor)
|
||||
@ -440,7 +424,7 @@ impl crate::CustomOp1 for QTensor {
|
||||
#[allow(clippy::infallible_destructuring_match)]
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cpu(storage) => storage,
|
||||
QStorage::Metal(_) | QStorage::Cuda(_) => crate::bail!("Invalid storage"),
|
||||
QStorage::Metal(_) => crate::bail!("Invalid storage"),
|
||||
};
|
||||
let slice = storage.as_slice::<f32>()?;
|
||||
let slice = &slice[layout.start_offset()..layout.start_offset() + src_shape.elem_count()];
|
||||
@ -460,18 +444,6 @@ impl crate::CustomOp1 for QTensor {
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
|
||||
fn cuda_fwd(
|
||||
&self,
|
||||
storage: &crate::CudaStorage,
|
||||
layout: &crate::Layout,
|
||||
) -> Result<(crate::CudaStorage, Shape)> {
|
||||
let self_storage = match &self.storage {
|
||||
QStorage::Cuda(cuda) => cuda,
|
||||
_ => unreachable!("Cannot call cuda matmul on non cuda QTensor"),
|
||||
};
|
||||
self_storage.fwd(&self.shape, storage, layout)
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for QMatMul {
|
||||
|
@ -352,10 +352,6 @@ impl Storage {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Cuda(s))
|
||||
}
|
||||
(Storage::Metal(inp), Storage::Metal(kernel)) => {
|
||||
let s = inp.conv_transpose1d(l, kernel, kernel_l, params)?;
|
||||
Ok(Self::Metal(s))
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
@ -701,32 +697,4 @@ impl Storage {
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) fn copy2d(
|
||||
&self,
|
||||
dst: &mut Self,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o: usize,
|
||||
dst_o: usize,
|
||||
) -> Result<()> {
|
||||
match (self, dst) {
|
||||
(Self::Cpu(src), Self::Cpu(dst)) => src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o),
|
||||
(Self::Cuda(src), Self::Cuda(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(Self::Metal(src), Self::Metal(dst)) => {
|
||||
Ok(src.copy2d(dst, d1, d2, src_s, dst_s, src_o, dst_o)?)
|
||||
}
|
||||
(lhs, rhs) => Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: lhs.device().location(),
|
||||
rhs: rhs.device().location(),
|
||||
op: "copy2d",
|
||||
}
|
||||
.bt()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -666,7 +666,7 @@ impl Tensor {
|
||||
Ok(from_storage(storage, self.shape(), op, false))
|
||||
}
|
||||
|
||||
pub(crate) fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
fn check_dim(&self, dim: usize, op: &'static str) -> Result<()> {
|
||||
if dim >= self.dims().len() {
|
||||
Err(Error::DimOutOfRange {
|
||||
shape: self.shape().clone(),
|
||||
@ -1015,7 +1015,7 @@ impl Tensor {
|
||||
/// tensor also has three dimensions, `(batch, channels, target_size)`.
|
||||
pub fn interpolate1d(&self, target_size: usize) -> Result<Self> {
|
||||
let (n, c, _l) = self.dims3()?;
|
||||
let op = BackpropOp::new1(self, |arg| Op::UpsampleNearest1D { arg, target_size });
|
||||
let op = BackpropOp::new1(self, Op::UpsampleNearest1D);
|
||||
let storage = self
|
||||
.storage()
|
||||
.upsample_nearest1d(self.layout(), target_size)?;
|
||||
@ -2149,6 +2149,152 @@ impl Tensor {
|
||||
Self::cat(&args, dim)
|
||||
}
|
||||
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
// TODO: Avoid these transpositions and have an implementation that works
|
||||
// for dim != 0...
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = BackpropOp::new(args, |args| Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
/// Pad the input tensor using 0s along dimension `dim`. This adds `left` elements before the
|
||||
/// input tensor values and `right` elements after.
|
||||
pub fn pad_with_zeros<D: Dim>(&self, dim: D, left: usize, right: usize) -> Result<Self> {
|
||||
|
@ -1,240 +0,0 @@
|
||||
use crate::{shape::Dim, Error, Result, Shape, Tensor};
|
||||
|
||||
impl Tensor {
|
||||
/// Concatenates two or more tensors along a particular dimension.
|
||||
///
|
||||
/// All tensors must of the same rank, and the output will have
|
||||
/// the same rank
|
||||
///
|
||||
/// ```rust
|
||||
/// # use candle_core::{Tensor, DType, Device};
|
||||
/// let a = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
/// let b = Tensor::zeros((2, 3), DType::F32, &Device::Cpu)?;
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 0)?;
|
||||
/// assert_eq!(c.shape().dims(), &[4, 3]);
|
||||
///
|
||||
/// let c = Tensor::cat(&[&a, &b], 1)?;
|
||||
/// assert_eq!(c.shape().dims(), &[2, 6]);
|
||||
/// # Ok::<(), candle_core::Error>(())
|
||||
/// ```
|
||||
pub fn cat<A: AsRef<Tensor>, D: Dim>(args: &[A], dim: D) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let dim = dim.to_index(arg0.shape(), "cat")?;
|
||||
for arg in args {
|
||||
arg.as_ref().check_dim(dim, "cat")?;
|
||||
}
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg0.rank() != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: arg0.rank(),
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
if dim == 0 {
|
||||
Self::cat0(args)
|
||||
} else {
|
||||
let all_contiguous = args.iter().all(|v| v.as_ref().is_contiguous());
|
||||
if all_contiguous {
|
||||
Self::cat_contiguous(args, dim)
|
||||
} else {
|
||||
let args: Vec<Tensor> = args
|
||||
.iter()
|
||||
.map(|a| a.as_ref().transpose(0, dim))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let cat = Self::cat0(&args)?;
|
||||
cat.transpose(0, dim)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cat0<A: AsRef<Tensor>>(args: &[A]) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[0] = 0;
|
||||
let mut offsets = vec![0usize];
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == 0 {
|
||||
cat_dims[0] += v2;
|
||||
}
|
||||
if dim_idx != 0 && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
let next_offset = offsets.last().unwrap() + arg.elem_count();
|
||||
offsets.push(next_offset);
|
||||
}
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, 0));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
for (arg, &offset) in args.iter().zip(offsets.iter()) {
|
||||
let arg = arg.as_ref();
|
||||
arg.storage()
|
||||
.copy_strided_src(&mut storage, offset, arg.layout())?;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
|
||||
fn cat_contiguous<A: AsRef<Tensor>>(args: &[A], dim: usize) -> Result<Self> {
|
||||
if args.is_empty() {
|
||||
Err(Error::OpRequiresAtLeastOneTensor { op: "cat" }.bt())?
|
||||
}
|
||||
let arg0 = args[0].as_ref();
|
||||
if args.len() == 1 {
|
||||
return Ok(arg0.clone());
|
||||
}
|
||||
let rank = arg0.rank();
|
||||
let device = arg0.device();
|
||||
let dtype = arg0.dtype();
|
||||
let first_dims = arg0.shape().dims();
|
||||
let mut cat_dims = first_dims.to_vec();
|
||||
cat_dims[dim] = 0;
|
||||
for (arg_idx, arg) in args.iter().enumerate() {
|
||||
let arg = arg.as_ref();
|
||||
if arg.dtype() != dtype {
|
||||
Err(Error::DTypeMismatchBinaryOp {
|
||||
lhs: dtype,
|
||||
rhs: arg.dtype(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if arg.device().location() != device.location() {
|
||||
Err(Error::DeviceMismatchBinaryOp {
|
||||
lhs: device.location(),
|
||||
rhs: arg.device().location(),
|
||||
op: "cat",
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
if rank != arg.rank() {
|
||||
Err(Error::UnexpectedNumberOfDims {
|
||||
expected: rank,
|
||||
got: arg.rank(),
|
||||
shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
for (dim_idx, (v1, v2)) in arg0
|
||||
.shape()
|
||||
.dims()
|
||||
.iter()
|
||||
.zip(arg.shape().dims().iter())
|
||||
.enumerate()
|
||||
{
|
||||
if dim_idx == dim {
|
||||
cat_dims[dim] += v2;
|
||||
}
|
||||
if dim_idx != dim && v1 != v2 {
|
||||
Err(Error::ShapeMismatchCat {
|
||||
dim: dim_idx,
|
||||
first_shape: arg0.shape().clone(),
|
||||
n: arg_idx + 1,
|
||||
nth_shape: arg.shape().clone(),
|
||||
}
|
||||
.bt())?
|
||||
}
|
||||
}
|
||||
}
|
||||
let cat_target_dim_len = cat_dims[dim];
|
||||
let block_size: usize = cat_dims.iter().skip(1 + dim).product();
|
||||
let shape = Shape::from(cat_dims);
|
||||
let op = crate::op::BackpropOp::new(args, |args| crate::op::Op::Cat(args, dim));
|
||||
let mut storage = device.zeros(&shape, dtype)?;
|
||||
let mut dst_o = 0;
|
||||
for arg in args.iter() {
|
||||
let arg = arg.as_ref();
|
||||
let arg_dims = arg.shape().dims();
|
||||
let d1: usize = arg_dims.iter().take(dim).product();
|
||||
let d2 = block_size * arg_dims[dim];
|
||||
let dst_s = block_size * cat_target_dim_len;
|
||||
let src_o = arg.layout().start_offset();
|
||||
arg.storage().copy2d(
|
||||
&mut storage,
|
||||
d1,
|
||||
d2,
|
||||
/* src_s */ d2,
|
||||
dst_s,
|
||||
src_o,
|
||||
dst_o,
|
||||
)?;
|
||||
dst_o += d2;
|
||||
}
|
||||
Ok(crate::tensor::from_storage(storage, shape, op, false))
|
||||
}
|
||||
}
|
@ -18,9 +18,6 @@ w_t = w.transpose(0, 1)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
res = torch.nn.functional.conv_transpose1d(t, w_t, groups=2)
|
||||
print(res.shape)
|
||||
print(res)
|
||||
*/
|
||||
fn conv1d(dev: &Device) -> Result<()> {
|
||||
let t = Tensor::new(
|
||||
@ -53,36 +50,15 @@ fn conv1d(dev: &Device) -> Result<()> {
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[2.4509, 2.6357, -1.3336, 4.1393, 0.5657, 1.8091, -1.1784, 3.5675, 0.5069, 3.3352]
|
||||
);
|
||||
|
||||
// conv-transposes are not implemented for metal.
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let w = w.transpose(0, 1)?;
|
||||
// The CPU kernels applied in the contiguous and non contiguous cases are different.
|
||||
for w in [w.clone(), w.contiguous()?] {
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
let res = t.conv_transpose1d(&w, 0, 0, 1, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 4, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec2_round(&res.squeeze(0)?, 4)?,
|
||||
[
|
||||
[-1.5596, -1.8099, 2.0407, 4.8764, -0.1743, -0.735, -0.7819],
|
||||
[0.7816, 3.8152, -0.5926, 2.2515, -5.1844, -0.3157, 1.4721],
|
||||
[1.6295, 0.52, 6.2611, 0.7109, 2.6315, -1.8793, 0.7113],
|
||||
[1.0949, 1.0166, 1.7464, 2.4561, -0.79, -0.5119, 0.1488]
|
||||
]
|
||||
);
|
||||
}
|
||||
let res = t.conv_transpose1d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&res.flatten_all()?, 4)?,
|
||||
[
|
||||
0.0699, -1.2899, 8.3018, 5.5873, 2.4572, -2.6143, -0.0706, 1.8765, 4.8318, 1.1538,
|
||||
4.7076, -5.9745, -0.8276, 1.621
|
||||
],
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -168,33 +144,31 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
10.389, 3.6023, -4.2808, 0.2672, 5.3646, -5.2023, -2.1955, -9.4075
|
||||
]
|
||||
);
|
||||
if !dev.is_metal() {
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 7, 7]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
[-1.9918, 2.6797, -0.4599, -1.6037, 1.4131, -2.4012, 2.9277],
|
||||
[1.8016, -3.5361, 1.0757, 3.5395, -8.2168, -3.2023, 0.5375],
|
||||
[0.8243, 1.8675, 7.8929, -4.0746, -6.4415, 5.1139, 1.6889],
|
||||
[0.2722, 8.9679, 3.3477, 1.8514, -4.2896, -3.8228, -7.5632],
|
||||
[-8.5412, -5.8142, -7.1587, -1.6095, 0.4651, 0.2748, -2.0985],
|
||||
[2.0833, -0.6482, -12.1692, -4.1284, -2.9765, -0.0656, -4.5114],
|
||||
[5.307, 2.6957, 2.3087, 1.0478, 0.7808, -1.1519, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, 0.1872, -0.6408, -0.9897, 0.8503, 1.1019, -0.9211],
|
||||
[-0.1741, -0.2915, 4.2472, 1.9417, 1.65, 0.6303, -4.7131],
|
||||
[1.6555, 2.4026, -2.9293, 2.9953, 0.5328, 3.5873, -0.9621],
|
||||
[-1.4289, -3.2787, 4.1747, -6.0341, -4.6341, -5.7945, 4.142],
|
||||
[7.5973, 6.4431, 5.9872, 2.1639, -8.6566, 3.3143, -3.4059],
|
||||
[-0.8775, -3.048, 11.6543, 0.6442, 2.3218, -0.4765, 1.1516],
|
||||
[-5.5423, -2.5188, 1.0754, -0.0563, -2.9386, -1.1504, 1.0171]
|
||||
]
|
||||
);
|
||||
}
|
||||
]
|
||||
);
|
||||
// Dilations.
|
||||
let res = t.conv2d(&w, 0, 1, 2, 1)?;
|
||||
assert_eq!(res.dims(), [1, 2, 1, 1]);
|
||||
@ -203,44 +177,36 @@ fn conv2d(dev: &Device) -> Result<()> {
|
||||
[2.45, -2.3504],
|
||||
);
|
||||
|
||||
if !dev.is_metal() {
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
// Transpose and dilations.
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 2)?;
|
||||
assert_eq!(res.dims(), [1, 2, 9, 9]);
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(&res.i(0)?, 4)?,
|
||||
[
|
||||
[
|
||||
[
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[
|
||||
-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51,
|
||||
-3.5024
|
||||
],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[
|
||||
-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827,
|
||||
1.0171
|
||||
]
|
||||
]
|
||||
[-1.9918, 3.1652, -0.6778, -4.3442, 4.4351, 0.6652, -3.0124, -0.6031, 2.9277],
|
||||
[2.7036, -1.7156, -0.3969, 1.0516, 1.6381, -2.8886, -0.205, 2.4682, -1.0499],
|
||||
[-0.9459, 3.1631, 3.707, -4.8369, -8.5166, -1.4496, -2.7559, -3.2698, 1.4376],
|
||||
[-0.2157, 3.7786, -2.0252, -4.2633, 3.6731, -1.5142, 5.9391, -0.2622, -0.141],
|
||||
[-6.8121, -3.1744, 1.5945, 3.0637, -9.6088, 1.4446, 2.9489, -3.0082, -7.3822],
|
||||
[0.2371, 3.3303, 0.3861, 2.2646, -4.6784, 4.1235, -0.0109, 0.3176, -0.03],
|
||||
[-2.5339, -2.9564, -3.4518, -4.4594, -9.1873, -1.9709, -0.4676, 0.51, -3.5024],
|
||||
[4.007, 0.3067, -2.2954, 1.1105, -0.1992, 1.6372, -2.9268, 0.2807, -1.2787],
|
||||
[5.307, 1.1317, 1.3518, 0.9049, 3.8116, -0.4075, -0.8874, -0.2241, -0.9579]
|
||||
],
|
||||
[
|
||||
[1.089, -0.6483, 0.0726, -0.4752, -1.3283, 1.7103, 1.0703, 0.1076, -0.9211],
|
||||
[-0.8629, 0.1376, 0.3202, 2.0955, 0.9696, 2.8988, -1.0012, 1.5049, -0.1278],
|
||||
[1.9286, -1.5255, -2.9563, 2.4589, 3.3611, -0.6951, 0.3525, -1.7724, -5.9861],
|
||||
[1.1226, 2.1561, 3.6417, 4.7546, -0.692, 4.4126, -5.1902, 6.0805, 2.3185],
|
||||
[1.0111, 0.3604, 0.6432, -3.6605, 7.9517, -9.2955, -5.2988, -3.7803, -2.0642],
|
||||
[3.3172, -1.7967, -3.6576, -2.0942, 1.3158, 0.112, -1.7405, 2.9167, 0.7957],
|
||||
[5.1001, 1.8995, -1.8639, 1.1262, 9.9629, 2.683, -3.6319, -1.1607, 0.5856],
|
||||
[-4.8445, -0.5642, 4.2317, 0.0856, 1.2267, -0.5712, 1.736, 1.0997, 0.6908],
|
||||
[-5.5423, -1.1831, -1.2176, 0.0843, 0.0446, -0.7545, -2.4798, -0.0827, 1.0171]
|
||||
]
|
||||
);
|
||||
}
|
||||
]
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -294,12 +260,6 @@ fn conv2d_small(dev: &Device) -> Result<()> {
|
||||
0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000
|
||||
]
|
||||
);
|
||||
|
||||
// conv-transposes are not implemented for metal
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let res = t.conv_transpose2d(&w.transpose(0, 1)?, 0, 0, 1, 1)?;
|
||||
assert_eq!(res.dims(), [1, 1, 3, 3]);
|
||||
assert_eq!(
|
||||
@ -401,10 +361,6 @@ print(w.grad.shape)
|
||||
print(w.grad[0])
|
||||
*/
|
||||
fn conv2d_grad(dev: &Device) -> Result<()> {
|
||||
// conv-transposes are not implemented for metal
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
use candle_core::Var;
|
||||
let t = Var::from_slice(
|
||||
&[
|
||||
|
@ -1,4 +1,3 @@
|
||||
#![allow(clippy::approx_constant)]
|
||||
use anyhow::{Context, Result};
|
||||
use candle_core::{test_device, test_utils, Device, Shape, Tensor, Var};
|
||||
|
||||
@ -97,24 +96,24 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
y.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
);
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 4)?,
|
||||
[20.0855, 2.7183, 54.5982, 1.1618]
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[20.085537, 2.7182817, 54.59815, 1.1618342]
|
||||
);
|
||||
let y = x.exp()?.sqr()?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(x).context("no grad for x")?;
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 3)?,
|
||||
[403.429, 7.389, 2980.958, 1.35]
|
||||
y.to_vec1::<f32>()?,
|
||||
[403.4288, 7.3890557, 2980.9578, 1.3498588]
|
||||
);
|
||||
// exp(x)^2 = exp(2*x)
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(grad_x, 2)?,
|
||||
[806.86, 14.78, 5961.92, 2.7]
|
||||
grad_x.to_vec1::<f32>()?,
|
||||
[806.8576, 14.778111, 5961.9155, 2.6997175]
|
||||
);
|
||||
let y = x.sin()?;
|
||||
let grads = y.backward()?;
|
||||
@ -262,7 +261,6 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let y = elu_x.elu(2.)?;
|
||||
let grads = y.backward()?;
|
||||
let grad_x = grads.get(&elu_x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec1_round(&y, 4)?,
|
||||
[-1.2642, 0.0000, -1.7293, 3.0000]
|
||||
@ -285,38 +283,19 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
[1.0881, 0.9277, 1.0527, 0.5747],
|
||||
);
|
||||
|
||||
if device.is_cpu() {
|
||||
let x = Var::new(&[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]], device)?;
|
||||
let y = x.interpolate1d(12)?.reshape(36)?;
|
||||
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16.,
|
||||
17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32.,
|
||||
33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
||||
let loss = y.unsqueeze(1)?.transpose(0, 1)?.matmul(&z.unsqueeze(1)?)?;
|
||||
let grads = loss.backward()?;
|
||||
let grad_x = grads.get(&x).context("no grad for x")?;
|
||||
|
||||
assert_eq!(
|
||||
test_utils::to_vec3_round(grad_x, 4)?,
|
||||
[[[10_f32, 26., 42.], [58., 74., 90.], [106., 122., 138.]]]
|
||||
);
|
||||
}
|
||||
|
||||
// manually checked: see comments
|
||||
let x = Var::new(&[[[[1f32, 2., 3.], [4., 5., 6.], [7., 8., 9.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
@ -347,11 +326,15 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||
let x = Var::new(&[[[[1f32, 2.], [4., 5.]]]], device)?;
|
||||
let y = x.interpolate2d(6, 6)?.reshape(36)?;
|
||||
|
||||
#[rustfmt::skip]
|
||||
let z = Tensor::new(
|
||||
&[
|
||||
1_f32, 02., 03., 04., 05., 06., 07., 08., 09., 10., 11., 12., 13., 14., 15., 16., 17.,
|
||||
18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 32., 33., 34.,
|
||||
35., 36.,
|
||||
1_f32, 02., 03., 04., 05., 06.,
|
||||
07., 08., 09., 10., 11., 12.,
|
||||
13., 14., 15., 16., 17., 18.,
|
||||
19., 20., 21., 22., 23., 24.,
|
||||
25., 26., 27., 28., 29., 30.,
|
||||
31., 32., 33., 34., 35., 36.,
|
||||
],
|
||||
device,
|
||||
)?;
|
||||
|
@ -2,9 +2,6 @@ use candle_core::{test_device, test_utils, Device, IndexOp, Result, Tensor};
|
||||
|
||||
// https://github.com/huggingface/candle/issues/364
|
||||
fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let data: Vec<f32> = vec![
|
||||
1., 1., 1., 1., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
|
||||
];
|
||||
@ -22,9 +19,6 @@ fn avg_pool2d(dev: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn max_pool2d(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let data: Vec<f32> = vec![
|
||||
1., 2., 1., 3., 0., 0., 1., 1., 1., 1., 1., 1., 5., 1., 1., 1.,
|
||||
];
|
||||
@ -49,9 +43,6 @@ res = torch.nn.functional.avg_pool2d(t, 2)
|
||||
print(res)
|
||||
*/
|
||||
fn avg_pool2d_pytorch(dev: &Device) -> Result<()> {
|
||||
if dev.is_metal() {
|
||||
return Ok(());
|
||||
}
|
||||
let t = Tensor::new(
|
||||
&[
|
||||
0.4056f32, -0.8689, -0.0773, -1.5630, -2.8012, -1.5059, 0.3972, 1.0852, 0.4997, 3.0616,
|
||||
|
@ -178,6 +178,10 @@ test_device!(
|
||||
);
|
||||
|
||||
fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
@ -205,6 +209,10 @@ fn quantize_q4_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q4_1)?;
|
||||
@ -231,6 +239,10 @@ fn quantize_q4_1(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_0)?;
|
||||
@ -257,6 +269,10 @@ fn quantize_q5_0(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5_1(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let src = (0..32 * 4).map(|v| v as f32).collect::<Vec<_>>();
|
||||
let src = Tensor::from_slice(&src, (32 * 4,), device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, GgmlDType::Q5_1)?;
|
||||
@ -357,6 +373,10 @@ fn ggml_quantization_error_test(dtype: GgmlDType, device: &Device, max_error: f3
|
||||
}
|
||||
|
||||
fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q2K;
|
||||
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
@ -391,6 +411,10 @@ fn quantize_q2k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q3K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -424,6 +448,10 @@ fn quantize_q3k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q4K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -457,6 +485,10 @@ fn quantize_q4k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q5K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -490,6 +522,10 @@ fn quantize_q5k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q6K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -523,6 +559,10 @@ fn quantize_q6k(device: &Device) -> Result<()> {
|
||||
}
|
||||
|
||||
fn quantize_q8k(device: &Device) -> Result<()> {
|
||||
// TODO Enable this later when we enable cuda.
|
||||
if device.is_cuda() {
|
||||
return Ok(());
|
||||
}
|
||||
let dtype = GgmlDType::Q8K;
|
||||
let src = get_test_vector2(0.5, 1024, device)?;
|
||||
let quant = quantized::QTensor::quantize(&src, dtype)?;
|
||||
@ -738,6 +778,10 @@ macro_rules! quantized_matmul {
|
||||
// stable. https://github.com/rust-lang/rust/issues/29599
|
||||
($fn_name: ident, $fn_name_cpu: ident, $fn_name_cuda: ident, $fn_name_metal: ident, $dtype: expr) => {
|
||||
fn $fn_name(device: &Device) -> Result<()> {
|
||||
if device.is_cuda() {
|
||||
// TODO Enable Cuda GGML sometime maybe.
|
||||
return Ok(());
|
||||
}
|
||||
test_matmul(device, (1, 3, 4, 256), $dtype)?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -672,31 +672,6 @@ fn cat(device: &Device) -> Result<()> {
|
||||
[2.0, 7.0, 1.0, 8.0, 2.0, 2.0, 7.0, 1.0, 8.0, 2.0]
|
||||
]
|
||||
);
|
||||
|
||||
// 3D
|
||||
let t1 = Tensor::arange(0, 48i64, device)?.reshape((2, 6, 4))?;
|
||||
let t2 = Tensor::arange(100, 124i64, device)?.reshape((2, 3, 4))?;
|
||||
let t3 = Tensor::arange(10000, 10032i64, device)?.reshape((2, 4, 4))?;
|
||||
|
||||
let t_cat = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let t1 = t1.t()?.contiguous()?.t()?;
|
||||
let t2 = t2.t()?.contiguous()?.t()?;
|
||||
let t3 = t3.t()?.contiguous()?.t()?;
|
||||
let t_cat2 = Tensor::cat(&[&t1, &t2, &t3], 1)?;
|
||||
|
||||
let diff = t_cat.eq(&t_cat2)?.to_dtype(DType::F32)?.sum_all()?;
|
||||
assert_eq!(diff.to_vec0::<f32>()?, 104.0);
|
||||
assert_eq!(t_cat.i((0, 0, 0))?.to_vec0::<i64>()?, 0);
|
||||
assert_eq!(t_cat.i((0, 4, 0))?.to_vec0::<i64>()?, 16);
|
||||
assert_eq!(t_cat.i((0, 5, 0))?.to_vec0::<i64>()?, 20);
|
||||
assert_eq!(t_cat.i((1, 5, 0))?.to_vec0::<i64>()?, 44);
|
||||
assert_eq!(t_cat.i((0, 6, 0))?.to_vec0::<i64>()?, 100);
|
||||
assert_eq!(t_cat.i((1, 6, 0))?.to_vec0::<i64>()?, 112);
|
||||
assert_eq!(t_cat.i((0, 6, 1))?.to_vec0::<i64>()?, 101);
|
||||
assert_eq!(t_cat.i((0, 7, 1))?.to_vec0::<i64>()?, 105);
|
||||
assert_eq!(t_cat.i((0, 12, 1))?.to_vec0::<i64>()?, 10013);
|
||||
assert_eq!(t_cat.i((1, 12, 3))?.to_vec0::<i64>()?, 10031);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1105,33 +1080,8 @@ fn broadcasting(device: &Device) -> Result<()> {
|
||||
fn randn(device: &Device) -> Result<()> {
|
||||
let tensor = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::randn(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
let tensor = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_eq!(tensor.dims(), [5, 3]);
|
||||
// Check that the seed gets updated by checking that
|
||||
// a new series of numbers is generated each time
|
||||
let tensor2 = Tensor::rand(0f32, 1f32, (5, 3), device)?;
|
||||
assert_ne!(tensor.to_vec2::<f32>()?, tensor2.to_vec2::<f32>()?);
|
||||
// We do not expect deterministic elements at any index.
|
||||
// There once was a bug that had a deterministic zero element in evenly sized tensors.
|
||||
const N: usize = 2;
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::randn(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the randn tensors"
|
||||
);
|
||||
let v = (0..100)
|
||||
.map(|_| Tensor::rand(0f32, 1f32, N, device).and_then(|t| t.to_vec1::<f32>()))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
assert!(
|
||||
(0..N).all(|i| v.windows(2).any(|pair| pair[0][i] != pair[1][i])),
|
||||
"There are deterministic values in the rand tensors"
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
@ -12,7 +12,7 @@ readme = "README.md"
|
||||
[dependencies]
|
||||
accelerate-src = { workspace = true, optional = true }
|
||||
candle = { workspace = true }
|
||||
candle-datasets = { workspace = true, optional = true }
|
||||
candle-datasets = { workspace = true }
|
||||
candle-nn = { workspace = true }
|
||||
candle-transformers = { workspace = true }
|
||||
candle-flash-attn = { workspace = true, optional = true }
|
||||
@ -30,7 +30,7 @@ rayon = { workspace = true }
|
||||
safetensors = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"], optional = true }
|
||||
symphonia = { version = "0.5.3", features = ["all"] }
|
||||
tokenizers = { workspace = true, features = ["onig"] }
|
||||
cpal= { version = "0.15.2", optional = true }
|
||||
|
||||
@ -80,24 +80,6 @@ required-features = ["onnx"]
|
||||
name = "onnx_basics"
|
||||
required-features = ["onnx"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
[[example]]
|
||||
name = "whisper-microphone"
|
||||
required-features = ["microphone"]
|
||||
|
||||
[[example]]
|
||||
name = "mnist-training"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "llama2-c"
|
||||
required-features = ["candle-datasets"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["symphonia"]
|
||||
|
||||
|
||||
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -93,7 +93,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -1 +0,0 @@
|
||||
pub const LAYERNORM_KERNELS: &str = include_str!(concat!(env!("OUT_DIR"), "/layernorm_kernels.ptx"));
|
||||
|
@ -31,7 +31,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -47,7 +47,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -1,20 +0,0 @@
|
||||
# candle-efficientvit
|
||||
|
||||
[EfficientViT: Memory Efficient Vision Transformer with Cascaded Group Attention](https://arxiv.org/abs/2305.07027).
|
||||
|
||||
This candle implementation uses a pre-trained EfficientViT (from Microsoft Research Asia) network for inference.
|
||||
The classification head has been trained on the ImageNet dataset and returns the probabilities for the top-5 classes.
|
||||
|
||||
## Running an example
|
||||
|
||||
```
|
||||
$ cargo run --example efficientvit --release -- --image candle-examples/examples/yolo-v8/assets/bike.jpg --which m1
|
||||
|
||||
loaded image Tensor[dims 3, 224, 224; f32]
|
||||
model built
|
||||
mountain bike, all-terrain bike, off-roader: 69.80%
|
||||
unicycle, monocycle : 13.03%
|
||||
bicycle-built-for-two, tandem bicycle, tandem: 9.28%
|
||||
crash helmet : 2.25%
|
||||
alp : 0.46%
|
||||
```
|
@ -1,99 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle::{DType, IndexOp, D};
|
||||
use candle_nn::{Module, VarBuilder};
|
||||
use candle_transformers::models::efficientvit;
|
||||
|
||||
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||
enum Which {
|
||||
M0,
|
||||
M1,
|
||||
M2,
|
||||
M3,
|
||||
M4,
|
||||
M5,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn model_filename(&self) -> String {
|
||||
let name = match self {
|
||||
Self::M0 => "m0",
|
||||
Self::M1 => "m1",
|
||||
Self::M2 => "m2",
|
||||
Self::M3 => "m3",
|
||||
Self::M4 => "m4",
|
||||
Self::M5 => "m5",
|
||||
};
|
||||
format!("timm/efficientvit_{}.r224_in1k", name)
|
||||
}
|
||||
|
||||
fn config(&self) -> efficientvit::Config {
|
||||
match self {
|
||||
Self::M0 => efficientvit::Config::m0(),
|
||||
Self::M1 => efficientvit::Config::m1(),
|
||||
Self::M2 => efficientvit::Config::m2(),
|
||||
Self::M3 => efficientvit::Config::m3(),
|
||||
Self::M4 => efficientvit::Config::m4(),
|
||||
Self::M5 => efficientvit::Config::m5(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser)]
|
||||
struct Args {
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
image: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
#[arg(value_enum, long, default_value_t=Which::M0)]
|
||||
which: Which,
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
None => {
|
||||
let model_name = args.which.model_filename();
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name);
|
||||
api.get("model.safetensors")?
|
||||
}
|
||||
Some(model) => model.into(),
|
||||
};
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||
let model = efficientvit::efficientvit(&args.which.config(), 1000, vb)?;
|
||||
println!("model built");
|
||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?
|
||||
.i(0)?
|
||||
.to_vec1::<f32>()?;
|
||||
let mut prs = prs.iter().enumerate().collect::<Vec<_>>();
|
||||
prs.sort_by(|(_, p1), (_, p2)| p2.total_cmp(p1));
|
||||
for &(category_idx, pr) in prs.iter().take(5) {
|
||||
println!(
|
||||
"{:24}: {:.2}%",
|
||||
candle_examples::imagenet::CLASSES[category_idx],
|
||||
100. * pr
|
||||
);
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
# candle-endocec
|
||||
|
||||
[EnCodec](https://huggingface.co/facebook/encodec_24khz) is a high-quality audio
|
||||
compression model using an encoder/decoder architecture with residual vector
|
||||
quantization.
|
||||
|
||||
## Running one example
|
||||
|
||||
```bash
|
||||
cargo run --example encodec --features symphonia --release -- code-to-audio \
|
||||
candle-examples/examples/encodec/jfk-codes.safetensors \
|
||||
jfk.wav
|
||||
```
|
||||
|
||||
This decodes the EnCodec tokens stored in `jfk-codes.safetensors` and generates
|
||||
an output wav file containing the audio data. Instead of `code-to-audio` one
|
||||
can use:
|
||||
- `audio-to-audio in.mp3 out.wav`: encodes the input audio file then decodes it to a wav file.
|
||||
- `audio-to-code in.mp3 out.safetensors`: generates a safetensors file
|
||||
containing EnCodec tokens for the input audio file.
|
Binary file not shown.
@ -1,142 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::encodec::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some encodec tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some encodec tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let config = Config::default();
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
codes.get("codes").expect("no codes in input file").clone()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let (pcm, sample_rate) = pcm_decode(args.in_file)?;
|
||||
if sample_rate != 24_000 {
|
||||
println!("WARNING: encodec uses a 24khz sample rate, input uses {sample_rate}")
|
||||
}
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
codes.save_safetensors("codes", &args.out_file)?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,27 +0,0 @@
|
||||
# candle-gemma: 2b and 7b LLMs from Google DeepMind
|
||||
|
||||
[Gemma](https://ai.google.dev/gemma/docs) is a collection of lightweight open
|
||||
models published by Google Deepmind with a 2b and a 7b variant.
|
||||
|
||||
In order to use the example below, you have to accept the license on the
|
||||
[HuggingFace Hub Gemma repo](https://huggingface.co/google/gemma-7b) and set up
|
||||
your access token via the [HuggingFace cli login
|
||||
command](https://huggingface.co/docs/huggingface_hub/guides/cli#huggingface-cli-login).
|
||||
|
||||
## Running the example
|
||||
|
||||
```bash
|
||||
$ cargo run --example gemma --release -- --prompt "fn count_primes(max_n: usize)"
|
||||
fn count_primes(max_n: usize) -> usize {
|
||||
let mut primes = vec![true; max_n];
|
||||
for i in 2..=max_n {
|
||||
if primes[i] {
|
||||
for j in i * i..max_n {
|
||||
primes[j] = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
primes.len()
|
||||
}
|
||||
```
|
||||
|
@ -1,256 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::gemma::{Config, Model};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<eos>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <eos> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match &args.model_id {
|
||||
Some(model_id) => match model_id.as_str() {
|
||||
"7b-it" => "google/gemma-7b-it".to_string(),
|
||||
"7b" => "google/gemma-7b".to_string(),
|
||||
"2b-it" => "google/gemma-2b-it".to_string(),
|
||||
"2b" => "google/gemma-2b".to_string(),
|
||||
_ => model_id.to_string(),
|
||||
},
|
||||
None => "google/gemma-2b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let config_filename = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let config: Config = serde_json::from_reader(std::fs::File::open(config_filename)?)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -120,7 +120,7 @@ fn main() -> Result<()> {
|
||||
Some(dtype) => bail!("Unsupported dtype {dtype}"),
|
||||
None => DType::F16,
|
||||
};
|
||||
let (llama, tokenizer_filename, mut cache) = {
|
||||
let (llama, tokenizer_filename, cache) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
@ -146,7 +146,7 @@ fn main() -> Result<()> {
|
||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
(Llama::load(vb, &config)?, tokenizer_filename, cache)
|
||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||
};
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
let eos_token_id = tokenizer.token_to_id(EOS_TOKEN);
|
||||
@ -172,7 +172,7 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = llama.forward(&input, context_index, &mut cache)?;
|
||||
let logits = llama.forward(&input, context_index)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
let logits = if args.repeat_penalty == 1. {
|
||||
logits
|
||||
|
@ -19,7 +19,7 @@ use candle_transformers::generation::LogitsProcessor;
|
||||
use std::io::Write;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
use model::{Cache, Config, Llama};
|
||||
use model::{Config, Llama};
|
||||
use qmodel::QLlama;
|
||||
use weights::TransformerWeights;
|
||||
|
||||
@ -160,10 +160,10 @@ enum Model {
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, pos: usize, cache: &mut Cache) -> anyhow::Result<Tensor> {
|
||||
fn forward(&self, xs: &Tensor, pos: usize) -> anyhow::Result<Tensor> {
|
||||
match self {
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos, cache)?),
|
||||
Self::Llama(l) => Ok(l.forward(xs, pos)?),
|
||||
Self::QLlama(l) => Ok(l.forward(xs, pos)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -188,8 +188,8 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let cache = model::Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
|
||||
let tokens = match &args.pretokenized_dir {
|
||||
None => {
|
||||
@ -235,7 +235,7 @@ fn run_eval(args: &EvaluationCmd, common_args: &Args) -> Result<()> {
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
for inp_tgt in batch_iter {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
println!("{}", loss.to_vec0::<f32>()?);
|
||||
}
|
||||
@ -261,7 +261,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let is_safetensors = config_path
|
||||
.extension()
|
||||
.map_or(false, |v| v == "safetensors");
|
||||
let (model, config, mut cache) = if is_gguf {
|
||||
let (model, config) = if is_gguf {
|
||||
let vb = qmodel::VarBuilder::from_gguf(config_path, &device)?;
|
||||
let (_vocab_size, dim) = vb
|
||||
.get_no_shape("model.embed_tokens.weight")?
|
||||
@ -298,15 +298,15 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
&device,
|
||||
);
|
||||
let cache = model::Cache::new(true, &config, fake_vb)?;
|
||||
let model = Model::QLlama(QLlama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::QLlama(QLlama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else if is_safetensors {
|
||||
let config = Config::tiny_15m();
|
||||
let tensors = candle::safetensors::load(config_path, &device)?;
|
||||
let vb = candle_nn::VarBuilder::from_tensors(tensors, candle::DType::F32, &device);
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
} else {
|
||||
let mut file = std::fs::File::open(config_path)?;
|
||||
let config = Config::from_reader(&mut file)?;
|
||||
@ -314,8 +314,8 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let weights = TransformerWeights::from_reader(&mut file, &config, &device)?;
|
||||
let vb = weights.var_builder(&config, &device)?;
|
||||
let cache = model::Cache::new(true, &config, vb.pp("rot"))?;
|
||||
let model = Model::Llama(Llama::load(vb, config.clone())?);
|
||||
(model, config, cache)
|
||||
let model = Model::Llama(Llama::load(vb, &cache, config.clone())?);
|
||||
(model, config)
|
||||
};
|
||||
|
||||
println!("starting the inference loop");
|
||||
@ -338,7 +338,7 @@ fn run_inference(args: &InferenceCmd, common_args: &Args) -> Result<()> {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let ctxt = &tokens[tokens.len().saturating_sub(context_size)..];
|
||||
let input = Tensor::new(ctxt, &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, index_pos, &mut cache)?;
|
||||
let logits = model.forward(&input, index_pos)?;
|
||||
let logits = logits.i((0, logits.dim(1)? - 1))?;
|
||||
let logits = if common_args.repeat_penalty == 1. || tokens.is_empty() {
|
||||
logits
|
||||
|
@ -8,7 +8,6 @@ fn valid_loss(
|
||||
model: &Llama,
|
||||
args: &crate::TrainingCmd,
|
||||
device: &Device,
|
||||
cache: &mut Cache,
|
||||
) -> Result<f64> {
|
||||
let iter = DatasetRandomIter::new(dataset, true, model.config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
@ -16,7 +15,7 @@ fn valid_loss(
|
||||
let mut cnt = 0usize;
|
||||
for inp_tgt in batch_iter.take(50) {
|
||||
let (inp, tgt) = inp_tgt?;
|
||||
let logits = model.forward(&inp, 0, cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
sum_ce += loss.to_vec0::<f32>()? as f64;
|
||||
cnt += 1;
|
||||
@ -38,8 +37,8 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let iter = DatasetRandomIter::new(&dataset, false, config.seq_len, device.clone());
|
||||
let batch_iter = candle_datasets::Batcher::new_r2(iter).batch_size(args.batch_size);
|
||||
|
||||
let mut cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, config)?;
|
||||
let cache = Cache::new(false, &config, vb.pp("rot"))?;
|
||||
let model = Llama::load(vb, &cache, config)?;
|
||||
let params = candle_nn::ParamsAdamW {
|
||||
lr: args.learning_rate,
|
||||
..Default::default()
|
||||
@ -47,14 +46,14 @@ pub fn run(args: &crate::TrainingCmd, common_args: &crate::Args) -> Result<()> {
|
||||
let mut opt = candle_nn::AdamW::new(varmap.all_vars(), params)?;
|
||||
for (batch_index, batch) in batch_iter.enumerate() {
|
||||
let (inp, tgt) = batch?;
|
||||
let logits = model.forward(&inp, 0, &mut cache)?;
|
||||
let logits = model.forward(&inp, 0)?;
|
||||
let loss = candle_nn::loss::cross_entropy(&logits.flatten_to(1)?, &tgt.flatten_to(1)?)?;
|
||||
opt.backward_step(&loss)?;
|
||||
|
||||
if batch_index > 0 && batch_index % 100 == 0 {
|
||||
// TODO: Add a way to deactivate the backprop graph tracking when computing the
|
||||
// validation loss.
|
||||
let loss = valid_loss(&dataset, &model, args, &device, &mut cache)?;
|
||||
let loss = valid_loss(&dataset, &model, args, &device)?;
|
||||
println!("{batch_index} {loss}");
|
||||
}
|
||||
if batch_index > 0 && batch_index % 1000 == 0 {
|
||||
|
@ -1,18 +0,0 @@
|
||||
# candle-metavoice
|
||||
|
||||
MetaVoice-1B is a text-to-speech model trained on 100K hours of speech, more
|
||||
details on the [model
|
||||
card](https://huggingface.co/metavoiceio/metavoice-1B-v0.1).
|
||||
|
||||
Note that the current candle implementation suffers from some limitations as of
|
||||
2024-03-02:
|
||||
- The speaker embeddings are hardcoded.
|
||||
- The generated audio file quality is weaker than the Python implementation,
|
||||
probably because of some implementation discrepancies.
|
||||
|
||||
## Run an example
|
||||
|
||||
```bash
|
||||
cargo run --example metavoice --release -- \\
|
||||
--prompt "This is a demo of text to speech by MetaVoice-1B, an open-source foundational audio model."
|
||||
```
|
@ -1,277 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use clap::Parser;
|
||||
use std::io::Write;
|
||||
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use candle_transformers::models::encodec;
|
||||
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
||||
use candle_transformers::models::quantized_metavoice::transformer as qtransformer;
|
||||
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use hf_hub::api::sync::Api;
|
||||
use rand::{distributions::Distribution, SeedableRng};
|
||||
|
||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum ArgDType {
|
||||
F32,
|
||||
F16,
|
||||
Bf16,
|
||||
}
|
||||
|
||||
enum Transformer {
|
||||
Normal(transformer::Model),
|
||||
Quantized(qtransformer::Model),
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// Use the quantized version of the model.
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// The guidance scale.
|
||||
#[arg(long, default_value_t = 3.0)]
|
||||
guidance_scale: f64,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long, default_value_t = 1.0)]
|
||||
temperature: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The maximum number of tokens to generate for the first stage.
|
||||
#[arg(long, default_value_t = 2000)]
|
||||
max_tokens: u64,
|
||||
|
||||
/// The output file using the wav format.
|
||||
#[arg(long, default_value = "out.wav")]
|
||||
out_file: String,
|
||||
|
||||
#[arg(long)]
|
||||
first_stage_meta: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
first_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
second_stage_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
encodec_weights: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
spk_emb: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "f32")]
|
||||
dtype: ArgDType,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let api = Api::new()?;
|
||||
let repo = api.model("lmz/candle-metavoice".to_string());
|
||||
let first_stage_meta = match &args.first_stage_meta {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.meta.json")?,
|
||||
};
|
||||
let first_stage_meta: serde_json::Value =
|
||||
serde_json::from_reader(&std::fs::File::open(first_stage_meta)?)?;
|
||||
let first_stage_tokenizer = match first_stage_meta.as_object() {
|
||||
None => anyhow::bail!("not a json object"),
|
||||
Some(j) => match j.get("tokenizer") {
|
||||
None => anyhow::bail!("no tokenizer key"),
|
||||
Some(j) => j,
|
||||
},
|
||||
};
|
||||
let fs_tokenizer = tokenizers::BPE::from_json(first_stage_tokenizer, 512)?;
|
||||
|
||||
let second_stage_weights = match &args.second_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("second_stage.safetensors")?,
|
||||
};
|
||||
let encodec_weights = match args.encodec_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => Api::new()?
|
||||
.model("facebook/encodec_24khz".to_string())
|
||||
.get("model.safetensors")?,
|
||||
};
|
||||
let dtype = match args.dtype {
|
||||
ArgDType::F32 => DType::F32,
|
||||
ArgDType::F16 => DType::F16,
|
||||
ArgDType::Bf16 => DType::BF16,
|
||||
};
|
||||
|
||||
let first_stage_config = transformer::Config::cfg1b_v0_1();
|
||||
let mut first_stage_model = if args.quantized {
|
||||
let filename = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage_q4k.gguf")?,
|
||||
};
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
let first_stage_model = qtransformer::Model::new(&first_stage_config, vb)?;
|
||||
Transformer::Quantized(first_stage_model)
|
||||
} else {
|
||||
let first_stage_weights = match &args.first_stage_weights {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("first_stage.safetensors")?,
|
||||
};
|
||||
let first_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[first_stage_weights], dtype, &device)? };
|
||||
let first_stage_model = transformer::Model::new(&first_stage_config, first_stage_vb)?;
|
||||
Transformer::Normal(first_stage_model)
|
||||
};
|
||||
|
||||
let second_stage_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[second_stage_weights], dtype, &device)? };
|
||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
||||
|
||||
let encodec_device = if device.is_metal() {
|
||||
&candle::Device::Cpu
|
||||
} else {
|
||||
&device
|
||||
};
|
||||
let encodec_vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[encodec_weights], dtype, encodec_device)? };
|
||||
let encodec_config = encodec::Config::default();
|
||||
let encodec_model = encodec::Model::new(&encodec_config, encodec_vb)?;
|
||||
|
||||
println!("prompt: '{}'", args.prompt);
|
||||
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
||||
let mut tokens = prompt_tokens.clone();
|
||||
println!("{tokens:?}");
|
||||
let spk_emb_file = match &args.spk_emb {
|
||||
Some(w) => std::path::PathBuf::from(w),
|
||||
None => repo.get("spk_emb.safetensors")?,
|
||||
};
|
||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||
let spk_emb = match spk_emb.get("spk_emb") {
|
||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
|
||||
};
|
||||
let spk_emb = spk_emb.to_device(&device)?;
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||
|
||||
// First stage generation.
|
||||
for index in 0..args.max_tokens {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &device)?;
|
||||
let input = Tensor::stack(&[&input, &input], 0)?;
|
||||
let logits = match &mut first_stage_model {
|
||||
Transformer::Normal(m) => m.forward(&input, &spk_emb, tokens.len() - context_size)?,
|
||||
Transformer::Quantized(m) => {
|
||||
m.forward(&input, &spk_emb, tokens.len() - context_size)?
|
||||
}
|
||||
};
|
||||
let logits0 = logits.i((0, 0))?;
|
||||
let logits1 = logits.i((1, 0))?;
|
||||
let logits = ((logits0 * args.guidance_scale)? + logits1 * (1. - args.guidance_scale))?;
|
||||
let logits = logits.to_dtype(DType::F32)?;
|
||||
let next_token = logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
print!(".");
|
||||
std::io::stdout().flush()?;
|
||||
if next_token == 2048 {
|
||||
break;
|
||||
}
|
||||
}
|
||||
println!();
|
||||
let fie2c = adapters::FlattenedInterleavedEncodec2Codebook::new(ENCODEC_NTOKENS);
|
||||
let (text_ids, ids1, ids2) = fie2c.decode(&tokens);
|
||||
println!("text ids len: {}", text_ids.len());
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
||||
// TODO: Use the config rather than hardcoding the offset here.
|
||||
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
||||
let mut hierarchies_in1 =
|
||||
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
||||
let mut hierarchies_in2 = [
|
||||
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
||||
ids2.as_slice(),
|
||||
&[ENCODEC_NTOKENS],
|
||||
]
|
||||
.concat();
|
||||
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||
let logits = second_stage_model.forward(&in_x)?;
|
||||
println!("sampling from logits...");
|
||||
let mut codes = vec![];
|
||||
for logits in logits.iter() {
|
||||
let logits = logits.squeeze(0)?;
|
||||
let (seq_len, _) = logits.dims2()?;
|
||||
let mut codes_ = Vec::with_capacity(seq_len);
|
||||
for step in 0..seq_len {
|
||||
let logits = logits.i(step)?.to_dtype(DType::F32)?;
|
||||
let logits = &(&logits / 1.0)?;
|
||||
let prs = candle_nn::ops::softmax_last_dim(logits)?.to_vec1::<f32>()?;
|
||||
let distr = rand::distributions::WeightedIndex::new(prs.as_slice())?;
|
||||
let sample = distr.sample(&mut rng) as u32;
|
||||
codes_.push(sample)
|
||||
}
|
||||
codes.push(codes_)
|
||||
}
|
||||
|
||||
let codes = Tensor::new(codes, &device)?.unsqueeze(0)?;
|
||||
let codes = Tensor::cat(&[in_x, codes], 1)?;
|
||||
println!("codes: {codes}");
|
||||
let tilted_encodec = adapters::TiltedEncodec::new(ENCODEC_NTOKENS);
|
||||
let codes = codes.i(0)?.to_vec2::<u32>()?;
|
||||
let (text_ids, audio_ids) = tilted_encodec.decode(&codes);
|
||||
println!("text_ids len: {:?}", text_ids.len());
|
||||
let audio_ids = Tensor::new(audio_ids, encodec_device)?.unsqueeze(0)?;
|
||||
println!("audio_ids shape: {:?}", audio_ids.shape());
|
||||
let pcm = encodec_model.decode(&audio_ids)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||
Ok(())
|
||||
}
|
@ -63,7 +63,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
580
candle-examples/examples/musicgen/encodec_model.rs
Normal file
@ -0,0 +1,580 @@
|
||||
use crate::nn::conv1d_weight_norm;
|
||||
use candle::{DType, IndexOp, Module, Result, Tensor};
|
||||
use candle_nn::{conv1d, Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Encodec Model
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/models/encodec/modeling_encodec.py
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
enum NormType {
|
||||
WeightNorm,
|
||||
TimeGroupNorm,
|
||||
None,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub struct Config {
|
||||
target_bandwidths: Vec<f64>,
|
||||
sampling_rate: usize,
|
||||
audio_channels: usize,
|
||||
normalize: bool,
|
||||
chunk_length_s: Option<usize>,
|
||||
overlap: Option<usize>,
|
||||
hidden_size: usize,
|
||||
num_filters: usize,
|
||||
num_residual_layers: usize,
|
||||
upsampling_ratios: Vec<usize>,
|
||||
norm_type: NormType,
|
||||
kernel_size: usize,
|
||||
last_kernel_size: usize,
|
||||
residual_kernel_size: usize,
|
||||
dilation_growth_rate: usize,
|
||||
use_causal_conv: bool,
|
||||
pad_mode: &'static str,
|
||||
compress: usize,
|
||||
num_lstm_layers: usize,
|
||||
trim_right_ratio: f64,
|
||||
codebook_size: usize,
|
||||
codebook_dim: Option<usize>,
|
||||
use_conv_shortcut: bool,
|
||||
}
|
||||
|
||||
impl Default for Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
target_bandwidths: vec![1.5, 3.0, 6.0, 12.0, 24.0],
|
||||
sampling_rate: 24_000,
|
||||
audio_channels: 1,
|
||||
normalize: false,
|
||||
chunk_length_s: None,
|
||||
overlap: None,
|
||||
hidden_size: 128,
|
||||
num_filters: 32,
|
||||
num_residual_layers: 1,
|
||||
upsampling_ratios: vec![8, 5, 4, 2],
|
||||
norm_type: NormType::WeightNorm,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
residual_kernel_size: 3,
|
||||
dilation_growth_rate: 2,
|
||||
use_causal_conv: true,
|
||||
pad_mode: "reflect",
|
||||
compress: 2,
|
||||
num_lstm_layers: 2,
|
||||
trim_right_ratio: 1.0,
|
||||
codebook_size: 1024,
|
||||
codebook_dim: None,
|
||||
use_conv_shortcut: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
pub fn musicgen_small() -> Self {
|
||||
Self {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
pad_mode: "reflect",
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
}
|
||||
}
|
||||
|
||||
fn codebook_dim(&self) -> usize {
|
||||
self.codebook_dim.unwrap_or(self.codebook_size)
|
||||
}
|
||||
|
||||
fn frame_rate(&self) -> usize {
|
||||
let hop_length: usize = self.upsampling_ratios.iter().product();
|
||||
(self.sampling_rate + hop_length - 1) / hop_length
|
||||
}
|
||||
|
||||
fn num_quantizers(&self) -> usize {
|
||||
let num = 1000f64
|
||||
* self
|
||||
.target_bandwidths
|
||||
.last()
|
||||
.expect("empty target_bandwidths");
|
||||
(num as usize) / (self.frame_rate() * 10)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L340
|
||||
#[derive(Debug)]
|
||||
struct EncodecEuclideanCodebook {
|
||||
inited: Tensor,
|
||||
cluster_size: Tensor,
|
||||
embed: Tensor,
|
||||
embed_avg: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecEuclideanCodebook {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let inited = vb.get(1, "inited")?;
|
||||
let cluster_size = vb.get(cfg.codebook_size, "cluster_size")?;
|
||||
let e_shape = (cfg.codebook_size, cfg.codebook_dim());
|
||||
let embed = vb.get(e_shape, "embed")?;
|
||||
let embed_avg = vb.get(e_shape, "embed_avg")?;
|
||||
Ok(Self {
|
||||
inited,
|
||||
cluster_size,
|
||||
embed,
|
||||
embed_avg,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.embed.embedding(embed_ind)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecVectorQuantization {
|
||||
codebook: EncodecEuclideanCodebook,
|
||||
}
|
||||
|
||||
impl EncodecVectorQuantization {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let codebook = EncodecEuclideanCodebook::load(vb.pp("codebook"), cfg)?;
|
||||
Ok(Self { codebook })
|
||||
}
|
||||
|
||||
fn decode(&self, embed_ind: &Tensor) -> Result<Tensor> {
|
||||
let quantize = self.codebook.decode(embed_ind)?;
|
||||
let quantize = quantize.transpose(1, 2)?;
|
||||
Ok(quantize)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResidualVectorQuantizer {
|
||||
layers: Vec<EncodecVectorQuantization>,
|
||||
}
|
||||
|
||||
impl EncodecResidualVectorQuantizer {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("layers");
|
||||
let layers = (0..cfg.num_quantizers())
|
||||
.map(|i| EncodecVectorQuantization::load(vb.pp(&i.to_string()), cfg))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { layers })
|
||||
}
|
||||
|
||||
fn decode(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut quantized_out = Tensor::zeros((), DType::F32, codes.device())?;
|
||||
if codes.dim(0)? != self.layers.len() {
|
||||
candle::bail!(
|
||||
"codes shape {:?} does not match the number of quantization layers {}",
|
||||
codes.shape(),
|
||||
self.layers.len()
|
||||
)
|
||||
}
|
||||
for (i, layer) in self.layers.iter().enumerate() {
|
||||
let quantized = layer.decode(&codes.i(i)?)?;
|
||||
quantized_out = quantized.broadcast_add(&quantized_out)?;
|
||||
}
|
||||
Ok(quantized_out)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/abaca9f9432a84cfaa95531de4c72334f38a42f2/src/transformers/models/encodec/modeling_encodec.py#L226
|
||||
#[derive(Debug)]
|
||||
struct EncodecLSTM {
|
||||
layers: Vec<candle_nn::LSTM>,
|
||||
}
|
||||
|
||||
impl EncodecLSTM {
|
||||
fn load(dim: usize, vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let vb = &vb.pp("lstm");
|
||||
let mut layers = vec![];
|
||||
for layer_idx in 0..cfg.num_lstm_layers {
|
||||
let config = candle_nn::LSTMConfig {
|
||||
layer_idx,
|
||||
..Default::default()
|
||||
};
|
||||
let lstm = candle_nn::lstm(dim, dim, config, vb.clone())?;
|
||||
layers.push(lstm)
|
||||
}
|
||||
Ok(Self { layers })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecLSTM {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
use candle_nn::RNN;
|
||||
let mut xs = xs.clone();
|
||||
for layer in self.layers.iter() {
|
||||
let states = layer.seq(&xs)?;
|
||||
xs = layer.states_to_tensor(&states)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConvTranspose1d {
|
||||
weight_g: Tensor,
|
||||
weight_v: Tensor,
|
||||
bias: Tensor,
|
||||
}
|
||||
|
||||
impl EncodecConvTranspose1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
k: usize,
|
||||
_stride: usize,
|
||||
vb: VarBuilder,
|
||||
_cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("conv");
|
||||
let weight_g = vb.get((in_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((in_c, out_c, k), "weight_v")?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Self {
|
||||
weight_g,
|
||||
weight_v,
|
||||
bias,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConvTranspose1d {
|
||||
fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecConv1d {
|
||||
causal: bool,
|
||||
conv: Conv1d,
|
||||
norm: Option<candle_nn::GroupNorm>,
|
||||
}
|
||||
|
||||
impl EncodecConv1d {
|
||||
fn load(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
cfg: &Config,
|
||||
) -> Result<Self> {
|
||||
let conv = match cfg.norm_type {
|
||||
NormType::WeightNorm => conv1d_weight_norm(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
NormType::None | NormType::TimeGroupNorm => conv1d(
|
||||
in_c,
|
||||
out_c,
|
||||
kernel_size,
|
||||
Conv1dConfig {
|
||||
padding: 0,
|
||||
stride,
|
||||
groups: 1,
|
||||
dilation: 1,
|
||||
},
|
||||
vb.pp("conv"),
|
||||
)?,
|
||||
};
|
||||
let norm = match cfg.norm_type {
|
||||
NormType::None | NormType::WeightNorm => None,
|
||||
NormType::TimeGroupNorm => {
|
||||
let gn = candle_nn::group_norm(1, out_c, 1e-5, vb.pp("norm"))?;
|
||||
Some(gn)
|
||||
}
|
||||
};
|
||||
Ok(Self {
|
||||
causal: cfg.use_causal_conv,
|
||||
conv,
|
||||
norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecConv1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
// TODO: padding, depending on causal.
|
||||
let xs = self.conv.forward(xs)?;
|
||||
match &self.norm {
|
||||
None => Ok(xs),
|
||||
Some(norm) => xs.apply(norm),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecResnetBlock {
|
||||
block_conv1: EncodecConv1d,
|
||||
block_conv2: EncodecConv1d,
|
||||
shortcut: Option<EncodecConv1d>,
|
||||
}
|
||||
|
||||
impl EncodecResnetBlock {
|
||||
fn load(dim: usize, dilations: &[usize], vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let h = dim / cfg.compress;
|
||||
let mut layer = Layer::new(vb.pp("block"));
|
||||
if dilations.len() != 2 {
|
||||
candle::bail!("expected dilations of size 2")
|
||||
}
|
||||
// TODO: Apply dilations!
|
||||
layer.inc();
|
||||
let block_conv1 =
|
||||
EncodecConv1d::load(dim, h, cfg.residual_kernel_size, 1, layer.next(), cfg)?;
|
||||
layer.inc();
|
||||
let block_conv2 = EncodecConv1d::load(h, dim, 1, 1, layer.next(), cfg)?;
|
||||
let shortcut = if cfg.use_conv_shortcut {
|
||||
let conv = EncodecConv1d::load(dim, dim, 1, 1, vb.pp("shortcut"), cfg)?;
|
||||
Some(conv)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
block_conv1,
|
||||
block_conv2,
|
||||
shortcut,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for EncodecResnetBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let residual = xs.clone();
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv1.forward(&xs)?;
|
||||
let xs = xs.elu(1.)?;
|
||||
let xs = self.block_conv2.forward(&xs)?;
|
||||
let xs = match &self.shortcut {
|
||||
None => (xs + residual)?,
|
||||
Some(shortcut) => xs.add(&shortcut.forward(&residual)?)?,
|
||||
};
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
struct Layer<'a> {
|
||||
vb: VarBuilder<'a>,
|
||||
cnt: usize,
|
||||
}
|
||||
|
||||
impl<'a> Layer<'a> {
|
||||
fn new(vb: VarBuilder<'a>) -> Self {
|
||||
Self { vb, cnt: 0 }
|
||||
}
|
||||
|
||||
fn inc(&mut self) {
|
||||
self.cnt += 1;
|
||||
}
|
||||
|
||||
fn next(&mut self) -> VarBuilder {
|
||||
let vb = self.vb.pp(&self.cnt.to_string());
|
||||
self.cnt += 1;
|
||||
vb
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecEncoder {
|
||||
init_conv: EncodecConv1d,
|
||||
sampling_layers: Vec<(Vec<EncodecResnetBlock>, EncodecConv1d)>,
|
||||
final_lstm: EncodecLSTM,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecEncoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.audio_channels,
|
||||
cfg.num_filters,
|
||||
cfg.kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut sampling_layers = vec![];
|
||||
let mut scaling = 1;
|
||||
for &ratio in cfg.upsampling_ratios.iter().rev() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConv1d::load(
|
||||
current_scale,
|
||||
current_scale * 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
sampling_layers.push((resnets, conv1d));
|
||||
scaling *= 2;
|
||||
}
|
||||
let final_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters * scaling,
|
||||
cfg.hidden_size,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
final_lstm,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?;
|
||||
for (resnets, conv) in self.sampling_layers.iter() {
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?;
|
||||
}
|
||||
xs = xs.elu(1.0)?.apply(conv)?;
|
||||
}
|
||||
xs.apply(&self.final_lstm)?
|
||||
.elu(1.0)?
|
||||
.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct EncodecDecoder {
|
||||
init_conv: EncodecConv1d,
|
||||
init_lstm: EncodecLSTM,
|
||||
sampling_layers: Vec<(EncodecConvTranspose1d, Vec<EncodecResnetBlock>)>,
|
||||
final_conv: EncodecConv1d,
|
||||
}
|
||||
|
||||
impl EncodecDecoder {
|
||||
fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let mut layer = Layer::new(vb.pp("layers"));
|
||||
let mut scaling = usize::pow(2, cfg.upsampling_ratios.len() as u32);
|
||||
let init_conv = EncodecConv1d::load(
|
||||
cfg.hidden_size,
|
||||
cfg.num_filters * scaling,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let init_lstm = EncodecLSTM::load(cfg.num_filters * scaling, layer.next(), cfg)?;
|
||||
let mut sampling_layers = vec![];
|
||||
for &ratio in cfg.upsampling_ratios.iter() {
|
||||
let current_scale = scaling * cfg.num_filters;
|
||||
layer.inc(); // ELU
|
||||
let conv1d = EncodecConvTranspose1d::load(
|
||||
current_scale,
|
||||
current_scale / 2,
|
||||
ratio * 2,
|
||||
ratio,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
let mut resnets = vec![];
|
||||
for j in 0..(cfg.num_residual_layers as u32) {
|
||||
let resnet = EncodecResnetBlock::load(
|
||||
current_scale / 2,
|
||||
&[cfg.dilation_growth_rate.pow(j), 1],
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
resnets.push(resnet)
|
||||
}
|
||||
sampling_layers.push((conv1d, resnets));
|
||||
scaling /= 2;
|
||||
}
|
||||
layer.inc(); // ELU
|
||||
let final_conv = EncodecConv1d::load(
|
||||
cfg.num_filters,
|
||||
cfg.audio_channels,
|
||||
cfg.last_kernel_size,
|
||||
1,
|
||||
layer.next(),
|
||||
cfg,
|
||||
)?;
|
||||
Ok(Self {
|
||||
init_conv,
|
||||
init_lstm,
|
||||
sampling_layers,
|
||||
final_conv,
|
||||
})
|
||||
}
|
||||
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.init_conv)?.apply(&self.init_lstm)?;
|
||||
for (conv, resnets) in self.sampling_layers.iter() {
|
||||
xs = xs.elu(1.)?.apply(conv)?;
|
||||
for resnet in resnets.iter() {
|
||||
xs = xs.apply(resnet)?
|
||||
}
|
||||
}
|
||||
xs.elu(1.)?.apply(&self.final_conv)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct EncodecModel {
|
||||
encoder: EncodecEncoder,
|
||||
decoder: EncodecDecoder,
|
||||
quantizer: EncodecResidualVectorQuantizer,
|
||||
}
|
||||
|
||||
impl EncodecModel {
|
||||
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
|
||||
let encoder = EncodecEncoder::load(vb.pp("encoder"), cfg)?;
|
||||
let decoder = EncodecDecoder::load(vb.pp("decoder"), cfg)?;
|
||||
let quantizer = EncodecResidualVectorQuantizer::load(vb.pp("quantizer"), cfg)?;
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn forward(&self, _xs: &Tensor) -> Result<Tensor> {
|
||||
todo!()
|
||||
}
|
||||
}
|
@ -10,7 +10,9 @@ extern crate intel_mkl_src;
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
mod encodec_model;
|
||||
mod musicgen_model;
|
||||
mod nn;
|
||||
|
||||
use musicgen_model::{GenConfig, MusicgenForConditionalGeneration};
|
||||
|
||||
|
@ -1,9 +1,10 @@
|
||||
use crate::encodec_model;
|
||||
use candle::{DType, Device, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
embedding, layer_norm, linear_no_bias, Activation, Embedding, LayerNorm, Linear, Module,
|
||||
VarBuilder,
|
||||
};
|
||||
use candle_transformers::models::{encodec, t5};
|
||||
use candle_transformers::models::t5;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/cd4584e3c809bb9e1392ccd3fe38b40daba5519a/src/transformers/models/musicgen/configuration_musicgen.py#L83
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
@ -371,7 +372,7 @@ impl MusicgenForCausalLM {
|
||||
#[derive(Debug)]
|
||||
pub struct MusicgenForConditionalGeneration {
|
||||
pub text_encoder: t5::T5EncoderModel,
|
||||
pub audio_encoder: encodec::Model,
|
||||
pub audio_encoder: crate::encodec_model::EncodecModel,
|
||||
pub decoder: MusicgenForCausalLM,
|
||||
cfg: GenConfig,
|
||||
}
|
||||
@ -380,42 +381,15 @@ pub struct MusicgenForConditionalGeneration {
|
||||
pub struct GenConfig {
|
||||
musicgen: Config,
|
||||
t5: t5::Config,
|
||||
encodec: encodec::Config,
|
||||
encodec: crate::encodec_model::Config,
|
||||
}
|
||||
|
||||
impl GenConfig {
|
||||
pub fn small() -> Self {
|
||||
// https://huggingface.co/facebook/musicgen-small/blob/495da4ad086b3416a27c6187f9239f9fd96f3962/config.json#L6
|
||||
let encodec = encodec::Config {
|
||||
audio_channels: 1,
|
||||
chunk_length_s: None,
|
||||
codebook_dim: Some(128),
|
||||
codebook_size: 2048,
|
||||
compress: 2,
|
||||
dilation_growth_rate: 2,
|
||||
hidden_size: 128,
|
||||
kernel_size: 7,
|
||||
last_kernel_size: 7,
|
||||
norm_type: encodec::NormType::WeightNorm,
|
||||
normalize: false,
|
||||
num_filters: 64,
|
||||
num_lstm_layers: 2,
|
||||
num_residual_layers: 1,
|
||||
overlap: None,
|
||||
// This should be Reflect and not Replicate but Reflect does not work yet.
|
||||
pad_mode: encodec::PadMode::Replicate,
|
||||
residual_kernel_size: 3,
|
||||
sampling_rate: 32_000,
|
||||
target_bandwidths: vec![2.2],
|
||||
trim_right_ratio: 1.0,
|
||||
upsampling_ratios: vec![8, 5, 4, 4],
|
||||
use_causal_conv: false,
|
||||
use_conv_shortcut: false,
|
||||
};
|
||||
Self {
|
||||
musicgen: Config::musicgen_small(),
|
||||
t5: t5::Config::musicgen_small(),
|
||||
encodec,
|
||||
encodec: encodec_model::Config::musicgen_small(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -427,7 +401,8 @@ impl MusicgenForConditionalGeneration {
|
||||
|
||||
pub fn load(vb: VarBuilder, cfg: GenConfig) -> Result<Self> {
|
||||
let text_encoder = t5::T5EncoderModel::load(vb.pp("text_encoder"), &cfg.t5)?;
|
||||
let audio_encoder = encodec::Model::new(&cfg.encodec, vb.pp("audio_encoder"))?;
|
||||
let audio_encoder =
|
||||
encodec_model::EncodecModel::load(vb.pp("audio_encoder"), &cfg.encodec)?;
|
||||
let decoder = MusicgenForCausalLM::load(vb.pp("decoder"), &cfg.musicgen)?;
|
||||
Ok(Self {
|
||||
text_encoder,
|
||||
|
20
candle-examples/examples/musicgen/nn.rs
Normal file
20
candle-examples/examples/musicgen/nn.rs
Normal file
@ -0,0 +1,20 @@
|
||||
use candle::Result;
|
||||
use candle_nn::{Conv1d, Conv1dConfig, VarBuilder};
|
||||
|
||||
// Applies weight norm for inference by recomputing the weight tensor. This
|
||||
// does not apply to training.
|
||||
// https://pytorch.org/docs/stable/generated/torch.nn.utils.weight_norm.html
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
@ -212,14 +212,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
verbose_prompt: bool,
|
||||
|
||||
/// Process prompt elements separately.
|
||||
#[arg(long)]
|
||||
split_prompt: bool,
|
||||
|
||||
/// Run on CPU rather than GPU even if a GPU is available.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -369,7 +361,7 @@ fn main() -> anyhow::Result<()> {
|
||||
let model_path = args.model()?;
|
||||
let mut file = std::fs::File::open(&model_path)?;
|
||||
let start = std::time::Instant::now();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let device = candle_examples::device(false)?;
|
||||
|
||||
let mut model = match model_path.extension().and_then(|v| v.to_str()) {
|
||||
Some("gguf") => {
|
||||
@ -495,20 +487,11 @@ fn main() -> anyhow::Result<()> {
|
||||
let mut logits_processor = LogitsProcessor::new(args.seed, temperature, args.top_p);
|
||||
|
||||
let start_prompt_processing = std::time::Instant::now();
|
||||
let mut next_token = if !args.split_prompt {
|
||||
let mut next_token = {
|
||||
let input = Tensor::new(prompt_tokens.as_slice(), &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, 0)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
logits_processor.sample(&logits)?
|
||||
} else {
|
||||
let mut next_token = 0;
|
||||
for (pos, token) in prompt_tokens.iter().enumerate() {
|
||||
let input = Tensor::new(&[*token], &device)?.unsqueeze(0)?;
|
||||
let logits = model.forward(&input, pos)?;
|
||||
let logits = logits.squeeze(0)?;
|
||||
next_token = logits_processor.sample(&logits)?
|
||||
}
|
||||
next_token
|
||||
};
|
||||
let prompt_dt = start_prompt_processing.elapsed();
|
||||
all_tokens.push(next_token);
|
||||
|
@ -78,7 +78,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -45,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
The [RWKV model](https://wiki.rwkv.com/) is a recurrent neural network model
|
||||
with performance on par with transformer architectures. Several variants are
|
||||
available, candle implements the v5 and v6 versions and can be used with
|
||||
Eagle 7B([blog post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
available, candle implements the v5 version and can be used with Eagle 7B([blog
|
||||
post](https://blog.rwkv.com/p/eagle-7b-soaring-past-transformers)).
|
||||
|
||||
```bash
|
||||
$ cargo run --example rwkv --release -- --prompt "The smallest prime is "
|
||||
|
@ -7,36 +7,13 @@ extern crate accelerate_src;
|
||||
use anyhow::Result;
|
||||
use clap::{Parser, ValueEnum};
|
||||
|
||||
use candle_transformers::models::quantized_rwkv_v5::Model as Q5;
|
||||
use candle_transformers::models::quantized_rwkv_v6::Model as Q6;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model as M5, State, Tokenizer};
|
||||
use candle_transformers::models::rwkv_v6::Model as M6;
|
||||
use candle_transformers::models::rwkv_v5::{Config, Model, State, Tokenizer};
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
|
||||
const EOS_TOKEN_ID: u32 = 261;
|
||||
|
||||
enum Model {
|
||||
M5(M5),
|
||||
Q5(Q5),
|
||||
M6(M6),
|
||||
Q6(Q6),
|
||||
}
|
||||
|
||||
impl Model {
|
||||
fn forward(&self, xs: &Tensor, state: &mut State) -> candle::Result<Tensor> {
|
||||
match self {
|
||||
Self::M5(m) => m.forward(xs, state),
|
||||
Self::Q5(m) => m.forward(xs, state),
|
||||
Self::M6(m) => m.forward(xs, state),
|
||||
Self::Q6(m) => m.forward(xs, state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
config: Config,
|
||||
@ -106,9 +83,6 @@ impl TextGeneration {
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == EOS_TOKEN_ID || next_token == 0 {
|
||||
break;
|
||||
}
|
||||
print!("{}", self.tokenizer.decode(&[next_token])?);
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
@ -129,7 +103,6 @@ enum Which {
|
||||
Eagle7b,
|
||||
World1b5,
|
||||
World3b,
|
||||
World6_1b6,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for Which {
|
||||
@ -141,10 +114,9 @@ impl std::fmt::Display for Which {
|
||||
impl Which {
|
||||
fn model_id(&self) -> &'static str {
|
||||
match self {
|
||||
Self::Eagle7b => "RWKV/v5-Eagle-7B-HF",
|
||||
Self::Eagle7b => "RWKV/HF_v5-Eagle-7B",
|
||||
Self::World1b5 => "RWKV/rwkv-5-world-1b5",
|
||||
Self::World3b => "RWKV/rwkv-5-world-3b",
|
||||
Self::World6_1b6 => "paperfun/rwkv",
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,7 +124,6 @@ impl Which {
|
||||
match self {
|
||||
Self::Eagle7b => "refs/pr/1",
|
||||
Self::World1b5 | Self::World3b => "refs/pr/2",
|
||||
Self::World6_1b6 => "main",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -205,9 +176,6 @@ struct Args {
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
quantized: bool,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
@ -268,27 +236,7 @@ fn main() -> Result<()> {
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => {
|
||||
if args.quantized {
|
||||
vec![match args.which {
|
||||
Which::World1b5 => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world1b5-q4k.gguf")?,
|
||||
Which::World3b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("world3b-q4k.gguf")?,
|
||||
Which::Eagle7b => api
|
||||
.model("lmz/candle-rwkv".to_string())
|
||||
.get("eagle7b-q4k.gguf")?,
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6-q4k.gguf")?,
|
||||
}]
|
||||
} else {
|
||||
vec![match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => {
|
||||
repo.get("model.safetensors")?
|
||||
}
|
||||
Which::World6_1b6 => repo.get("rwkv-6-world-1b6.safetensors")?,
|
||||
}]
|
||||
}
|
||||
vec![repo.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
@ -297,21 +245,8 @@ fn main() -> Result<()> {
|
||||
let start = std::time::Instant::now();
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model = if args.quantized {
|
||||
let filename = &filenames[0];
|
||||
let vb =
|
||||
candle_transformers::quantized_var_builder::VarBuilder::from_gguf(filename, &device)?;
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::Q5(Q5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::Q6(Q6::new(&config, vb)?),
|
||||
}
|
||||
} else {
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
match args.which {
|
||||
Which::World1b5 | Which::World3b | Which::Eagle7b => Model::M5(M5::new(&config, vb)?),
|
||||
Which::World6_1b6 => Model::M6(M6::new(&config, vb)?),
|
||||
}
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
|
@ -1,28 +0,0 @@
|
||||
# candle-segformer
|
||||
|
||||
- [HuggingFace Segformer Model Card][segformer]
|
||||
- [`mit-b0` - An encoder only pretrained model][encoder]
|
||||
- [`segformer-b0-finetuned-ade-512-512` - A fine tuned model for segmentation][ade512]
|
||||
|
||||
## How to run the example
|
||||
|
||||
If you want you can use the example images from this [pull request][pr], download them and supply the path to the image as an argument to the example.
|
||||
|
||||
```bash
|
||||
# run the image classification task
|
||||
cargo run --example segformer classify <path-to-image>
|
||||
# run the segmentation task
|
||||
cargo run --example segformer segment <path-to-image>
|
||||
```
|
||||
|
||||
Example output for classification:
|
||||
|
||||
```text
|
||||
classification logits [3.275261e-5, 0.0008562019, 0.0008868563, 0.9977506, 0.0002465068, 0.0002241473, 2.846596e-6]
|
||||
label: hamburger
|
||||
```
|
||||
|
||||
[pr]: https://github.com/huggingface/candle/pull/1617
|
||||
[segformer]: https://huggingface.co/docs/transformers/model_doc/segformer
|
||||
[encoder]: https://huggingface.co/nvidia/mit-b0
|
||||
[ade512]: https://huggingface.co/nvidia/segformer-b0-finetuned-ade-512-512
|
@ -1,752 +0,0 @@
|
||||
[
|
||||
{
|
||||
"index": 1,
|
||||
"color": "#787878",
|
||||
"label": "wall"
|
||||
},
|
||||
{
|
||||
"index": 2,
|
||||
"color": "#B47878",
|
||||
"label": "building;edifice"
|
||||
},
|
||||
{
|
||||
"index": 3,
|
||||
"color": "#06E6E6",
|
||||
"label": "sky"
|
||||
},
|
||||
{
|
||||
"index": 4,
|
||||
"color": "#503232",
|
||||
"label": "floor;flooring"
|
||||
},
|
||||
{
|
||||
"index": 5,
|
||||
"color": "#04C803",
|
||||
"label": "tree"
|
||||
},
|
||||
{
|
||||
"index": 6,
|
||||
"color": "#787850",
|
||||
"label": "ceiling"
|
||||
},
|
||||
{
|
||||
"index": 7,
|
||||
"color": "#8C8C8C",
|
||||
"label": "road;route"
|
||||
},
|
||||
{
|
||||
"index": 8,
|
||||
"color": "#CC05FF",
|
||||
"label": "bed"
|
||||
},
|
||||
{
|
||||
"index": 9,
|
||||
"color": "#E6E6E6",
|
||||
"label": "windowpane;window"
|
||||
},
|
||||
{
|
||||
"index": 10,
|
||||
"color": "#04FA07",
|
||||
"label": "grass"
|
||||
},
|
||||
{
|
||||
"index": 11,
|
||||
"color": "#E005FF",
|
||||
"label": "cabinet"
|
||||
},
|
||||
{
|
||||
"index": 12,
|
||||
"color": "#EBFF07",
|
||||
"label": "sidewalk;pavement"
|
||||
},
|
||||
{
|
||||
"index": 13,
|
||||
"color": "#96053D",
|
||||
"label": "person;individual;someone;somebody;mortal;soul"
|
||||
},
|
||||
{
|
||||
"index": 14,
|
||||
"color": "#787846",
|
||||
"label": "earth;ground"
|
||||
},
|
||||
{
|
||||
"index": 15,
|
||||
"color": "#08FF33",
|
||||
"label": "door;double;door"
|
||||
},
|
||||
{
|
||||
"index": 16,
|
||||
"color": "#FF0652",
|
||||
"label": "table"
|
||||
},
|
||||
{
|
||||
"index": 17,
|
||||
"color": "#8FFF8C",
|
||||
"label": "mountain;mount"
|
||||
},
|
||||
{
|
||||
"index": 18,
|
||||
"color": "#CCFF04",
|
||||
"label": "plant;flora;plant;life"
|
||||
},
|
||||
{
|
||||
"index": 19,
|
||||
"color": "#FF3307",
|
||||
"label": "curtain;drape;drapery;mantle;pall"
|
||||
},
|
||||
{
|
||||
"index": 20,
|
||||
"color": "#CC4603",
|
||||
"label": "chair"
|
||||
},
|
||||
{
|
||||
"index": 21,
|
||||
"color": "#0066C8",
|
||||
"label": "car;auto;automobile;machine;motorcar"
|
||||
},
|
||||
{
|
||||
"index": 22,
|
||||
"color": "#3DE6FA",
|
||||
"label": "water"
|
||||
},
|
||||
{
|
||||
"index": 23,
|
||||
"color": "#FF0633",
|
||||
"label": "painting;picture"
|
||||
},
|
||||
{
|
||||
"index": 24,
|
||||
"color": "#0B66FF",
|
||||
"label": "sofa;couch;lounge"
|
||||
},
|
||||
{
|
||||
"index": 25,
|
||||
"color": "#FF0747",
|
||||
"label": "shelf"
|
||||
},
|
||||
{
|
||||
"index": 26,
|
||||
"color": "#FF09E0",
|
||||
"label": "house"
|
||||
},
|
||||
{
|
||||
"index": 27,
|
||||
"color": "#0907E6",
|
||||
"label": "sea"
|
||||
},
|
||||
{
|
||||
"index": 28,
|
||||
"color": "#DCDCDC",
|
||||
"label": "mirror"
|
||||
},
|
||||
{
|
||||
"index": 29,
|
||||
"color": "#FF095C",
|
||||
"label": "rug;carpet;carpeting"
|
||||
},
|
||||
{
|
||||
"index": 30,
|
||||
"color": "#7009FF",
|
||||
"label": "field"
|
||||
},
|
||||
{
|
||||
"index": 31,
|
||||
"color": "#08FFD6",
|
||||
"label": "armchair"
|
||||
},
|
||||
{
|
||||
"index": 32,
|
||||
"color": "#07FFE0",
|
||||
"label": "seat"
|
||||
},
|
||||
{
|
||||
"index": 33,
|
||||
"color": "#FFB806",
|
||||
"label": "fence;fencing"
|
||||
},
|
||||
{
|
||||
"index": 34,
|
||||
"color": "#0AFF47",
|
||||
"label": "desk"
|
||||
},
|
||||
{
|
||||
"index": 35,
|
||||
"color": "#FF290A",
|
||||
"label": "rock;stone"
|
||||
},
|
||||
{
|
||||
"index": 36,
|
||||
"color": "#07FFFF",
|
||||
"label": "wardrobe;closet;press"
|
||||
},
|
||||
{
|
||||
"index": 37,
|
||||
"color": "#E0FF08",
|
||||
"label": "lamp"
|
||||
},
|
||||
{
|
||||
"index": 38,
|
||||
"color": "#6608FF",
|
||||
"label": "bathtub;bathing;tub;bath;tub"
|
||||
},
|
||||
{
|
||||
"index": 39,
|
||||
"color": "#FF3D06",
|
||||
"label": "railing;rail"
|
||||
},
|
||||
{
|
||||
"index": 40,
|
||||
"color": "#FFC207",
|
||||
"label": "cushion"
|
||||
},
|
||||
{
|
||||
"index": 41,
|
||||
"color": "#FF7A08",
|
||||
"label": "base;pedestal;stand"
|
||||
},
|
||||
{
|
||||
"index": 42,
|
||||
"color": "#00FF14",
|
||||
"label": "box"
|
||||
},
|
||||
{
|
||||
"index": 43,
|
||||
"color": "#FF0829",
|
||||
"label": "column;pillar"
|
||||
},
|
||||
{
|
||||
"index": 44,
|
||||
"color": "#FF0599",
|
||||
"label": "signboard;sign"
|
||||
},
|
||||
{
|
||||
"index": 45,
|
||||
"color": "#0633FF",
|
||||
"label": "chest;of;drawers;chest;bureau;dresser"
|
||||
},
|
||||
{
|
||||
"index": 46,
|
||||
"color": "#EB0CFF",
|
||||
"label": "counter"
|
||||
},
|
||||
{
|
||||
"index": 47,
|
||||
"color": "#A09614",
|
||||
"label": "sand"
|
||||
},
|
||||
{
|
||||
"index": 48,
|
||||
"color": "#00A3FF",
|
||||
"label": "sink"
|
||||
},
|
||||
{
|
||||
"index": 49,
|
||||
"color": "#8C8C8C",
|
||||
"label": "skyscraper"
|
||||
},
|
||||
{
|
||||
"index": 50,
|
||||
"color": "#FA0A0F",
|
||||
"label": "fireplace;hearth;open;fireplace"
|
||||
},
|
||||
{
|
||||
"index": 51,
|
||||
"color": "#14FF00",
|
||||
"label": "refrigerator;icebox"
|
||||
},
|
||||
{
|
||||
"index": 52,
|
||||
"color": "#1FFF00",
|
||||
"label": "grandstand;covered;stand"
|
||||
},
|
||||
{
|
||||
"index": 53,
|
||||
"color": "#FF1F00",
|
||||
"label": "path"
|
||||
},
|
||||
{
|
||||
"index": 54,
|
||||
"color": "#FFE000",
|
||||
"label": "stairs;steps"
|
||||
},
|
||||
{
|
||||
"index": 55,
|
||||
"color": "#99FF00",
|
||||
"label": "runway"
|
||||
},
|
||||
{
|
||||
"index": 56,
|
||||
"color": "#0000FF",
|
||||
"label": "case;display;case;showcase;vitrine"
|
||||
},
|
||||
{
|
||||
"index": 57,
|
||||
"color": "#FF4700",
|
||||
"label": "pool;table;billiard;table;snooker;table"
|
||||
},
|
||||
{
|
||||
"index": 58,
|
||||
"color": "#00EBFF",
|
||||
"label": "pillow"
|
||||
},
|
||||
{
|
||||
"index": 59,
|
||||
"color": "#00ADFF",
|
||||
"label": "screen;door;screen"
|
||||
},
|
||||
{
|
||||
"index": 60,
|
||||
"color": "#1F00FF",
|
||||
"label": "stairway;staircase"
|
||||
},
|
||||
{
|
||||
"index": 61,
|
||||
"color": "#0BC8C8",
|
||||
"label": "river"
|
||||
},
|
||||
{
|
||||
"index": 62,
|
||||
"color": "#FF5200",
|
||||
"label": "bridge;span"
|
||||
},
|
||||
{
|
||||
"index": 63,
|
||||
"color": "#00FFF5",
|
||||
"label": "bookcase"
|
||||
},
|
||||
{
|
||||
"index": 64,
|
||||
"color": "#003DFF",
|
||||
"label": "blind;screen"
|
||||
},
|
||||
{
|
||||
"index": 65,
|
||||
"color": "#00FF70",
|
||||
"label": "coffee;table;cocktail;table"
|
||||
},
|
||||
{
|
||||
"index": 66,
|
||||
"color": "#00FF85",
|
||||
"label": "toilet;can;commode;crapper;pot;potty;stool;throne"
|
||||
},
|
||||
{
|
||||
"index": 67,
|
||||
"color": "#FF0000",
|
||||
"label": "flower"
|
||||
},
|
||||
{
|
||||
"index": 68,
|
||||
"color": "#FFA300",
|
||||
"label": "book"
|
||||
},
|
||||
{
|
||||
"index": 69,
|
||||
"color": "#FF6600",
|
||||
"label": "hill"
|
||||
},
|
||||
{
|
||||
"index": 70,
|
||||
"color": "#C2FF00",
|
||||
"label": "bench"
|
||||
},
|
||||
{
|
||||
"index": 71,
|
||||
"color": "#008FFF",
|
||||
"label": "countertop"
|
||||
},
|
||||
{
|
||||
"index": 72,
|
||||
"color": "#33FF00",
|
||||
"label": "stove;kitchen;stove;range;kitchen;range;cooking;stove"
|
||||
},
|
||||
{
|
||||
"index": 73,
|
||||
"color": "#0052FF",
|
||||
"label": "palm;palm;tree"
|
||||
},
|
||||
{
|
||||
"index": 74,
|
||||
"color": "#00FF29",
|
||||
"label": "kitchen;island"
|
||||
},
|
||||
{
|
||||
"index": 75,
|
||||
"color": "#00FFAD",
|
||||
"label": "computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system"
|
||||
},
|
||||
{
|
||||
"index": 76,
|
||||
"color": "#0A00FF",
|
||||
"label": "swivel;chair"
|
||||
},
|
||||
{
|
||||
"index": 77,
|
||||
"color": "#ADFF00",
|
||||
"label": "boat"
|
||||
},
|
||||
{
|
||||
"index": 78,
|
||||
"color": "#00FF99",
|
||||
"label": "bar"
|
||||
},
|
||||
{
|
||||
"index": 79,
|
||||
"color": "#FF5C00",
|
||||
"label": "arcade;machine"
|
||||
},
|
||||
{
|
||||
"index": 80,
|
||||
"color": "#FF00FF",
|
||||
"label": "hovel;hut;hutch;shack;shanty"
|
||||
},
|
||||
{
|
||||
"index": 81,
|
||||
"color": "#FF00F5",
|
||||
"label": "bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle"
|
||||
},
|
||||
{
|
||||
"index": 82,
|
||||
"color": "#FF0066",
|
||||
"label": "towel"
|
||||
},
|
||||
{
|
||||
"index": 83,
|
||||
"color": "#FFAD00",
|
||||
"label": "light;light;source"
|
||||
},
|
||||
{
|
||||
"index": 84,
|
||||
"color": "#FF0014",
|
||||
"label": "truck;motortruck"
|
||||
},
|
||||
{
|
||||
"index": 85,
|
||||
"color": "#FFB8B8",
|
||||
"label": "tower"
|
||||
},
|
||||
{
|
||||
"index": 86,
|
||||
"color": "#001FFF",
|
||||
"label": "chandelier;pendant;pendent"
|
||||
},
|
||||
{
|
||||
"index": 87,
|
||||
"color": "#00FF3D",
|
||||
"label": "awning;sunshade;sunblind"
|
||||
},
|
||||
{
|
||||
"index": 88,
|
||||
"color": "#0047FF",
|
||||
"label": "streetlight;street;lamp"
|
||||
},
|
||||
{
|
||||
"index": 89,
|
||||
"color": "#FF00CC",
|
||||
"label": "booth;cubicle;stall;kiosk"
|
||||
},
|
||||
{
|
||||
"index": 90,
|
||||
"color": "#00FFC2",
|
||||
"label": "television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box"
|
||||
},
|
||||
{
|
||||
"index": 91,
|
||||
"color": "#00FF52",
|
||||
"label": "airplane;aeroplane;plane"
|
||||
},
|
||||
{
|
||||
"index": 92,
|
||||
"color": "#000AFF",
|
||||
"label": "dirt;track"
|
||||
},
|
||||
{
|
||||
"index": 93,
|
||||
"color": "#0070FF",
|
||||
"label": "apparel;wearing;apparel;dress;clothes"
|
||||
},
|
||||
{
|
||||
"index": 94,
|
||||
"color": "#3300FF",
|
||||
"label": "pole"
|
||||
},
|
||||
{
|
||||
"index": 95,
|
||||
"color": "#00C2FF",
|
||||
"label": "land;ground;soil"
|
||||
},
|
||||
{
|
||||
"index": 96,
|
||||
"color": "#007AFF",
|
||||
"label": "bannister;banister;balustrade;balusters;handrail"
|
||||
},
|
||||
{
|
||||
"index": 97,
|
||||
"color": "#00FFA3",
|
||||
"label": "escalator;moving;staircase;moving;stairway"
|
||||
},
|
||||
{
|
||||
"index": 98,
|
||||
"color": "#FF9900",
|
||||
"label": "ottoman;pouf;pouffe;puff;hassock"
|
||||
},
|
||||
{
|
||||
"index": 99,
|
||||
"color": "#00FF0A",
|
||||
"label": "bottle"
|
||||
},
|
||||
{
|
||||
"index": 100,
|
||||
"color": "#FF7000",
|
||||
"label": "buffet;counter;sideboard"
|
||||
},
|
||||
{
|
||||
"index": 101,
|
||||
"color": "#8FFF00",
|
||||
"label": "poster;posting;placard;notice;bill;card"
|
||||
},
|
||||
{
|
||||
"index": 102,
|
||||
"color": "#5200FF",
|
||||
"label": "stage"
|
||||
},
|
||||
{
|
||||
"index": 103,
|
||||
"color": "#A3FF00",
|
||||
"label": "van"
|
||||
},
|
||||
{
|
||||
"index": 104,
|
||||
"color": "#FFEB00",
|
||||
"label": "ship"
|
||||
},
|
||||
{
|
||||
"index": 105,
|
||||
"color": "#08B8AA",
|
||||
"label": "fountain"
|
||||
},
|
||||
{
|
||||
"index": 106,
|
||||
"color": "#8500FF",
|
||||
"label": "conveyer;belt;conveyor;belt;conveyer;conveyor;transporter"
|
||||
},
|
||||
{
|
||||
"index": 107,
|
||||
"color": "#00FF5C",
|
||||
"label": "canopy"
|
||||
},
|
||||
{
|
||||
"index": 108,
|
||||
"color": "#B800FF",
|
||||
"label": "washer;automatic;washer;washing;machine"
|
||||
},
|
||||
{
|
||||
"index": 109,
|
||||
"color": "#FF001F",
|
||||
"label": "plaything;toy"
|
||||
},
|
||||
{
|
||||
"index": 110,
|
||||
"color": "#00B8FF",
|
||||
"label": "swimming;pool;swimming;bath;natatorium"
|
||||
},
|
||||
{
|
||||
"index": 111,
|
||||
"color": "#00D6FF",
|
||||
"label": "stool"
|
||||
},
|
||||
{
|
||||
"index": 112,
|
||||
"color": "#FF0070",
|
||||
"label": "barrel;cask"
|
||||
},
|
||||
{
|
||||
"index": 113,
|
||||
"color": "#5CFF00",
|
||||
"label": "basket;handbasket"
|
||||
},
|
||||
{
|
||||
"index": 114,
|
||||
"color": "#00E0FF",
|
||||
"label": "waterfall;falls"
|
||||
},
|
||||
{
|
||||
"index": 115,
|
||||
"color": "#70E0FF",
|
||||
"label": "tent;collapsible;shelter"
|
||||
},
|
||||
{
|
||||
"index": 116,
|
||||
"color": "#46B8A0",
|
||||
"label": "bag"
|
||||
},
|
||||
{
|
||||
"index": 117,
|
||||
"color": "#A300FF",
|
||||
"label": "minibike;motorbike"
|
||||
},
|
||||
{
|
||||
"index": 118,
|
||||
"color": "#9900FF",
|
||||
"label": "cradle"
|
||||
},
|
||||
{
|
||||
"index": 119,
|
||||
"color": "#47FF00",
|
||||
"label": "oven"
|
||||
},
|
||||
{
|
||||
"index": 120,
|
||||
"color": "#FF00A3",
|
||||
"label": "ball"
|
||||
},
|
||||
{
|
||||
"index": 121,
|
||||
"color": "#FFCC00",
|
||||
"label": "food;solid;food"
|
||||
},
|
||||
{
|
||||
"index": 122,
|
||||
"color": "#FF008F",
|
||||
"label": "step;stair"
|
||||
},
|
||||
{
|
||||
"index": 123,
|
||||
"color": "#00FFEB",
|
||||
"label": "tank;storage;tank"
|
||||
},
|
||||
{
|
||||
"index": 124,
|
||||
"color": "#85FF00",
|
||||
"label": "trade;name;brand;name;brand;marque"
|
||||
},
|
||||
{
|
||||
"index": 125,
|
||||
"color": "#FF00EB",
|
||||
"label": "microwave;microwave;oven"
|
||||
},
|
||||
{
|
||||
"index": 126,
|
||||
"color": "#F500FF",
|
||||
"label": "pot;flowerpot"
|
||||
},
|
||||
{
|
||||
"index": 127,
|
||||
"color": "#FF007A",
|
||||
"label": "animal;animate;being;beast;brute;creature;fauna"
|
||||
},
|
||||
{
|
||||
"index": 128,
|
||||
"color": "#FFF500",
|
||||
"label": "bicycle;bike;wheel;cycle"
|
||||
},
|
||||
{
|
||||
"index": 129,
|
||||
"color": "#0ABED4",
|
||||
"label": "lake"
|
||||
},
|
||||
{
|
||||
"index": 130,
|
||||
"color": "#D6FF00",
|
||||
"label": "dishwasher;dish;washer;dishwashing;machine"
|
||||
},
|
||||
{
|
||||
"index": 131,
|
||||
"color": "#00CCFF",
|
||||
"label": "screen;silver;screen;projection;screen"
|
||||
},
|
||||
{
|
||||
"index": 132,
|
||||
"color": "#1400FF",
|
||||
"label": "blanket;cover"
|
||||
},
|
||||
{
|
||||
"index": 133,
|
||||
"color": "#FFFF00",
|
||||
"label": "sculpture"
|
||||
},
|
||||
{
|
||||
"index": 134,
|
||||
"color": "#0099FF",
|
||||
"label": "hood;exhaust;hood"
|
||||
},
|
||||
{
|
||||
"index": 135,
|
||||
"color": "#0029FF",
|
||||
"label": "sconce"
|
||||
},
|
||||
{
|
||||
"index": 136,
|
||||
"color": "#00FFCC",
|
||||
"label": "vase"
|
||||
},
|
||||
{
|
||||
"index": 137,
|
||||
"color": "#2900FF",
|
||||
"label": "traffic;light;traffic;signal;stoplight"
|
||||
},
|
||||
{
|
||||
"index": 138,
|
||||
"color": "#29FF00",
|
||||
"label": "tray"
|
||||
},
|
||||
{
|
||||
"index": 139,
|
||||
"color": "#AD00FF",
|
||||
"label": "ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin"
|
||||
},
|
||||
{
|
||||
"index": 140,
|
||||
"color": "#00F5FF",
|
||||
"label": "fan"
|
||||
},
|
||||
{
|
||||
"index": 141,
|
||||
"color": "#4700FF",
|
||||
"label": "pier;wharf;wharfage;dock"
|
||||
},
|
||||
{
|
||||
"index": 142,
|
||||
"color": "#7A00FF",
|
||||
"label": "crt;screen"
|
||||
},
|
||||
{
|
||||
"index": 143,
|
||||
"color": "#00FFB8",
|
||||
"label": "plate"
|
||||
},
|
||||
{
|
||||
"index": 144,
|
||||
"color": "#005CFF",
|
||||
"label": "monitor;monitoring;device"
|
||||
},
|
||||
{
|
||||
"index": 145,
|
||||
"color": "#B8FF00",
|
||||
"label": "bulletin;board;notice;board"
|
||||
},
|
||||
{
|
||||
"index": 146,
|
||||
"color": "#0085FF",
|
||||
"label": "shower"
|
||||
},
|
||||
{
|
||||
"index": 147,
|
||||
"color": "#FFD600",
|
||||
"label": "radiator"
|
||||
},
|
||||
{
|
||||
"index": 148,
|
||||
"color": "#19C2C2",
|
||||
"label": "glass;drinking;glass"
|
||||
},
|
||||
{
|
||||
"index": 149,
|
||||
"color": "#66FF00",
|
||||
"label": "clock"
|
||||
},
|
||||
{
|
||||
"index": 150,
|
||||
"color": "#5C00FF",
|
||||
"label": "flag"
|
||||
}
|
||||
]
|
@ -1,155 +0,0 @@
|
||||
use candle::Device;
|
||||
use candle::Module;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::segformer::{
|
||||
Config, ImageClassificationModel, SemanticSegmentationModel,
|
||||
};
|
||||
use clap::{Args, Parser, Subcommand};
|
||||
use image::Rgb;
|
||||
use imageproc::integral_image::ArrayData;
|
||||
use std::collections::HashMap;
|
||||
use std::path::PathBuf;
|
||||
|
||||
#[derive(Parser)]
|
||||
#[clap(about, version, long_about = None)]
|
||||
struct CliArgs {
|
||||
#[arg(long, help = "use cpu")]
|
||||
cpu: bool,
|
||||
#[command(subcommand)]
|
||||
command: Commands,
|
||||
}
|
||||
#[derive(Args, Debug)]
|
||||
struct SegmentationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "nvidia/segformer-b0-finetuned-ade-512-512"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(
|
||||
long,
|
||||
help = "path to the label file in json format",
|
||||
default_value = "candle-examples/examples/segformer/assets/labels.json"
|
||||
)]
|
||||
label_path: PathBuf,
|
||||
#[arg(long, help = "path to for the output mask image")]
|
||||
output_path: PathBuf,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Args, Debug)]
|
||||
struct ClassificationArgs {
|
||||
#[arg(
|
||||
long,
|
||||
help = "name of the huggingface hub model",
|
||||
default_value = "paolinox/segformer-finetuned-food101"
|
||||
)]
|
||||
model_name: String,
|
||||
#[arg(help = "path to image as input")]
|
||||
image: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Subcommand, Debug)]
|
||||
enum Commands {
|
||||
Segment(SegmentationArgs),
|
||||
Classify(ClassificationArgs),
|
||||
}
|
||||
|
||||
fn get_vb_and_config(model_name: String, device: &Device) -> anyhow::Result<(VarBuilder, Config)> {
|
||||
println!("loading model {} via huggingface hub", model_name);
|
||||
let api = hf_hub::api::sync::Api::new()?;
|
||||
let api = api.model(model_name.clone());
|
||||
let model_file = api.get("model.safetensors")?;
|
||||
println!("model {} downloaded and loaded", model_name);
|
||||
let vb =
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], candle::DType::F32, device)? };
|
||||
let config = std::fs::read_to_string(api.get("config.json")?)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
println!("{:?}", config);
|
||||
Ok((vb, config))
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LabelItem {
|
||||
index: u32,
|
||||
color: String,
|
||||
}
|
||||
|
||||
fn segmentation_task(args: SegmentationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let label_file = std::fs::read_to_string(&args.label_path)?;
|
||||
let label_items: Vec<LabelItem> = serde_json::from_str(&label_file)?;
|
||||
let label_colors: HashMap<u32, Rgb<u8>> = label_items
|
||||
.iter()
|
||||
.map(|x| {
|
||||
(x.index - 1, {
|
||||
let color = x.color.trim_start_matches('#');
|
||||
let r = u8::from_str_radix(&color[0..2], 16).unwrap();
|
||||
let g = u8::from_str_radix(&color[2..4], 16).unwrap();
|
||||
let b = u8::from_str_radix(&color[4..6], 16).unwrap();
|
||||
Rgb([r, g, b])
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = label_items.len();
|
||||
|
||||
let model = SemanticSegmentationModel::new(&config, num_labels, vb)?;
|
||||
let segmentations = model.forward(&image)?;
|
||||
|
||||
// generate a mask image
|
||||
let mask = &segmentations.squeeze(0)?.argmax(0)?;
|
||||
let (h, w) = mask.dims2()?;
|
||||
let mask = mask.flatten_all()?.to_vec1::<u32>()?;
|
||||
let mask = mask
|
||||
.iter()
|
||||
.flat_map(|x| label_colors[x].data())
|
||||
.collect::<Vec<u8>>();
|
||||
let mask: image::ImageBuffer<image::Rgb<u8>, Vec<u8>> =
|
||||
image::ImageBuffer::from_raw(w as u32, h as u32, mask).unwrap();
|
||||
// resize
|
||||
let mask = image::DynamicImage::from(mask);
|
||||
let mask = mask.resize_to_fill(
|
||||
w as u32 * 4,
|
||||
h as u32 * 4,
|
||||
image::imageops::FilterType::CatmullRom,
|
||||
);
|
||||
mask.save(args.output_path.clone())?;
|
||||
println!("mask image saved to {:?}", args.output_path);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn classification_task(args: ClassificationArgs, device: &Device) -> anyhow::Result<()> {
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?
|
||||
.unsqueeze(0)?
|
||||
.to_device(device)?;
|
||||
let (vb, config) = get_vb_and_config(args.model_name, device)?;
|
||||
let num_labels = 7;
|
||||
let model = ImageClassificationModel::new(&config, num_labels, vb)?;
|
||||
let classification = model.forward(&image)?;
|
||||
let classification = candle_nn::ops::softmax_last_dim(&classification)?;
|
||||
let classification = classification.squeeze(0)?;
|
||||
println!(
|
||||
"classification logits {:?}",
|
||||
classification.to_vec1::<f32>()?
|
||||
);
|
||||
let label_id = classification.argmax(0)?.to_scalar::<u32>()?;
|
||||
let label_id = format!("{}", label_id);
|
||||
println!("label: {}", config.id2label[&label_id]);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = CliArgs::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
if let Commands::Segment(args) = args.command {
|
||||
segmentation_task(args, &device)?
|
||||
} else if let Commands::Classify(args) = args.command {
|
||||
classification_task(args, &device)?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -57,7 +57,7 @@ The downside is some long compilation time. You can set the
|
||||
`/home/user/.candle` to ensures that the compilation artifacts are properly
|
||||
cached.
|
||||
|
||||
Enabling flash-attention requires both a feature flag, `--features flash-attn`
|
||||
Enabling flash-attention requires both a feature flag, `--feature flash-attn`
|
||||
and using the command line flag `--use-flash-attn`.
|
||||
|
||||
Note that flash-attention-v2 is only compatible with Ampere, Ada, or Hopper GPUs
|
||||
|
@ -96,10 +96,6 @@ struct Args {
|
||||
/// information.
|
||||
#[arg(long, default_value_t = 0.8)]
|
||||
img2img_strength: f64,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long)]
|
||||
seed: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, clap::ValueEnum, PartialEq, Eq)]
|
||||
@ -378,7 +374,6 @@ fn run(args: Args) -> Result<()> {
|
||||
use_flash_attn,
|
||||
img2img,
|
||||
img2img_strength,
|
||||
seed,
|
||||
..
|
||||
} = args;
|
||||
|
||||
@ -432,9 +427,6 @@ fn run(args: Args) -> Result<()> {
|
||||
|
||||
let scheduler = sd_config.build_scheduler(n_steps)?;
|
||||
let device = candle_examples::device(cpu)?;
|
||||
if let Some(seed) = seed {
|
||||
device.set_seed(seed)?;
|
||||
}
|
||||
let use_guide_scale = guidance_scale > 1.0;
|
||||
|
||||
let which = match sd_version {
|
||||
|
@ -10,6 +10,11 @@ order to be able to use it.
|
||||
|
||||
Other available models are Stable-Code-3B, StableLM-2 and Zephyr variants.
|
||||
|
||||
StableLM-2 uses a Tiktoken based GPT-3.5/GPT-4 tokenizer not supported by
|
||||
Candle, so to run it you can download a somewhat compatible
|
||||
[tokenizer.json](https://huggingface.co/Xenova/gpt-4/resolve/main/tokenizer.json?download=true)
|
||||
and pass it via the --tokenizer-file argument.
|
||||
|
||||
## Running some example
|
||||
|
||||
```bash
|
||||
|
@ -239,7 +239,14 @@ fn main() -> Result<()> {
|
||||
));
|
||||
let tokenizer_filename = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
None => match args.which {
|
||||
Which::V1Orig | Which::V1 | Which::V1Zephyr | Which::Code => {
|
||||
repo.get("tokenizer.json")?
|
||||
}
|
||||
Which::V2 | Which::V2Zephyr => api
|
||||
.model("lmz/candle-stablelm".to_string())
|
||||
.get("tokenizer-gpt4.json")?,
|
||||
},
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
|
@ -1,253 +0,0 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use clap::Parser;
|
||||
|
||||
use candle_transformers::models::starcoder2::Model;
|
||||
|
||||
use candle::{DType, Device, Tensor};
|
||||
use candle_examples::token_output_stream::TokenOutputStream;
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::generation::LogitsProcessor;
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
struct TextGeneration {
|
||||
model: Model,
|
||||
device: Device,
|
||||
tokenizer: TokenOutputStream,
|
||||
logits_processor: LogitsProcessor,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
impl TextGeneration {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
model: Model,
|
||||
tokenizer: Tokenizer,
|
||||
seed: u64,
|
||||
temp: Option<f64>,
|
||||
top_p: Option<f64>,
|
||||
repeat_penalty: f32,
|
||||
repeat_last_n: usize,
|
||||
device: &Device,
|
||||
) -> Self {
|
||||
let logits_processor = LogitsProcessor::new(seed, temp, top_p);
|
||||
Self {
|
||||
model,
|
||||
tokenizer: TokenOutputStream::new(tokenizer),
|
||||
logits_processor,
|
||||
repeat_penalty,
|
||||
repeat_last_n,
|
||||
device: device.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> {
|
||||
use std::io::Write;
|
||||
self.tokenizer.clear();
|
||||
let mut tokens = self
|
||||
.tokenizer
|
||||
.tokenizer()
|
||||
.encode(prompt, true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
for &t in tokens.iter() {
|
||||
if let Some(t) = self.tokenizer.next_token(t)? {
|
||||
print!("{t}")
|
||||
}
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
|
||||
let mut generated_tokens = 0usize;
|
||||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") {
|
||||
Some(token) => token,
|
||||
None => anyhow::bail!("cannot find the <|endoftext|> token"),
|
||||
};
|
||||
let start_gen = std::time::Instant::now();
|
||||
for index in 0..sample_len {
|
||||
let context_size = if index > 0 { 1 } else { tokens.len() };
|
||||
let start_pos = tokens.len().saturating_sub(context_size);
|
||||
let ctxt = &tokens[start_pos..];
|
||||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
|
||||
let logits = self.model.forward(&input, start_pos)?;
|
||||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
|
||||
let logits = if self.repeat_penalty == 1. {
|
||||
logits
|
||||
} else {
|
||||
let start_at = tokens.len().saturating_sub(self.repeat_last_n);
|
||||
candle_transformers::utils::apply_repeat_penalty(
|
||||
&logits,
|
||||
self.repeat_penalty,
|
||||
&tokens[start_at..],
|
||||
)?
|
||||
};
|
||||
|
||||
let next_token = self.logits_processor.sample(&logits)?;
|
||||
tokens.push(next_token);
|
||||
generated_tokens += 1;
|
||||
if next_token == eos_token {
|
||||
break;
|
||||
}
|
||||
if let Some(t) = self.tokenizer.next_token(next_token)? {
|
||||
print!("{t}");
|
||||
std::io::stdout().flush()?;
|
||||
}
|
||||
}
|
||||
let dt = start_gen.elapsed();
|
||||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? {
|
||||
print!("{rest}");
|
||||
}
|
||||
std::io::stdout().flush()?;
|
||||
println!(
|
||||
"\n{generated_tokens} tokens generated ({:.2} token/s)",
|
||||
generated_tokens as f64 / dt.as_secs_f64(),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// Enable tracing (generates a trace-timestamp.json file).
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long)]
|
||||
use_flash_attn: bool,
|
||||
|
||||
#[arg(long)]
|
||||
prompt: String,
|
||||
|
||||
/// The temperature used to generate samples.
|
||||
#[arg(long)]
|
||||
temperature: Option<f64>,
|
||||
|
||||
/// Nucleus sampling probability cutoff.
|
||||
#[arg(long)]
|
||||
top_p: Option<f64>,
|
||||
|
||||
/// The seed to use when generating random samples.
|
||||
#[arg(long, default_value_t = 299792458)]
|
||||
seed: u64,
|
||||
|
||||
/// The length of the sample to generate (in tokens).
|
||||
#[arg(long, short = 'n', default_value_t = 10000)]
|
||||
sample_len: usize,
|
||||
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
#[arg(long, default_value = "main")]
|
||||
revision: String,
|
||||
|
||||
#[arg(long)]
|
||||
config_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
tokenizer_file: Option<String>,
|
||||
|
||||
#[arg(long)]
|
||||
weight_files: Option<String>,
|
||||
|
||||
/// Penalty to be applied for repeating tokens, 1. means no penalty.
|
||||
#[arg(long, default_value_t = 1.1)]
|
||||
repeat_penalty: f32,
|
||||
|
||||
/// The context size to consider for the repeat penalty.
|
||||
#[arg(long, default_value_t = 64)]
|
||||
repeat_last_n: usize,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
println!(
|
||||
"avx: {}, neon: {}, simd128: {}, f16c: {}",
|
||||
candle::utils::with_avx(),
|
||||
candle::utils::with_neon(),
|
||||
candle::utils::with_simd128(),
|
||||
candle::utils::with_f16c()
|
||||
);
|
||||
println!(
|
||||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}",
|
||||
args.temperature.unwrap_or(0.),
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n
|
||||
);
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let api = Api::new()?;
|
||||
let model_id = match args.model_id {
|
||||
Some(model_id) => model_id,
|
||||
None => "bigcode/starcoder2-3b".to_string(),
|
||||
};
|
||||
let repo = api.repo(Repo::with_revision(
|
||||
model_id,
|
||||
RepoType::Model,
|
||||
args.revision,
|
||||
));
|
||||
let config_file = match args.config_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("config.json")?,
|
||||
};
|
||||
let tokenizer_file = match args.tokenizer_file {
|
||||
Some(file) => std::path::PathBuf::from(file),
|
||||
None => repo.get("tokenizer.json")?,
|
||||
};
|
||||
let filenames = match args.weight_files {
|
||||
Some(files) => files
|
||||
.split(',')
|
||||
.map(std::path::PathBuf::from)
|
||||
.collect::<Vec<_>>(),
|
||||
None => vec![repo.get("model.safetensors")?],
|
||||
};
|
||||
println!("retrieved the files in {:?}", start.elapsed());
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_file).map_err(E::msg)?;
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let config = serde_json::from_reader(std::fs::File::open(config_file)?)?;
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let dtype = if device.is_cuda() {
|
||||
DType::BF16
|
||||
} else {
|
||||
DType::F32
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
println!("loaded the model in {:?}", start.elapsed());
|
||||
|
||||
let mut pipeline = TextGeneration::new(
|
||||
model,
|
||||
tokenizer,
|
||||
args.seed,
|
||||
args.temperature,
|
||||
args.top_p,
|
||||
args.repeat_penalty,
|
||||
args.repeat_last_n,
|
||||
&device,
|
||||
);
|
||||
pipeline.run(&args.prompt, args.sample_len)?;
|
||||
Ok(())
|
||||
}
|
@ -33,7 +33,7 @@ struct Args {
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
|
@ -28,7 +28,7 @@ pub fn main() -> anyhow::Result<()> {
|
||||
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?.to_device(&device)?;
|
||||
let image = candle_examples::imagenet::load_image224(args.image)?;
|
||||
println!("loaded image {image:?}");
|
||||
|
||||
let model_file = match args.model {
|
||||
|
@ -1,29 +0,0 @@
|
||||
use candle::{Result, Tensor};
|
||||
|
||||
// https://github.com/facebookresearch/audiocraft/blob/69fea8b290ad1b4b40d28f92d1dfc0ab01dbab85/audiocraft/data/audio_utils.py#L57
|
||||
pub fn normalize_loudness(
|
||||
wav: &Tensor,
|
||||
sample_rate: u32,
|
||||
loudness_compressor: bool,
|
||||
) -> Result<Tensor> {
|
||||
let energy = wav.sqr()?.mean_all()?.sqrt()?.to_vec0::<f32>()?;
|
||||
if energy < 2e-3 {
|
||||
return Ok(wav.clone());
|
||||
}
|
||||
let wav_array = wav.to_vec1::<f32>()?;
|
||||
let mut meter = crate::bs1770::ChannelLoudnessMeter::new(sample_rate);
|
||||
meter.push(wav_array.into_iter());
|
||||
let power = meter.as_100ms_windows();
|
||||
let loudness = match crate::bs1770::gated_mean(power) {
|
||||
None => return Ok(wav.clone()),
|
||||
Some(gp) => gp.loudness_lkfs() as f64,
|
||||
};
|
||||
let delta_loudness = -14. - loudness;
|
||||
let gain = 10f64.powf(delta_loudness / 20.);
|
||||
let wav = (wav * gain)?;
|
||||
if loudness_compressor {
|
||||
wav.tanh()
|
||||
} else {
|
||||
Ok(wav)
|
||||
}
|
||||
}
|
@ -1,506 +0,0 @@
|
||||
// Copied from https://github.com/ruuda/bs1770/blob/master/src/lib.rs
|
||||
// BS1770 -- Loudness analysis library conforming to ITU-R BS.1770
|
||||
// Copyright 2020 Ruud van Asseldonk
|
||||
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// A copy of the License has been included in the root of the repository.
|
||||
|
||||
//! Loudness analysis conforming to [ITU-R BS.1770-4][bs17704].
|
||||
//!
|
||||
//! This library offers the building blocks to perform BS.1770 loudness
|
||||
//! measurements, but you need to put the pieces together yourself.
|
||||
//!
|
||||
//! [bs17704]: https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
|
||||
//!
|
||||
//! # Stereo integrated loudness example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! # fn load_stereo_audio() -> [Vec<i16>; 2] {
|
||||
//! # [vec![0; 48_000], vec![0; 48_000]]
|
||||
//! # }
|
||||
//! #
|
||||
//! let sample_rate_hz = 44_100;
|
||||
//! let bits_per_sample = 16;
|
||||
//! let channel_samples: [Vec<i16>; 2] = load_stereo_audio();
|
||||
//!
|
||||
//! // When converting integer samples to float, note that the maximum amplitude
|
||||
//! // is `1 << (bits_per_sample - 1)`, one bit is the sign bit.
|
||||
//! let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
//!
|
||||
//! let channel_power: Vec<_> = channel_samples.iter().map(|samples| {
|
||||
//! let mut meter = bs1770::ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
//! meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
//! meter.into_100ms_windows()
|
||||
//! }).collect();
|
||||
//!
|
||||
//! let stereo_power = bs1770::reduce_stereo(
|
||||
//! channel_power[0].as_ref(),
|
||||
//! channel_power[1].as_ref(),
|
||||
//! );
|
||||
//!
|
||||
//! let gated_power = bs1770::gated_mean(
|
||||
//! stereo_power.as_ref()
|
||||
//! ).unwrap_or(bs1770::Power(0.0));
|
||||
//! println!("Integrated loudness: {:.1} LUFS", gated_power.loudness_lkfs());
|
||||
//! ```
|
||||
|
||||
use std::f32;
|
||||
|
||||
/// Coefficients for a 2nd-degree infinite impulse response filter.
|
||||
///
|
||||
/// Coefficient a0 is implicitly 1.0.
|
||||
#[derive(Clone)]
|
||||
struct Filter {
|
||||
a1: f32,
|
||||
a2: f32,
|
||||
b0: f32,
|
||||
b1: f32,
|
||||
b2: f32,
|
||||
|
||||
// The past two input and output samples.
|
||||
x1: f32,
|
||||
x2: f32,
|
||||
y1: f32,
|
||||
y2: f32,
|
||||
}
|
||||
|
||||
impl Filter {
|
||||
/// Stage 1 of th BS.1770-4 pre-filter.
|
||||
pub fn high_shelf(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let gain_db = 3.999_843_8;
|
||||
let q = 0.707_175_25;
|
||||
let center_hz = 1_681.974_5;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L134-L143.
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
let vh = 10.0_f32.powf(gain_db / 20.0);
|
||||
let vb = vh.powf(0.499_666_78);
|
||||
let a0 = 1.0 + k / q + k * k;
|
||||
Filter {
|
||||
b0: (vh + vb * k / q + k * k) / a0,
|
||||
b1: 2.0 * (k * k - vh) / a0,
|
||||
b2: (vh - vb * k / q + k * k) / a0,
|
||||
a1: 2.0 * (k * k - 1.0) / a0,
|
||||
a2: (1.0 - k / q + k * k) / a0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Stage 2 of th BS.1770-4 pre-filter.
|
||||
pub fn high_pass(sample_rate_hz: f32) -> Filter {
|
||||
// Coefficients taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/meter.py#L135-L136.
|
||||
let q = 0.500_327_05;
|
||||
let center_hz = 38.135_47;
|
||||
|
||||
// Formula taken from https://github.com/csteinmetz1/pyloudnorm/blob/
|
||||
// 6baa64d59b7794bc812e124438692e7fd2e65c0c/pyloudnorm/iirfilter.py#L145-L151
|
||||
let k = (f32::consts::PI * center_hz / sample_rate_hz).tan();
|
||||
Filter {
|
||||
a1: 2.0 * (k * k - 1.0) / (1.0 + k / q + k * k),
|
||||
a2: (1.0 - k / q + k * k) / (1.0 + k / q + k * k),
|
||||
b0: 1.0,
|
||||
b1: -2.0,
|
||||
b2: 1.0,
|
||||
|
||||
x1: 0.0,
|
||||
x2: 0.0,
|
||||
y1: 0.0,
|
||||
y2: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed the next input sample, get the next output sample.
|
||||
#[inline(always)]
|
||||
pub fn apply(&mut self, x0: f32) -> f32 {
|
||||
let y0 = 0.0 + self.b0 * x0 + self.b1 * self.x1 + self.b2 * self.x2
|
||||
- self.a1 * self.y1
|
||||
- self.a2 * self.y2;
|
||||
|
||||
self.x2 = self.x1;
|
||||
self.x1 = x0;
|
||||
self.y2 = self.y1;
|
||||
self.y1 = y0;
|
||||
|
||||
y0
|
||||
}
|
||||
}
|
||||
|
||||
/// Compensated sum, for summing many values of different orders of magnitude
|
||||
/// accurately.
|
||||
#[derive(Copy, Clone, PartialEq)]
|
||||
struct Sum {
|
||||
sum: f32,
|
||||
residue: f32,
|
||||
}
|
||||
|
||||
impl Sum {
|
||||
#[inline(always)]
|
||||
fn zero() -> Sum {
|
||||
Sum {
|
||||
sum: 0.0,
|
||||
residue: 0.0,
|
||||
}
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn add(&mut self, x: f32) {
|
||||
let sum = self.sum + (self.residue + x);
|
||||
self.residue = (self.residue + x) - (sum - self.sum);
|
||||
self.sum = sum;
|
||||
}
|
||||
}
|
||||
|
||||
/// The mean of the squares of the K-weighted samples in a window of time.
|
||||
///
|
||||
/// K-weighted power is equivalent to K-weighted loudness, the only difference
|
||||
/// is one of scale: power is quadratic in sample amplitudes, whereas loudness
|
||||
/// units are logarithmic. `loudness_lkfs` and `from_lkfs` convert between power,
|
||||
/// and K-weighted Loudness Units relative to nominal Full Scale (LKFS).
|
||||
///
|
||||
/// The term “LKFS” (Loudness Units, K-Weighted, relative to nominal Full Scale)
|
||||
/// is used in BS.1770-4 to emphasize K-weighting, but the term is otherwise
|
||||
/// interchangeable with the more widespread term “LUFS” (Loudness Units,
|
||||
/// relative to Full Scale). Loudness units are related to decibels in the
|
||||
/// following sense: boosting a signal that has a loudness of
|
||||
/// -<var>L<sub>K</sub></var> LUFS by <var>L<sub>K</sub></var> dB (by
|
||||
/// multiplying the amplitude by 10<sup><var>L<sub>K</sub></var>/20</sup>) will
|
||||
/// bring the loudness to 0 LUFS.
|
||||
///
|
||||
/// K-weighting refers to a high-shelf and high-pass filter that model the
|
||||
/// effect that humans perceive a certain amount of power in low frequencies to
|
||||
/// be less loud than the same amount of power in higher frequencies. In this
|
||||
/// library the `Power` type is used exclusively to refer to power after applying K-weighting.
|
||||
///
|
||||
/// The nominal “full scale” is the range [-1.0, 1.0]. Because the power is the
|
||||
/// mean square of the samples, if no input samples exceeded the full scale, the
|
||||
/// power will be in the range [0.0, 1.0]. However, the power delivered by
|
||||
/// multiple channels, which is a weighted sum over individual channel powers,
|
||||
/// can exceed this range, because the weighted sum is not normalized.
|
||||
#[derive(Copy, Clone, PartialEq, PartialOrd)]
|
||||
pub struct Power(pub f32);
|
||||
|
||||
impl Power {
|
||||
/// Convert Loudness Units relative to Full Scale into a squared sample amplitude.
|
||||
///
|
||||
/// This is the inverse of `loudness_lkfs`.
|
||||
pub fn from_lkfs(lkfs: f32) -> Power {
|
||||
// The inverse of the formula below.
|
||||
Power(10.0_f32.powf((lkfs + 0.691) * 0.1))
|
||||
}
|
||||
|
||||
/// Return the loudness of this window in Loudness Units, K-weighted, relative to Full Scale.
|
||||
///
|
||||
/// This is the inverse of `from_lkfs`.
|
||||
pub fn loudness_lkfs(&self) -> f32 {
|
||||
// Equation 2 (p.5) of BS.1770-4.
|
||||
-0.691 + 10.0 * self.0.log10()
|
||||
}
|
||||
}
|
||||
|
||||
/// A `T` value for non-overlapping windows of audio, 100ms in length.
|
||||
///
|
||||
/// The `ChannelLoudnessMeter` applies K-weighting and then produces the power
|
||||
/// for non-overlapping windows of 100ms duration.
|
||||
///
|
||||
/// These non-overlapping 100ms windows can later be combined into overlapping
|
||||
/// windows of 400ms, spaced 100ms apart, to compute instantaneous loudness or
|
||||
/// to perform a gated measurement, or they can be combined into even larger
|
||||
/// windows for a momentary loudness measurement.
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Windows100ms<T> {
|
||||
pub inner: T,
|
||||
}
|
||||
|
||||
impl<T> Windows100ms<T> {
|
||||
/// Wrap a new empty vector.
|
||||
pub fn new() -> Windows100ms<Vec<T>> {
|
||||
Windows100ms { inner: Vec::new() }
|
||||
}
|
||||
|
||||
/// Apply `as_ref` to the inner value.
|
||||
pub fn as_ref(&self) -> Windows100ms<&[Power]>
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply `as_mut` to the inner value.
|
||||
pub fn as_mut(&mut self) -> Windows100ms<&mut [Power]>
|
||||
where
|
||||
T: AsMut<[Power]>,
|
||||
{
|
||||
Windows100ms {
|
||||
inner: self.inner.as_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::len_without_is_empty)]
|
||||
/// Apply `len` to the inner value.
|
||||
pub fn len(&self) -> usize
|
||||
where
|
||||
T: AsRef<[Power]>,
|
||||
{
|
||||
self.inner.as_ref().len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Measures K-weighted power of non-overlapping 100ms windows of a single channel of audio.
|
||||
///
|
||||
/// # Output
|
||||
///
|
||||
/// The output of the meter is an intermediate result in the form of power for
|
||||
/// 100ms non-overlapping windows. The windows need to be processed further to
|
||||
/// get one of the instantaneous, momentary, and integrated loudness
|
||||
/// measurements defined in BS.1770.
|
||||
///
|
||||
/// The windows can also be inspected directly; the data is meaningful
|
||||
/// on its own (the K-weighted power delivered in that window of time), but it
|
||||
/// is not something that BS.1770 defines a term for.
|
||||
///
|
||||
/// # Multichannel audio
|
||||
///
|
||||
/// To perform a loudness measurement of multichannel audio, construct a
|
||||
/// `ChannelLoudnessMeter` per channel, and later combine the measured power
|
||||
/// with e.g. `reduce_stereo`.
|
||||
///
|
||||
/// # Instantaneous loudness
|
||||
///
|
||||
/// The instantaneous loudness is the power over a 400ms window, so you can
|
||||
/// average four 100ms windows. No special functionality is implemented to help
|
||||
/// with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Momentary loudness
|
||||
///
|
||||
/// The momentary loudness is the power over a 3-second window, so you can
|
||||
/// average thirty 100ms windows. No special functionality is implemented to
|
||||
/// help with that at this time. ([Pull requests would be accepted.][contribute])
|
||||
///
|
||||
/// # Integrated loudness
|
||||
///
|
||||
/// Use `gated_mean` to perform an integrated loudness measurement:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::{ChannelLoudnessMeter, gated_mean};
|
||||
/// # let sample_rate_hz = 44_100;
|
||||
/// # let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// # let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
/// # meter.push((0..44_100).map(|i| (i as f32 * 0.01).sin()));
|
||||
/// let integrated_loudness_lkfs = gated_mean(meter.as_100ms_windows())
|
||||
/// .unwrap_or(bs1770::Power(0.0))
|
||||
/// .loudness_lkfs();
|
||||
/// ```
|
||||
///
|
||||
/// [contribute]: https://github.com/ruuda/bs1770/blob/master/CONTRIBUTING.md
|
||||
#[derive(Clone)]
|
||||
pub struct ChannelLoudnessMeter {
|
||||
/// The number of samples that fit in 100ms of audio.
|
||||
samples_per_100ms: u32,
|
||||
|
||||
/// Stage 1 filter (head effects, high shelf).
|
||||
filter_stage1: Filter,
|
||||
|
||||
/// Stage 2 filter (high-pass).
|
||||
filter_stage2: Filter,
|
||||
|
||||
/// Sum of the squares over non-overlapping windows of 100ms.
|
||||
windows: Windows100ms<Vec<Power>>,
|
||||
|
||||
/// The number of samples in the current unfinished window.
|
||||
count: u32,
|
||||
|
||||
/// The sum of the squares of the samples in the current unfinished window.
|
||||
square_sum: Sum,
|
||||
}
|
||||
|
||||
impl ChannelLoudnessMeter {
|
||||
/// Construct a new loudness meter for the given sample rate.
|
||||
pub fn new(sample_rate_hz: u32) -> ChannelLoudnessMeter {
|
||||
ChannelLoudnessMeter {
|
||||
samples_per_100ms: sample_rate_hz / 10,
|
||||
filter_stage1: Filter::high_shelf(sample_rate_hz as f32),
|
||||
filter_stage2: Filter::high_pass(sample_rate_hz as f32),
|
||||
windows: Windows100ms::new(),
|
||||
count: 0,
|
||||
square_sum: Sum::zero(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Feed input samples for loudness analysis.
|
||||
///
|
||||
/// # Full scale
|
||||
///
|
||||
/// Full scale for the input samples is the interval [-1.0, 1.0]. If your
|
||||
/// input consists of signed integer samples, you can convert as follows:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # let mut meter = bs1770::ChannelLoudnessMeter::new(44_100);
|
||||
/// # let bits_per_sample = 16_usize;
|
||||
/// # let samples = &[0_i16];
|
||||
/// // Note that the maximum amplitude is `1 << (bits_per_sample - 1)`,
|
||||
/// // one bit is the sign bit.
|
||||
/// let normalizer = 1.0 / (1_u64 << (bits_per_sample - 1)) as f32;
|
||||
/// meter.push(samples.iter().map(|&s| s as f32 * normalizer));
|
||||
/// ```
|
||||
///
|
||||
/// # Repeated calls
|
||||
///
|
||||
/// You can call `push` multiple times to feed multiple batches of samples.
|
||||
/// This is equivalent to feeding a single chained iterator. The leftover of
|
||||
/// samples that did not fill a full 100ms window is not discarded:
|
||||
///
|
||||
/// ```ignore
|
||||
/// # use std::iter;
|
||||
/// # use bs1770::ChannelLoudnessMeter;
|
||||
/// let sample_rate_hz = 44_100;
|
||||
/// let samples_per_100ms = sample_rate_hz / 10;
|
||||
/// let mut meter = ChannelLoudnessMeter::new(sample_rate_hz);
|
||||
///
|
||||
/// meter.push(iter::repeat(0.0).take(samples_per_100ms as usize - 1));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 0);
|
||||
///
|
||||
/// meter.push(iter::once(0.0));
|
||||
/// assert_eq!(meter.as_100ms_windows().len(), 1);
|
||||
/// ```
|
||||
pub fn push<I: Iterator<Item = f32>>(&mut self, samples: I) {
|
||||
let normalizer = 1.0 / self.samples_per_100ms as f32;
|
||||
|
||||
// LLVM, if you could go ahead and inline those apply calls, and then
|
||||
// unroll and vectorize the loop, that'd be terrific.
|
||||
for x in samples {
|
||||
let y = self.filter_stage1.apply(x);
|
||||
let z = self.filter_stage2.apply(y);
|
||||
|
||||
self.square_sum.add(z * z);
|
||||
self.count += 1;
|
||||
|
||||
// TODO: Should this branch be marked cold?
|
||||
if self.count == self.samples_per_100ms {
|
||||
let mean_squares = Power(self.square_sum.sum * normalizer);
|
||||
self.windows.inner.push(mean_squares);
|
||||
// We intentionally do not reset the residue. That way, leftover
|
||||
// energy from this window is not lost, so for the file overall,
|
||||
// the sum remains more accurate.
|
||||
self.square_sum.sum = 0.0;
|
||||
self.count = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Return a reference to the 100ms windows analyzed so far.
|
||||
pub fn as_100ms_windows(&self) -> Windows100ms<&[Power]> {
|
||||
self.windows.as_ref()
|
||||
}
|
||||
|
||||
/// Return all 100ms windows analyzed so far.
|
||||
pub fn into_100ms_windows(self) -> Windows100ms<Vec<Power>> {
|
||||
self.windows
|
||||
}
|
||||
}
|
||||
|
||||
/// Combine power for multiple channels by taking a weighted sum.
|
||||
///
|
||||
/// Note that BS.1770-4 defines power for a multi-channel signal as a weighted
|
||||
/// sum over channels which is not normalized. This means that a stereo signal
|
||||
/// is inherently louder than a mono signal. For a mono signal played back on
|
||||
/// stereo speakers, you should therefore still apply `reduce_stereo`, passing
|
||||
/// in the same signal for both channels.
|
||||
pub fn reduce_stereo(
|
||||
left: Windows100ms<&[Power]>,
|
||||
right: Windows100ms<&[Power]>,
|
||||
) -> Windows100ms<Vec<Power>> {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
let mut result = Vec::with_capacity(left.len());
|
||||
for (l, r) in left.inner.iter().zip(right.inner) {
|
||||
result.push(Power(l.0 + r.0));
|
||||
}
|
||||
Windows100ms { inner: result }
|
||||
}
|
||||
|
||||
/// In-place version of `reduce_stereo` that stores the result in the former left channel.
|
||||
pub fn reduce_stereo_in_place(left: Windows100ms<&mut [Power]>, right: Windows100ms<&[Power]>) {
|
||||
assert_eq!(
|
||||
left.len(),
|
||||
right.len(),
|
||||
"Channels must have the same length."
|
||||
);
|
||||
for (l, r) in left.inner.iter_mut().zip(right.inner) {
|
||||
l.0 += r.0;
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform gating and averaging for a BS.1770-4 integrated loudness measurement.
|
||||
///
|
||||
/// The integrated loudness measurement is not just the average power over the
|
||||
/// entire signal. BS.1770-4 defines defines two stages of gating that exclude
|
||||
/// parts of the signal, to ensure that silent parts do not contribute to the
|
||||
/// loudness measurment. This function performs that gating, and returns the
|
||||
/// average power over the windows that were not excluded.
|
||||
///
|
||||
/// The result of this function is the integrated loudness measurement.
|
||||
///
|
||||
/// When no signal remains after applying the gate, this function returns
|
||||
/// `None`. In particular, this happens when all of the signal is softer than
|
||||
/// -70 LKFS, including a signal that consists of pure silence.
|
||||
pub fn gated_mean(windows_100ms: Windows100ms<&[Power]>) -> Option<Power> {
|
||||
let mut gating_blocks = Vec::with_capacity(windows_100ms.len());
|
||||
|
||||
// Stage 1: an absolute threshold of -70 LKFS. (Equation 6, p.6.)
|
||||
let absolute_threshold = Power::from_lkfs(-70.0);
|
||||
|
||||
// Iterate over all 400ms windows.
|
||||
for window in windows_100ms.inner.windows(4) {
|
||||
// Note that the sum over channels has already been performed at this point.
|
||||
let gating_block_power = Power(0.25 * window.iter().map(|mean| mean.0).sum::<f32>());
|
||||
|
||||
if gating_block_power > absolute_threshold {
|
||||
gating_blocks.push(gating_block_power);
|
||||
}
|
||||
}
|
||||
|
||||
if gating_blocks.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Compute the loudness after applying the absolute gate, in order to
|
||||
// determine the threshold for the relative gate.
|
||||
let mut sum_power = Sum::zero();
|
||||
for &gating_block_power in &gating_blocks {
|
||||
sum_power.add(gating_block_power.0);
|
||||
}
|
||||
let absolute_gated_power = Power(sum_power.sum / (gating_blocks.len() as f32));
|
||||
|
||||
// Stage 2: Apply the relative gate.
|
||||
let relative_threshold = Power::from_lkfs(absolute_gated_power.loudness_lkfs() - 10.0);
|
||||
let mut sum_power = Sum::zero();
|
||||
let mut n_blocks = 0_usize;
|
||||
for &gating_block_power in &gating_blocks {
|
||||
if gating_block_power > relative_threshold {
|
||||
sum_power.add(gating_block_power.0);
|
||||
n_blocks += 1;
|
||||
}
|
||||
}
|
||||
|
||||
if n_blocks == 0 {
|
||||
return None;
|
||||
}
|
||||
|
||||
let relative_gated_power = Power(sum_power.sum / n_blocks as f32);
|
||||
Some(relative_gated_power)
|
||||
}
|
@ -1,9 +1,6 @@
|
||||
pub mod audio;
|
||||
pub mod bs1770;
|
||||
pub mod coco_classes;
|
||||
pub mod imagenet;
|
||||
pub mod token_output_stream;
|
||||
pub mod wav;
|
||||
|
||||
use candle::utils::{cuda_is_available, metal_is_available};
|
||||
use candle::{Device, Result, Tensor};
|
||||
|
@ -40,7 +40,7 @@ impl TokenOutputStream {
|
||||
};
|
||||
self.tokens.push(token);
|
||||
let text = self.decode(&self.tokens[self.prev_index..])?;
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() {
|
||||
if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphabetic() {
|
||||
let text = text.split_at(prev_text.len());
|
||||
self.prev_index = self.current_index;
|
||||
self.current_index = self.tokens.len();
|
||||
|
@ -1,56 +0,0 @@
|
||||
use std::io::prelude::*;
|
||||
|
||||
pub trait Sample {
|
||||
fn to_i16(&self) -> i16;
|
||||
}
|
||||
|
||||
impl Sample for f32 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for f64 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
(self.clamp(-1.0, 1.0) * 32767.0) as i16
|
||||
}
|
||||
}
|
||||
|
||||
impl Sample for i16 {
|
||||
fn to_i16(&self) -> i16 {
|
||||
*self
|
||||
}
|
||||
}
|
||||
|
||||
pub fn write_pcm_as_wav<W: Write, S: Sample>(
|
||||
w: &mut W,
|
||||
samples: &[S],
|
||||
sample_rate: u32,
|
||||
) -> std::io::Result<()> {
|
||||
let len = 12u32; // header
|
||||
let len = len + 24u32; // fmt
|
||||
let len = len + samples.len() as u32 * 2 + 8; // data
|
||||
let n_channels = 1u16;
|
||||
let bytes_per_second = sample_rate * 2 * n_channels as u32;
|
||||
w.write_all(b"RIFF")?;
|
||||
w.write_all(&(len - 8).to_le_bytes())?; // total length minus 8 bytes
|
||||
w.write_all(b"WAVE")?;
|
||||
|
||||
// Format block
|
||||
w.write_all(b"fmt ")?;
|
||||
w.write_all(&16u32.to_le_bytes())?; // block len minus 8 bytes
|
||||
w.write_all(&1u16.to_le_bytes())?; // PCM
|
||||
w.write_all(&n_channels.to_le_bytes())?; // one channel
|
||||
w.write_all(&sample_rate.to_le_bytes())?;
|
||||
w.write_all(&bytes_per_second.to_le_bytes())?;
|
||||
w.write_all(&2u16.to_le_bytes())?; // 2 bytes of data per sample
|
||||
w.write_all(&16u16.to_le_bytes())?; // bits per sample
|
||||
|
||||
// Data block
|
||||
w.write_all(b"data")?;
|
||||
w.write_all(&(samples.len() as u32 * 2).to_le_bytes())?;
|
||||
for sample in samples.iter() {
|
||||
w.write_all(&sample.to_i16().to_le_bytes())?
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-flash-attn"
|
||||
version = "0.4.2"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Flash attention layer for the candle ML framework."
|
||||
@ -11,7 +11,7 @@ license = "MIT OR Apache-2.0"
|
||||
readme = "README.md"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.2" }
|
||||
candle = { path = "../candle-core", features = ["cuda"], package = "candle-core", version = "0.4.0" }
|
||||
half = { version = "2.3.1", features = ["num-traits"] }
|
||||
|
||||
[build-dependencies]
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-kernels"
|
||||
version = "0.4.2"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "CUDA kernels for Candle"
|
||||
|
@ -51,48 +51,6 @@ __device__ void conv1d(
|
||||
dst[dst_i] = static_cast<T>(d);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void col2im1d(
|
||||
const size_t l_in,
|
||||
const size_t l_out,
|
||||
const size_t c_out,
|
||||
const size_t k_size,
|
||||
const size_t b_size,
|
||||
const size_t stride,
|
||||
const T *src,
|
||||
T *dst
|
||||
) {
|
||||
const size_t dst_i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
// src: (b_size, l_in, c_out, k_size)
|
||||
// dst: (b_size, c_out, l_out)
|
||||
if (dst_i >= b_size * c_out * l_out) {
|
||||
return;
|
||||
}
|
||||
const size_t dst_s0 = c_out * l_out;
|
||||
const size_t dst_s1 = l_out;
|
||||
|
||||
// dst_idx = b_i * dst_s0 + c_i * dst_s1 + l_in_i * stride + k_i
|
||||
const size_t b_i = dst_i / dst_s0;
|
||||
const size_t dst_i2 = dst_i - b_i * dst_s0;
|
||||
const size_t c_i = dst_i2 / dst_s1;
|
||||
const size_t dst_i3 = dst_i2 - c_i * dst_s1; // l_in_i * stride + k_i
|
||||
|
||||
const size_t src_s0 = c_out * k_size * l_in;
|
||||
const size_t src_s1 = c_out * k_size;
|
||||
const size_t src_s2 = k_size;
|
||||
|
||||
T d = 0;
|
||||
for (size_t k_i = 0; k_i < min(dst_i3 + 1, k_size); ++k_i) {
|
||||
const size_t l_in_i_times_stride = dst_i3 - k_i;
|
||||
const size_t l_in_i = l_in_i_times_stride / stride;
|
||||
const size_t src_i = b_i * src_s0 + l_in_i * src_s1 + c_i * src_s2 + k_i;
|
||||
if (l_in_i * stride == l_in_i_times_stride && l_in_i < l_in) {
|
||||
d += src[src_i];
|
||||
}
|
||||
}
|
||||
dst[dst_i] = d;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ void im2col1d(
|
||||
const size_t dst_numel,
|
||||
@ -569,7 +527,7 @@ extern "C" __global__ void FN_NAME( \
|
||||
conv2d<TYPENAME, TYPEACC>(src_numel, w_out, h_out, stride, padding, dilation, info, src, kernel, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME, FN_NAME2) \
|
||||
#define IM2COL1D_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
const size_t dst_numel, \
|
||||
const size_t l_out, \
|
||||
@ -583,18 +541,6 @@ extern "C" __global__ void FN_NAME( \
|
||||
) { \
|
||||
im2col1d<TYPENAME>(dst_numel, l_out, l_k, stride, padding, dilation, info, src, dst); \
|
||||
} \
|
||||
extern "C" __global__ void FN_NAME2( \
|
||||
const size_t l_in, \
|
||||
const size_t l_out, \
|
||||
const size_t c_out, \
|
||||
const size_t k_size, \
|
||||
const size_t b_size, \
|
||||
const size_t stride, \
|
||||
const TYPENAME *src, \
|
||||
TYPENAME *dst \
|
||||
) { \
|
||||
col2im1d<TYPENAME>(l_in, l_out, c_out, k_size, b_size, stride, src, dst); \
|
||||
} \
|
||||
|
||||
#define IM2COL_OP(TYPENAME, FN_NAME) \
|
||||
extern "C" __global__ void FN_NAME( \
|
||||
@ -696,7 +642,7 @@ AVG_POOL2D_OP(__nv_bfloat16, float, avg_pool2d_bf16)
|
||||
MAX_POOL2D_OP(__nv_bfloat16, max_pool2d_bf16)
|
||||
UPSAMPLE_NEAREST2D_OP(__nv_bfloat16, upsample_nearest2d_bf16)
|
||||
IM2COL_OP(__nv_bfloat16, im2col_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16, col2im1d_bf16)
|
||||
IM2COL1D_OP(__nv_bfloat16, im2col1d_bf16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
@ -708,7 +654,7 @@ AVG_POOL2D_OP(__half, float, avg_pool2d_f16)
|
||||
MAX_POOL2D_OP(__half, max_pool2d_f16)
|
||||
UPSAMPLE_NEAREST2D_OP(__half, upsample_nearest2d_f16)
|
||||
IM2COL_OP(__half, im2col_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16, col2im1d_f16)
|
||||
IM2COL1D_OP(__half, im2col1d_f16)
|
||||
#endif
|
||||
|
||||
CONV1D_OP(float, float, conv1d_f32)
|
||||
@ -751,7 +697,7 @@ IM2COL_OP(double, im2col_f64)
|
||||
IM2COL_OP(uint8_t, im2col_u8)
|
||||
IM2COL_OP(uint32_t, im2col_u32)
|
||||
|
||||
IM2COL1D_OP(float, im2col1d_f32, col2im1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64, col2im1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8, col2im1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32, col2im1d_u32)
|
||||
IM2COL1D_OP(float, im2col1d_f32)
|
||||
IM2COL1D_OP(double, im2col1d_f64)
|
||||
IM2COL1D_OP(uint8_t, im2col1d_u8)
|
||||
IM2COL1D_OP(uint32_t, im2col1d_u32)
|
||||
|
@ -10,39 +10,11 @@ __device__ void fill_with(T *buf, T value, const size_t numel) {
|
||||
extern "C" __global__ void fill_u8(uint8_t *buf, uint8_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_u32(uint32_t *buf, uint32_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_i64(int64_t *buf, int64_t value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f32(float *buf, float value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
extern "C" __global__ void fill_f64(double *buf, double value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
|
||||
template<typename T>
|
||||
__device__ void copy2d(const T *src, T *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) {
|
||||
uint32_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (idx >= d1 * d2) {
|
||||
return;
|
||||
}
|
||||
uint32_t idx1 = idx / d2;
|
||||
uint32_t idx2 = idx - d2 * idx1;
|
||||
dst[idx1 * dst_s + idx2] = src[idx1 * src_s + idx2];
|
||||
}
|
||||
|
||||
#define COPY2D_OP(TYPENAME, FNNAME) \
|
||||
extern "C" __global__ \
|
||||
void FNNAME(const TYPENAME *src, TYPENAME *dst, uint32_t d1, uint32_t d2, uint32_t src_s, uint32_t dst_s) { \
|
||||
copy2d(src, dst, d1, d2, src_s, dst_s); \
|
||||
} \
|
||||
|
||||
COPY2D_OP(float, copy2d_f32)
|
||||
COPY2D_OP(double, copy2d_f64)
|
||||
COPY2D_OP(uint8_t, copy2d_u8)
|
||||
COPY2D_OP(uint32_t, copy2d_u32)
|
||||
COPY2D_OP(int64_t, copy2d_i64)
|
||||
|
||||
#if __CUDA_ARCH__ >= 530
|
||||
extern "C" __global__ void fill_f16(__half *buf, __half value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__half, copy2d_f16)
|
||||
#endif
|
||||
|
||||
#if __CUDA_ARCH__ >= 800
|
||||
#include <cuda_bf16.h>
|
||||
extern "C" __global__ void fill_bf16(__nv_bfloat16 *buf, __nv_bfloat16 value, const size_t numel) { fill_with(buf, value, numel); }
|
||||
COPY2D_OP(__nv_bfloat16, copy2d_bf16)
|
||||
#endif
|
||||
|
@ -4,7 +4,6 @@ pub const CAST: &str = include_str!(concat!(env!("OUT_DIR"), "/cast.ptx"));
|
||||
pub const CONV: &str = include_str!(concat!(env!("OUT_DIR"), "/conv.ptx"));
|
||||
pub const FILL: &str = include_str!(concat!(env!("OUT_DIR"), "/fill.ptx"));
|
||||
pub const INDEXING: &str = include_str!(concat!(env!("OUT_DIR"), "/indexing.ptx"));
|
||||
pub const QUANTIZED: &str = include_str!(concat!(env!("OUT_DIR"), "/quantized.ptx"));
|
||||
pub const REDUCE: &str = include_str!(concat!(env!("OUT_DIR"), "/reduce.ptx"));
|
||||
pub const TERNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/ternary.ptx"));
|
||||
pub const UNARY: &str = include_str!(concat!(env!("OUT_DIR"), "/unary.ptx"));
|
||||
|
File diff suppressed because it is too large
Load Diff
2
candle-metal-kernels/.gitignore
vendored
Normal file
2
candle-metal-kernels/.gitignore
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
src/compiled/
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-metal-kernels"
|
||||
version = "0.4.2"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "Metal kernels for Candle"
|
||||
|
45
candle-metal-kernels/build.rs
Normal file
45
candle-metal-kernels/build.rs
Normal file
@ -0,0 +1,45 @@
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let files: std::fs::ReadDir = std::fs::read_dir("src/").unwrap();
|
||||
for file in files {
|
||||
let file = file?;
|
||||
let path = file.path();
|
||||
if let Some(extension) = path.extension() {
|
||||
if extension == "metal" {
|
||||
build_kernel(&path)?;
|
||||
}
|
||||
println!("cargo:warning=output {:?}", path.file_stem());
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn build_kernel(path: &Path) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let stem = path
|
||||
.file_stem()
|
||||
.expect("expect real filename")
|
||||
.to_str()
|
||||
.expect("expect real stem");
|
||||
Command::new("xcrun")
|
||||
.args([
|
||||
"metal",
|
||||
"-c",
|
||||
path.as_os_str().to_str().expect("Expect a real filename"),
|
||||
"-I",
|
||||
"src/",
|
||||
"-o",
|
||||
&format!("src/compiled/{stem}.air"),
|
||||
])
|
||||
.output()?;
|
||||
Command::new("xcrun")
|
||||
.args([
|
||||
"metallib",
|
||||
&format!("src/compiled/{stem}.air"),
|
||||
"-o",
|
||||
&format!("src/compiled/{stem}.metallib"),
|
||||
])
|
||||
.output()?;
|
||||
Ok(())
|
||||
}
|
@ -89,7 +89,7 @@ kernel void FN_NAME( \
|
||||
return; \
|
||||
} \
|
||||
const TYPENAME x = input[id]; \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * (exp(x) - 1)); \
|
||||
output[id] = TYPENAME((x > 0)?x: mul * exp(x - 1)); \
|
||||
} \
|
||||
kernel void FN_NAME##_strided( \
|
||||
constant size_t &dim, \
|
||||
|
@ -72,60 +72,28 @@ kernel void FN_NAME_STRIDED( \
|
||||
output[tid] = static_cast<RIGHT_TYPENAME>(static_cast<IR_TYPENAME>(input[get_strided_index(tid, num_dims, dims, strides)])); \
|
||||
} \
|
||||
|
||||
// u32
|
||||
CAST(cast_u32_f32, cast_u32_f32_strided, uint32_t, float)
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u32_f16, cast_u32_f16_strided, uint32_t, half)
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
#endif
|
||||
|
||||
// u8
|
||||
CAST(cast_u32_u8, cast_u32_u8_strided, uint32_t, uint8_t)
|
||||
CAST(cast_u8_u32, cast_u8_u32_strided, uint8_t, uint32_t)
|
||||
CAST(cast_u8_f32, cast_u8_f32_strided, uint8_t, float)
|
||||
CAST(cast_u8_f16, cast_u8_f16_strided, uint8_t, half)
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
CAST(cast_u8_i64, cast_u8_i64_strided, uint8_t, int64_t)
|
||||
#endif
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
#endif
|
||||
|
||||
// f16
|
||||
CAST(cast_f16_f32, cast_f16_f32_strided, half, float)
|
||||
CAST(cast_f16_u8, cast_f16_u8_strided, half, uint8_t)
|
||||
CAST(cast_f16_u32, cast_f16_u32_strided, half, uint32_t)
|
||||
CAST(cast_f16_i64, cast_f16_i64_strided, half, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
||||
// i64
|
||||
CAST(cast_u32_i64, cast_u32_i64_strided, uint32_t, int64_t)
|
||||
CAST(cast_i64_f32, cast_i64_f32_strided, int64_t, float)
|
||||
CAST(cast_i64_u8, cast_i64_u8_strided, int64_t, uint8_t)
|
||||
CAST(cast_i64_u32, cast_i64_u32_strided, int64_t, uint32_t)
|
||||
CAST(cast_i64_f16, cast_i64_f16_strided, int64_t, half)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST_THROUGH(cast_i64_bf16, cast_i64_bf16_strided, int64_t, bfloat, float)
|
||||
#endif
|
||||
|
||||
// f32
|
||||
CAST(cast_f32_f16, cast_f32_f16_strided, float, half)
|
||||
CAST(cast_f32_u32, cast_f32_u32_strided, float, uint32_t)
|
||||
CAST(cast_f32_u8, cast_f32_u8_strided, float, uint8_t)
|
||||
CAST(cast_f32_i64, cast_f32_i64_strided, float, int64_t)
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
#endif
|
||||
|
||||
// bf16
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
CAST(cast_bf16_u32, cast_bf16_u32_strided, bfloat, uint32_t)
|
||||
CAST(cast_bf16_i64, cast_bf16_i64_strided, bfloat, int64_t)
|
||||
CAST(cast_bf16_f32, cast_bf16_f32_strided, bfloat, float)
|
||||
CAST(cast_u8_bf16, cast_u8_bf16_strided, uint8_t, bfloat)
|
||||
CAST(cast_u32_bf16, cast_u32_bf16_strided, uint32_t, bfloat)
|
||||
CAST(cast_f32_bf16, cast_f32_bf16_strided, float, bfloat)
|
||||
|
||||
CAST_THROUGH(cast_bf16_u8, cast_bf16_u8_strided, bfloat, uint8_t, float)
|
||||
CAST_THROUGH(cast_bf16_f16, cast_bf16_f16_strided, bfloat, half, float)
|
||||
#endif
|
||||
CAST_THROUGH(cast_f16_bf16, cast_f16_bf16_strided, half, bfloat, float)
|
||||
#endif
|
||||
|
@ -167,16 +167,11 @@ kernel void NAME( \
|
||||
|
||||
INDEX_OP(is_u32_f32, uint, float)
|
||||
INDEX_OP(is_u32_f16, uint, half)
|
||||
|
||||
GATHER_OP(gather_u32_f32, uint, float)
|
||||
GATHER_OP(gather_u32_f16, uint, half)
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint, half)
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_f32, uint32_t, float)
|
||||
SCATTER_ADD_OP(sa_u8_f32, uint8_t, float)
|
||||
SCATTER_ADD_OP(sa_i64_f32, int64_t, float)
|
||||
SCATTER_ADD_OP(sa_u32_f16, uint32_t, half)
|
||||
SCATTER_ADD_OP(sa_u8_f16, uint8_t, half)
|
||||
SCATTER_ADD_OP(sa_i64_f16, int64_t, half)
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
INDEX_OP(is_u32_bf16, uint32_t, bfloat)
|
||||
@ -185,10 +180,6 @@ INDEX_OP(is_u8_bf16, uint8_t, bfloat)
|
||||
INDEX_ADD_OP(ia_i64_bf16, int64_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u32_bf16, uint32_t, bfloat)
|
||||
INDEX_ADD_OP(ia_u8_bf16, uint8_t, bfloat)
|
||||
|
||||
SCATTER_ADD_OP(sa_u32_bf16, uint32_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_u8_bf16, uint8_t, bfloat)
|
||||
SCATTER_ADD_OP(sa_i64_bf16, int64_t, bfloat)
|
||||
#endif
|
||||
|
||||
INDEX_ADD_OP(ia_u32_f16, uint32_t, half)
|
||||
|
@ -1,22 +1,22 @@
|
||||
use metal::{
|
||||
Buffer, CommandBufferRef, CompileOptions, ComputeCommandEncoderRef, ComputePipelineState,
|
||||
Device, Function, FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
Buffer, CommandBufferRef, ComputeCommandEncoderRef, ComputePipelineState, Device, Function,
|
||||
FunctionConstantValues, Library, MTLDataType, MTLSize, NSUInteger,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::c_void;
|
||||
use std::sync::RwLock;
|
||||
|
||||
const AFFINE: &str = include_str!("affine.metal");
|
||||
const INDEXING: &str = include_str!("indexing.metal");
|
||||
const UNARY: &str = include_str!("unary.metal");
|
||||
const BINARY: &str = include_str!("binary.metal");
|
||||
const TERNARY: &str = include_str!("ternary.metal");
|
||||
const CAST: &str = include_str!("cast.metal");
|
||||
const CONV: &str = include_str!("conv.metal");
|
||||
const REDUCE: &str = include_str!("reduce.metal");
|
||||
const RANDOM: &str = include_str!("random.metal");
|
||||
const AFFINE: &[u8] = include_bytes!("compiled/affine.metallib");
|
||||
const INDEXING: &[u8] = include_bytes!("compiled/indexing.metallib");
|
||||
const UNARY: &[u8] = include_bytes!("compiled/unary.metallib");
|
||||
const BINARY: &[u8] = include_bytes!("compiled/binary.metallib");
|
||||
const TERNARY: &[u8] = include_bytes!("compiled/ternary.metallib");
|
||||
const CAST: &[u8] = include_bytes!("compiled/cast.metallib");
|
||||
const CONV: &[u8] = include_bytes!("compiled/conv.metallib");
|
||||
const REDUCE: &[u8] = include_bytes!("compiled/reduce.metallib");
|
||||
const RANDOM: &[u8] = include_bytes!("compiled/random.metallib");
|
||||
const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
|
||||
const QUANTIZED: &str = include_str!("quantized.metal");
|
||||
const QUANTIZED: &[u8] = include_bytes!("compiled/quantized.metallib");
|
||||
|
||||
/// Most kernels apply similarly across the tensors
|
||||
/// This creates a strategy that uses the maximum amount of threads per threadgroup (capped at the
|
||||
@ -127,16 +127,6 @@ pub enum Source {
|
||||
Quantized,
|
||||
}
|
||||
|
||||
pub mod copy2d {
|
||||
pub struct Kernel(pub &'static str);
|
||||
pub const FLOAT: Kernel = Kernel("copy2d_f32");
|
||||
pub const HALF: Kernel = Kernel("copy2d_f16");
|
||||
pub const BFLOAT: Kernel = Kernel("copy2d_bf16");
|
||||
pub const I64: Kernel = Kernel("copy2d_i64");
|
||||
pub const U32: Kernel = Kernel("copy2d_u32");
|
||||
pub const U8: Kernel = Kernel("copy2d_u8");
|
||||
}
|
||||
|
||||
macro_rules! ops{
|
||||
($($name:ident),+) => {
|
||||
|
||||
@ -245,7 +235,7 @@ impl Kernels {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_library_source(&self, source: Source) -> &'static str {
|
||||
fn get_library_source(&self, source: Source) -> &'static [u8] {
|
||||
match source {
|
||||
Source::Affine => AFFINE,
|
||||
Source::Unary => UNARY,
|
||||
@ -257,7 +247,7 @@ impl Kernels {
|
||||
Source::Conv => CONV,
|
||||
Source::Random => RANDOM,
|
||||
Source::Quantized => QUANTIZED,
|
||||
Source::Mfa => panic!("Invalid lib"),
|
||||
Source::Mfa => MFA,
|
||||
}
|
||||
}
|
||||
|
||||
@ -272,22 +262,12 @@ impl Kernels {
|
||||
if let Some(lib) = libraries.get(&source) {
|
||||
Ok(lib.clone())
|
||||
} else {
|
||||
let lib = match source {
|
||||
Source::Mfa => {
|
||||
let source_data = MFA;
|
||||
device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?
|
||||
}
|
||||
source => {
|
||||
let source_content = self.get_library_source(source);
|
||||
device
|
||||
.new_library_with_source(source_content, &CompileOptions::new())
|
||||
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
|
||||
}
|
||||
};
|
||||
let source_data = self.get_library_source(source);
|
||||
let lib = device.new_library_with_data(source_data).map_err(|e| {
|
||||
MetalKernelError::LoadLibraryError(format!(
|
||||
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
|
||||
))
|
||||
})?;
|
||||
libraries.insert(source, lib.clone());
|
||||
Ok(lib)
|
||||
}
|
||||
@ -375,46 +355,6 @@ pub fn call_unary_contiguous(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_copy2d(
|
||||
device: &Device,
|
||||
command_buffer: &CommandBufferRef,
|
||||
kernels: &Kernels,
|
||||
name: copy2d::Kernel,
|
||||
input: &Buffer,
|
||||
output: &Buffer,
|
||||
d1: usize,
|
||||
d2: usize,
|
||||
src_s: usize,
|
||||
dst_s: usize,
|
||||
src_o_in_bytes: usize,
|
||||
dst_o_in_bytes: usize,
|
||||
) -> Result<(), MetalKernelError> {
|
||||
let pipeline = kernels.load_pipeline(device, Source::Unary, name.0)?;
|
||||
let encoder = command_buffer.new_compute_command_encoder();
|
||||
encoder.set_compute_pipeline_state(&pipeline);
|
||||
set_params!(
|
||||
encoder,
|
||||
(
|
||||
d1,
|
||||
d2,
|
||||
src_s,
|
||||
dst_s,
|
||||
(input, src_o_in_bytes),
|
||||
(output, dst_o_in_bytes)
|
||||
)
|
||||
);
|
||||
|
||||
let width: usize = d1 * d2;
|
||||
let (thread_group_count, thread_group_size) = linear_split(&pipeline, width);
|
||||
|
||||
encoder.use_resource(input, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(output, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn call_unary_strided(
|
||||
device: &Device,
|
||||
@ -1608,10 +1548,8 @@ pub fn call_random_uniform(
|
||||
|
||||
set_params!(encoder, (length, min, max, seed, buffer));
|
||||
|
||||
encoder.use_resource(
|
||||
seed,
|
||||
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
|
||||
);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
@ -1641,10 +1579,8 @@ pub fn call_random_normal(
|
||||
|
||||
set_params!(encoder, (length, mean, stddev, seed, buffer));
|
||||
|
||||
encoder.use_resource(
|
||||
seed,
|
||||
metal::MTLResourceUsage::Read | metal::MTLResourceUsage::Write,
|
||||
);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Read);
|
||||
encoder.use_resource(seed, metal::MTLResourceUsage::Write);
|
||||
encoder.use_resource(buffer, metal::MTLResourceUsage::Write);
|
||||
encoder.dispatch_thread_groups(thread_group_count, thread_group_size);
|
||||
encoder.end_encoding();
|
||||
|
@ -123,20 +123,16 @@ template<typename T> METAL_FUNC void rand_uniform(
|
||||
return;
|
||||
}
|
||||
|
||||
// Evenly sized vectors need an offset when writing the mirror element.
|
||||
uint off = 1 - size % 2;
|
||||
float diff = abs(min - max);
|
||||
uint s = atomic_load_explicit(seed, memory_order_relaxed);
|
||||
HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1});
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
out[tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0 && off == 0, otherwise we will write to out[size].
|
||||
if (off == 0)
|
||||
return;
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - off - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
out[size - tid] = static_cast<T>(rng.rand() * diff + min);
|
||||
}
|
||||
|
||||
// Create Gaussian normal distribution using Box-Muller transform:
|
||||
@ -152,10 +148,7 @@ template<typename T> METAL_FUNC void normal(
|
||||
if (tid >= size) {
|
||||
return;
|
||||
}
|
||||
// Evenly sized vectors need an offset when writing the mirror element.
|
||||
uint off = 1 - size % 2;
|
||||
uint s = atomic_load_explicit(seed, memory_order_relaxed);
|
||||
HybridTaus rng = HybridTaus::init({ulong(s), tid, 1, 1});
|
||||
HybridTaus rng = HybridTaus::init({ulong(seed), tid, 1, 1});
|
||||
float u1 = rng.rand();
|
||||
float u2 = rng.rand();
|
||||
|
||||
@ -169,12 +162,11 @@ template<typename T> METAL_FUNC void normal(
|
||||
|
||||
if (tid == 0) {
|
||||
atomic_store_explicit(seed, uint(rng.rand() * UNIF01_NORM32), memory_order_relaxed);
|
||||
// Return early if tid == 0 && off == 0, otherwise we will write to out[size].
|
||||
if (off == 0)
|
||||
return;
|
||||
// Return early if tid == 0, otherwise we will write to out[size].
|
||||
return;
|
||||
}
|
||||
// Use symmetry to fill the other half of the array.
|
||||
out[size - off - tid] = static_cast<T>(z1);
|
||||
out[size - tid] = static_cast<T>(z1);
|
||||
}
|
||||
|
||||
#define UNIFORM_OP(NAME, T) \
|
||||
|
@ -292,7 +292,7 @@ fn binary_ops_bf16() {
|
||||
binary_op!(max, |x: bf16, y| x.max(y));
|
||||
}
|
||||
|
||||
fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
fn cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
@ -319,189 +319,107 @@ fn run_cast<T: Clone, U: Clone>(v: &[T], name: &'static str) -> Vec<U> {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_f32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn cast_u32_f32() {
|
||||
let v = vec![1u32, 2, 3];
|
||||
let results = cast(&v, "cast_u32_f32");
|
||||
let expected: Vec<_> = v.iter().map(|&v| v as f32).collect();
|
||||
assert_eq!(approx(results, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
assert_eq!(approx(expected, 4), vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
// f32 -> f16
|
||||
let results: Vec<half::f16> = run_cast(&v_f32, "cast_f32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
let v = vec![1.0f32, 2.0, 3.0];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results, vec![1.0f32, 2.0, 3.0]);
|
||||
|
||||
// f32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f32, "cast_f32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f32 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f32, "cast_f32_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f32, "cast_f32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f32, "cast_f32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
let v = vec![1.0f32; 10_000];
|
||||
let input: Vec<f16> = v.iter().map(|v| f16::from_f32(*v)).collect();
|
||||
let results: Vec<f32> = cast(&input, "cast_f16_f32");
|
||||
assert_eq!(results.len(), 10_000);
|
||||
assert_eq!(&results[..10], vec![1.0f32; 10]);
|
||||
assert_eq!(results, vec![1.0f32; 10_000]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_f16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_bf16_u32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// f16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_f16, "cast_f16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<u32> = cast(&input, "cast_bf16_u32");
|
||||
let expected: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
// f16 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_f16, "cast_f16_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// f16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_f16, "cast_f16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// f16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_f16, "cast_f16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// f16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_f16, "cast_f16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_bf16() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_bf16_f32() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// bf16 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_bf16, "cast_bf16_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<f32> = cast(&input, "cast_bf16_f32");
|
||||
let expected: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
// bf16 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_bf16, "cast_bf16_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// bf16 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_bf16, "cast_bf16_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// bf16 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_bf16, "cast_bf16_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// bf16 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_bf16, "cast_bf16_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u32() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_u8_bf16() {
|
||||
let input: Vec<u8> = (1..=3).map(|v| v as u8).collect();
|
||||
|
||||
// u32 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u32, "cast_u32_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_u8_bf16");
|
||||
let expected: Vec<bf16> = input
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v as f32))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// u32 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u32, "cast_u32_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u32 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u32, "cast_u32_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u32 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_u32, "cast_u32_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
|
||||
// u32 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u32, "cast_u32_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_u8() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_u32_bf16() {
|
||||
let input: Vec<u32> = (1..=3).map(|v| v as u32).collect();
|
||||
|
||||
// u8 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_u8, "cast_u8_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_u32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
// u8 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_u8, "cast_u8_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
|
||||
// u8 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_u8, "cast_u8_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
|
||||
// u8 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_u8, "cast_u8_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
|
||||
// u8 -> i64
|
||||
let results: Vec<i64> = run_cast(&v_u8, "cast_u8_i64");
|
||||
assert_eq!(results, v_i64);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cast_i64() {
|
||||
let v_f64 = vec![1.0f64, 2.0, 3.0];
|
||||
let v_f32: Vec<f32> = v_f64.iter().map(|&v| v as f32).collect();
|
||||
let v_f16: Vec<f16> = v_f64.iter().map(|&v| f16::from_f32(v as f32)).collect();
|
||||
let v_bf16: Vec<bf16> = v_f64.iter().map(|&v| bf16::from_f32(v as f32)).collect();
|
||||
let v_u32: Vec<u32> = v_f64.iter().map(|&v| v as u32).collect();
|
||||
let v_u8: Vec<u8> = v_f64.iter().map(|&v| v as u8).collect();
|
||||
let v_i64: Vec<i64> = v_f64.iter().map(|&v| v as i64).collect();
|
||||
fn it_cast_f32_bf16() {
|
||||
let input: Vec<f32> = (1..=3).map(|v| v as f32).collect();
|
||||
|
||||
// i64 -> f32
|
||||
let results: Vec<f32> = run_cast(&v_i64, "cast_i64_f32");
|
||||
assert_eq!(results, v_f32);
|
||||
let output: Vec<bf16> = cast(&input, "cast_f32_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(*v as f32)).collect();
|
||||
|
||||
// i64 -> f16
|
||||
let results: Vec<f16> = run_cast(&v_i64, "cast_i64_f16");
|
||||
assert_eq!(results, v_f16);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
// i64 -> bf16
|
||||
let results: Vec<bf16> = run_cast(&v_i64, "cast_i64_bf16");
|
||||
assert_eq!(results, v_bf16);
|
||||
#[test]
|
||||
fn it_cast_bf16_u8() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
// i64 -> u32
|
||||
let results: Vec<u32> = run_cast(&v_i64, "cast_i64_u32");
|
||||
assert_eq!(results, v_u32);
|
||||
let output: Vec<u8> = cast(&input, "cast_bf16_u8");
|
||||
let expected: Vec<u8> = input.iter().map(|v| v.to_f32() as u8).collect();
|
||||
|
||||
// i64 -> u8
|
||||
let results: Vec<u8> = run_cast(&v_i64, "cast_i64_u8");
|
||||
assert_eq!(results, v_u8);
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_bf16_f16() {
|
||||
let input: Vec<bf16> = (1..=3).map(|v| bf16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<f16> = cast(&input, "cast_bf16_f16");
|
||||
let expected: Vec<f16> = input.iter().map(|v| f16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn it_cast_f16_bf16() {
|
||||
let input: Vec<f16> = (1..=3).map(|v| f16::from_f32(v as f32)).collect();
|
||||
|
||||
let output: Vec<bf16> = cast(&input, "cast_f16_bf16");
|
||||
let expected: Vec<bf16> = input.iter().map(|v| bf16::from_f32(v.to_f32())).collect();
|
||||
|
||||
assert_eq!(output, expected);
|
||||
}
|
||||
|
||||
fn run_affine<T: Clone>(v: &[T], mul: f64, add: f64) -> Vec<T> {
|
||||
@ -1148,107 +1066,3 @@ fn random() {
|
||||
validate_random!(f16);
|
||||
validate_random!(bf16);
|
||||
}
|
||||
|
||||
fn run_scatter_add<T: Clone, I: Clone + std::fmt::Debug>(
|
||||
input: &[T],
|
||||
ids: &[I],
|
||||
shape: &[usize],
|
||||
dim: usize,
|
||||
name: &'static str,
|
||||
) -> Vec<T> {
|
||||
let device = device();
|
||||
let kernels = Kernels::new();
|
||||
let command_queue = device.new_command_queue();
|
||||
let command_buffer = command_queue.new_command_buffer();
|
||||
let options = MTLResourceOptions::StorageModeManaged;
|
||||
let input_buffer = new_buffer(&device, input);
|
||||
let ids_buffer = new_buffer(&device, ids);
|
||||
let output = device.new_buffer(std::mem::size_of_val(input) as u64, options);
|
||||
call_scatter_add(
|
||||
&device,
|
||||
command_buffer,
|
||||
&kernels,
|
||||
name,
|
||||
shape,
|
||||
shape,
|
||||
dim,
|
||||
&input_buffer,
|
||||
0,
|
||||
&ids_buffer,
|
||||
0,
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
command_buffer.commit();
|
||||
command_buffer.wait_until_completed();
|
||||
read_to_vec(&output, input.len())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scatter_add() {
|
||||
let ids_u8 = [0u8, 0, 1, 0, 2, 2, 3, 3];
|
||||
let ids_u32 = [0u32, 0, 1, 0, 2, 2, 3, 3];
|
||||
let ids_i64 = [0i64, 0, 1, 0, 2, 2, 3, 3];
|
||||
|
||||
let input_f32 = [5.0f32, 1.0, 7.0, 2.0, 3.0, 2.0, 1.0, 3.0];
|
||||
let input_f16 = input_f32
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let input_bf16 = input_f32
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output_dim1_f32 = vec![8.0, 7.0, 5.0, 4.0, 0.0, 0.0, 0.0, 0.0];
|
||||
let output_dim1_f16 = output_dim1_f32
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let output_dim1_bf16 = output_dim1_f32
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let output_dim2_f32 = vec![5.0, 3.0, 7.0, 0.0, 3.0, 2.0, 1.0, 3.0];
|
||||
let output_dim2_f16 = output_dim2_f32
|
||||
.iter()
|
||||
.map(|v| f16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
let output_dim2_bf16 = output_dim2_f32
|
||||
.iter()
|
||||
.map(|v| bf16::from_f32(*v))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
for (shape, output_f32, output_f16, output_bf16) in [
|
||||
(vec![8], output_dim1_f32, output_dim1_f16, output_dim1_bf16),
|
||||
(
|
||||
vec![4, 2],
|
||||
output_dim2_f32,
|
||||
output_dim2_f16,
|
||||
output_dim2_bf16,
|
||||
),
|
||||
] {
|
||||
for results in [
|
||||
run_scatter_add(&input_f32, &ids_u8, &shape, 0, "sa_u8_f32"),
|
||||
run_scatter_add(&input_f32, &ids_u32, &shape, 0, "sa_u32_f32"),
|
||||
run_scatter_add(&input_f32, &ids_i64, &shape, 0, "sa_i64_f32"),
|
||||
] {
|
||||
assert_eq!(results, output_f32);
|
||||
}
|
||||
for results in [
|
||||
run_scatter_add(&input_f16, &ids_u8, &shape, 0, "sa_u8_f16"),
|
||||
run_scatter_add(&input_f16, &ids_u32, &shape, 0, "sa_u32_f16"),
|
||||
run_scatter_add(&input_f16, &ids_i64, &shape, 0, "sa_i64_f16"),
|
||||
] {
|
||||
assert_eq!(results, output_f16);
|
||||
}
|
||||
for results in [
|
||||
run_scatter_add(&input_bf16, &ids_u8, &shape, 0, "sa_u8_bf16"),
|
||||
run_scatter_add(&input_bf16, &ids_u32, &shape, 0, "sa_u32_bf16"),
|
||||
run_scatter_add(&input_bf16, &ids_i64, &shape, 0, "sa_i64_bf16"),
|
||||
] {
|
||||
assert_eq!(results, output_bf16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -102,30 +102,6 @@ UNARY(NAME, half, NAME##_f16, NAME##_f16_strided);
|
||||
#define BFLOAT_UNARY_OP(NAME) \
|
||||
UNARY(NAME, bfloat, NAME##_bf16, NAME##_bf16_strided);
|
||||
|
||||
#define COPY2D(FN_NAME, TYPENAME) \
|
||||
kernel void FN_NAME( \
|
||||
constant size_t &d1, \
|
||||
constant size_t &d2, \
|
||||
constant size_t &src_s, \
|
||||
constant size_t &dst_s, \
|
||||
device const TYPENAME *input, \
|
||||
device TYPENAME *output, \
|
||||
uint tid [[ thread_position_in_grid ]] \
|
||||
) { \
|
||||
if (tid >= d1 * d2) { \
|
||||
return; \
|
||||
} \
|
||||
size_t idx1 = tid / d2; \
|
||||
size_t idx2 = tid - idx1 * d2; \
|
||||
size_t src_idx = idx1 * src_s + idx2; \
|
||||
size_t dst_idx = idx1 * dst_s + idx2; \
|
||||
output[dst_idx] = input[src_idx]; \
|
||||
}
|
||||
|
||||
COPY2D(copy2d_f32, float)
|
||||
COPY2D(copy2d_f16, half)
|
||||
COPY2D(copy2d_u8, uint8_t)
|
||||
COPY2D(copy2d_u32, uint32_t)
|
||||
|
||||
UNARY_OP(cos)
|
||||
UNARY_OP(sin)
|
||||
@ -152,7 +128,6 @@ UNARY(id, uint32_t, copy_u32, copy_u32_strided)
|
||||
|
||||
#if __METAL_VERSION__ >= 220
|
||||
UNARY(id, int64_t, copy_i64, copy_i64_strided)
|
||||
COPY2D(copy2d_i64, int64_t)
|
||||
#endif
|
||||
|
||||
#if defined(__HAVE_BFLOAT__)
|
||||
@ -176,6 +151,4 @@ BFLOAT_UNARY_OP(recip)
|
||||
BFLOAT_UNARY_OP(relu)
|
||||
|
||||
UNARY(id, bfloat, copy_bf16, copy_bf16_strided)
|
||||
|
||||
COPY2D(copy2d_bf64, bfloat)
|
||||
#endif
|
||||
|
@ -238,23 +238,6 @@ impl Benchmark for QMatMul {
|
||||
const ITERS: usize = 100;
|
||||
}
|
||||
|
||||
struct Cat;
|
||||
impl Benchmark for Cat {
|
||||
type PreProcessData = (Tensor, Tensor);
|
||||
type RunResult = Tensor;
|
||||
fn preprocess() -> Result<Self::PreProcessData> {
|
||||
let lhs = Tensor::randn(0f32, 1., (1, 32, 2000, 128), &Device::Cpu)?;
|
||||
let rhs = Tensor::randn(0f32, 1., (1, 32, 1, 128), &Device::Cpu)?;
|
||||
Ok((lhs, rhs))
|
||||
}
|
||||
|
||||
fn run_one(d: &Self::PreProcessData) -> Result<Self::RunResult> {
|
||||
Tensor::cat(&[&d.0, &d.1], 2)
|
||||
}
|
||||
|
||||
const ITERS: usize = 1000;
|
||||
}
|
||||
|
||||
struct Softmax;
|
||||
impl Benchmark for Softmax {
|
||||
type PreProcessData = Tensor;
|
||||
@ -312,7 +295,6 @@ enum Task {
|
||||
Qmatmul,
|
||||
Softmax,
|
||||
SoftmaxLastDim,
|
||||
Cat,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
@ -337,7 +319,6 @@ fn main() -> Result<()> {
|
||||
Task::Softmax => run::<Softmax>(args.iters)?,
|
||||
Task::SoftmaxLastDim => run::<SoftmaxLastDim>(args.iters)?,
|
||||
Task::Qmatmul => run::<QMatMul>(args.iters)?,
|
||||
Task::Cat => run::<Cat>(args.iters)?,
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -5,7 +5,6 @@ use serde::Deserialize;
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum Activation {
|
||||
#[default]
|
||||
#[serde(alias = "gelu")]
|
||||
Gelu,
|
||||
#[serde(alias = "gelu_new")]
|
||||
NewGelu,
|
||||
@ -20,8 +19,6 @@ pub enum Activation {
|
||||
HardSwish,
|
||||
Elu(f64),
|
||||
LeakyRelu(f64),
|
||||
#[serde(alias = "gelu_pytorch_tanh")]
|
||||
GeluPytorchTanh,
|
||||
}
|
||||
|
||||
impl super::Module for Activation {
|
||||
@ -41,7 +38,6 @@ impl super::Module for Activation {
|
||||
Self::HardSwish => xs * crate::ops::hard_sigmoid(xs)?,
|
||||
&Self::Elu(alpha) => xs.elu(alpha),
|
||||
&Self::LeakyRelu(negative_slope) => crate::ops::leaky_relu(xs, negative_slope),
|
||||
Self::GeluPytorchTanh => xs.gelu(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -76,7 +76,7 @@ pub struct ConvTranspose1dConfig {
|
||||
pub output_padding: usize,
|
||||
pub stride: usize,
|
||||
pub dilation: usize,
|
||||
pub groups: usize,
|
||||
// TODO: support groups.
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose1dConfig {
|
||||
@ -86,7 +86,6 @@ impl Default for ConvTranspose1dConfig {
|
||||
output_padding: 0,
|
||||
stride: 1,
|
||||
dilation: 1,
|
||||
groups: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -110,14 +109,6 @@ impl ConvTranspose1d {
|
||||
pub fn config(&self) -> &ConvTranspose1dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for ConvTranspose1d {
|
||||
@ -128,13 +119,12 @@ impl crate::Module for ConvTranspose1d {
|
||||
self.config.output_padding,
|
||||
self.config.stride,
|
||||
self.config.dilation,
|
||||
self.config.groups,
|
||||
)?;
|
||||
match &self.bias {
|
||||
None => Ok(x),
|
||||
Some(bias) => {
|
||||
let b = bias.dims1()?;
|
||||
let bias = bias.reshape((1, b, 1))?;
|
||||
let bias = bias.reshape((1, b, 1, 1))?;
|
||||
Ok(x.broadcast_add(&bias)?)
|
||||
}
|
||||
}
|
||||
@ -268,14 +258,6 @@ impl ConvTranspose2d {
|
||||
pub fn config(&self) -> &ConvTranspose2dConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn weight(&self) -> &Tensor {
|
||||
&self.weight
|
||||
}
|
||||
|
||||
pub fn bias(&self) -> Option<&Tensor> {
|
||||
self.bias.as_ref()
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Module for ConvTranspose2d {
|
||||
@ -348,11 +330,7 @@ pub fn conv_transpose1d(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
||||
let bs = vb.get_with_hints(out_channels, "bias", init)?;
|
||||
Ok(ConvTranspose1d::new(ws, Some(bs), cfg))
|
||||
}
|
||||
@ -369,11 +347,7 @@ pub fn conv_transpose1d_no_bias(
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let ws = vb.get_with_hints(
|
||||
(in_channels, out_channels / cfg.groups, kernel_size),
|
||||
"weight",
|
||||
init,
|
||||
)?;
|
||||
let ws = vb.get_with_hints((in_channels, out_channels, kernel_size), "weight", init)?;
|
||||
Ok(ConvTranspose1d::new(ws, None, cfg))
|
||||
}
|
||||
|
||||
|
@ -28,7 +28,7 @@ pub use func::{func, func_t, Func, FuncT};
|
||||
pub use group_norm::{group_norm, GroupNorm};
|
||||
pub use init::Init;
|
||||
pub use layer_norm::{layer_norm, rms_norm, LayerNorm, LayerNormConfig, RmsNorm};
|
||||
pub use linear::{linear, linear_b, linear_no_bias, Linear};
|
||||
pub use linear::{linear, linear_no_bias, Linear};
|
||||
pub use ops::Dropout;
|
||||
pub use optim::{AdamW, Optimizer, ParamsAdamW, SGD};
|
||||
pub use rnn::{gru, lstm, GRUConfig, LSTMConfig, GRU, LSTM, RNN};
|
||||
|
@ -57,34 +57,21 @@ impl super::Module for Linear {
|
||||
/// Create or initialize a new linear layer.
|
||||
///
|
||||
/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
|
||||
pub fn linear(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let bound = 1. / (in_dim as f64).sqrt();
|
||||
let init_bs = crate::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
|
||||
let bs = vs.get_with_hints(out_dim, "bias", init_bs)?;
|
||||
Ok(Linear::new(ws, Some(bs)))
|
||||
}
|
||||
|
||||
/// Create or initialize a new linear layer without biases.
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
|
||||
pub fn linear_no_bias(in_dim: usize, out_dim: usize, vs: crate::VarBuilder) -> Result<Linear> {
|
||||
let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
let ws = vs.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
|
||||
Ok(Linear::new(ws, None))
|
||||
}
|
||||
|
||||
pub fn linear_b(
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
bias: bool,
|
||||
vb: crate::VarBuilder,
|
||||
) -> Result<Linear> {
|
||||
if bias {
|
||||
linear(in_dim, out_dim, vb)
|
||||
} else {
|
||||
linear_no_bias(in_dim, out_dim, vb)
|
||||
}
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ pub fn dropout(xs: &Tensor, drop_p: f32) -> Result<Tensor> {
|
||||
xs * mask
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
#[derive(Debug)]
|
||||
pub struct Dropout {
|
||||
drop_p: f32,
|
||||
}
|
||||
@ -238,8 +238,7 @@ impl candle::CustomOp1 for SoftmaxLastDim {
|
||||
&output,
|
||||
)
|
||||
.unwrap();
|
||||
let newstorage =
|
||||
candle::MetalStorage::new(output, device.clone(), elem_count, storage.dtype());
|
||||
let newstorage = candle::MetalStorage::new(output, device.clone(), storage.dtype());
|
||||
Ok((newstorage, layout.shape().clone()))
|
||||
}
|
||||
}
|
||||
|
@ -197,7 +197,7 @@ impl RNN for LSTM {
|
||||
|
||||
fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
|
||||
let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
|
||||
Tensor::stack(&states, 1)
|
||||
Tensor::cat(&states, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -70,7 +70,7 @@ impl VarMap {
|
||||
///
|
||||
/// If an error is returned, some of the variables might have already been set to their new
|
||||
/// values.
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<str>, V: AsRef<Tensor>>(
|
||||
pub fn set<I: Iterator<Item = (K, V)>, K: AsRef<String>, V: AsRef<Tensor>>(
|
||||
&mut self,
|
||||
iter: I,
|
||||
) -> Result<()> {
|
||||
|
@ -7,7 +7,7 @@ extern crate accelerate_src;
|
||||
use candle::test_utils::{to_vec0_round, to_vec2_round};
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, Device, Tensor, Var};
|
||||
use candle::{Device, Tensor, Var};
|
||||
use candle_nn::{AdamW, Linear, Module, Optimizer, ParamsAdamW, SGD};
|
||||
|
||||
#[test]
|
||||
@ -121,40 +121,3 @@ fn adamw_linear_regression() -> Result<()> {
|
||||
assert_eq!(to_vec0_round(b.as_tensor(), 4)?, 0.7873);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn adamw_linear_regression_varmap() -> Result<()> {
|
||||
use candle_nn::Init::Const;
|
||||
|
||||
// Similar as the previous test but using a VarMap.
|
||||
let w_gen = Tensor::new(&[[3f32, 1.]], &Device::Cpu)?;
|
||||
let b_gen = Tensor::new(-2f32, &Device::Cpu)?;
|
||||
let gen = Linear::new(w_gen, Some(b_gen));
|
||||
let sample_xs = Tensor::new(&[[2f32, 1.], [7., 4.], [-4., 12.], [5., 8.]], &Device::Cpu)?;
|
||||
let sample_ys = gen.forward(&sample_xs)?;
|
||||
|
||||
let mut var_map = candle_nn::VarMap::new();
|
||||
|
||||
let w = var_map.get((1, 2), "w", Const(0.), DType::F32, &Device::Cpu)?;
|
||||
let b = var_map.get((), "b", Const(0.), DType::F32, &Device::Cpu)?;
|
||||
let params = ParamsAdamW {
|
||||
lr: 0.1,
|
||||
..Default::default()
|
||||
};
|
||||
let mut opt = AdamW::new(var_map.all_vars(), params)?;
|
||||
let lin = Linear::new(w, Some(b));
|
||||
for _step in 0..100 {
|
||||
let ys = lin.forward(&sample_xs)?;
|
||||
let loss = ys.sub(&sample_ys)?.sqr()?.sum_all()?;
|
||||
opt.backward_step(&loss)?;
|
||||
}
|
||||
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[2.7257, 0.7097]]);
|
||||
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 0.7873);
|
||||
|
||||
var_map.set([("w", Tensor::zeros((1, 2), DType::F32, &Device::Cpu)?)].into_iter())?;
|
||||
var_map.set([("b", Tensor::ones((), DType::F32, &Device::Cpu)?)].into_iter())?;
|
||||
|
||||
assert_eq!(to_vec2_round(lin.weight(), 4)?, &[[0., 0.]]);
|
||||
assert_eq!(to_vec0_round(lin.bias().unwrap(), 4)?, 1.);
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,6 +1,6 @@
|
||||
[package]
|
||||
name = "candle-onnx"
|
||||
version = "0.4.2"
|
||||
version = "0.4.0"
|
||||
edition = "2021"
|
||||
|
||||
description = "ONNX support for Candle"
|
||||
@ -10,8 +10,8 @@ categories = ["science"]
|
||||
license = "MIT OR Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.2" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.2" }
|
||||
candle = { path = "../candle-core", package = "candle-core", version = "0.4.0" }
|
||||
candle-nn = { path = "../candle-nn", version = "0.4.0" }
|
||||
prost = "0.12.1"
|
||||
|
||||
[build-dependencies]
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user