mirror of
https://github.com/huggingface/candle.git
synced 2025-06-20 04:00:28 +00:00
Compare commits
13 Commits
fix-1.86
...
cuda-graph
Author | SHA1 | Date | |
---|---|---|---|
543b5b5898 | |||
c87f0fa5d6 | |||
eb478ece92 | |||
d339b01726 | |||
2f3bf42bcb | |||
e3370c6316 | |||
338f6a102e | |||
bc33df77e1 | |||
cf9d7bf24c | |||
9d31361c4f | |||
1bb68854d3 | |||
b2956857ef | |||
9076dee432 |
@ -42,7 +42,7 @@ clap = { workspace = true }
|
|||||||
criterion = { workspace = true }
|
criterion = { workspace = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = []
|
default = ["cuda"]
|
||||||
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
||||||
cudnn = ["cuda", "cudarc/cudnn"]
|
cudnn = ["cuda", "cudarc/cudnn"]
|
||||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||||
@ -56,3 +56,7 @@ harness = false
|
|||||||
[[example]]
|
[[example]]
|
||||||
name = "metal_basics"
|
name = "metal_basics"
|
||||||
required-features = ["metal"]
|
required-features = ["metal"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "cuda_basics"
|
||||||
|
required-features = ["cuda"]
|
||||||
|
@ -7,8 +7,79 @@ extern crate intel_mkl_src;
|
|||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use candle_core::{Device, Tensor};
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
const USE_CUDA_GRAPH: bool = true;
|
||||||
|
|
||||||
|
fn cuda_graph() -> Result<()> {
|
||||||
|
let device = Device::new_cuda_with_stream(0)?;
|
||||||
|
let cu_device = match &device {
|
||||||
|
Device::Cuda(dev) => dev,
|
||||||
|
_ => unreachable!(),
|
||||||
|
};
|
||||||
|
let cu_stream = cu_device.cuda_stream();
|
||||||
|
{
|
||||||
|
// load_ptx cannot be called while capturing the stream so we need this to happen
|
||||||
|
// beforehand.
|
||||||
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
||||||
|
let _x = x.affine(1., 0.5)?;
|
||||||
|
x.slice_set(&u, 0, 0)?;
|
||||||
|
device.synchronize()?;
|
||||||
|
}
|
||||||
|
if USE_CUDA_GRAPH {
|
||||||
|
cu_stream.begin_capture(
|
||||||
|
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
|
||||||
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
|
for _i in 0..100 {
|
||||||
|
// x.slice_set(&u, 0, 0)?;
|
||||||
|
// x.broadcast_add(&v)?;
|
||||||
|
x = x.affine(1., 0.5)?;
|
||||||
|
// x = (&u + &x)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if USE_CUDA_GRAPH {
|
||||||
|
println!("capturing graph");
|
||||||
|
let cu_graph = match cu_stream.end_capture(
|
||||||
|
cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
|
||||||
|
)? {
|
||||||
|
None => anyhow::bail!("no graph captured"),
|
||||||
|
Some(cu_graph) => cu_graph,
|
||||||
|
};
|
||||||
|
println!("graph captured!");
|
||||||
|
for i in 1..100 {
|
||||||
|
println!("graph exec {i}");
|
||||||
|
cu_graph.launch()?;
|
||||||
|
println!("sync");
|
||||||
|
if let Err(err) = device.synchronize() {
|
||||||
|
println!("err: {err:?}")
|
||||||
|
}
|
||||||
|
println!("done syncing");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
device.synchronize()?;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
let device = Device::new_cuda(0)?;
|
cuda_graph()?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
fn _matmul() -> Result<()> {
|
||||||
|
let device = Device::new_cuda_with_stream(0)?;
|
||||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||||
.to_dtype(candle_core::DType::BF16)?;
|
.to_dtype(candle_core::DType::BF16)?;
|
||||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||||
|
@ -816,7 +816,7 @@ impl PthTensors {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `path` - Path to the pth file.
|
/// * `path` - Path to the pth file.
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
key: Option<&str>,
|
key: Option<&str>,
|
||||||
|
@ -73,7 +73,7 @@ fn dequantize_f32(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = elem_count.div_ceil(256);
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||||
@ -133,7 +133,7 @@ fn dequantize_f16(
|
|||||||
elem_count: usize,
|
elem_count: usize,
|
||||||
dev: &CudaDevice,
|
dev: &CudaDevice,
|
||||||
) -> Result<CudaStorage> {
|
) -> Result<CudaStorage> {
|
||||||
let nb = (elem_count + 255) / 256;
|
let nb = elem_count.div_ceil(256);
|
||||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||||
@ -278,8 +278,8 @@ fn mul_mat_vec_via_q8_1(
|
|||||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||||
let (nblocks, nwarps) = match b_size {
|
let (nblocks, nwarps) = match b_size {
|
||||||
1 => (nrows as u32, 4),
|
1 => (nrows as u32, 4),
|
||||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
2..=4 => ((nrows as u32).div_ceil(2), 4),
|
||||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
5..=8 => ((nrows as u32).div_ceil(2), 2),
|
||||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||||
};
|
};
|
||||||
let cfg = cudarc::driver::LaunchConfig {
|
let cfg = cudarc::driver::LaunchConfig {
|
||||||
|
@ -69,6 +69,7 @@ metal = ["candle/metal", "candle-nn/metal"]
|
|||||||
microphone = ["cpal", "rubato"]
|
microphone = ["cpal", "rubato"]
|
||||||
encodec = ["cpal", "symphonia", "rubato"]
|
encodec = ["cpal", "symphonia", "rubato"]
|
||||||
mimi = ["cpal", "symphonia", "rubato"]
|
mimi = ["cpal", "symphonia", "rubato"]
|
||||||
|
snac = ["cpal", "symphonia", "rubato"]
|
||||||
depth_anything_v2 = ["palette", "enterpolation"]
|
depth_anything_v2 = ["palette", "enterpolation"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
@ -107,6 +108,10 @@ required-features = ["candle-datasets"]
|
|||||||
name = "mimi"
|
name = "mimi"
|
||||||
required-features = ["mimi"]
|
required-features = ["mimi"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "snac"
|
||||||
|
required-features = ["snac"]
|
||||||
|
|
||||||
[[example]]
|
[[example]]
|
||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["encodec"]
|
required-features = ["encodec"]
|
||||||
|
14
candle-examples/examples/csm/README.md
Normal file
14
candle-examples/examples/csm/README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Conversational Speech Model (CSM)
|
||||||
|
|
||||||
|
CSM is a speech generation model from Sesame,
|
||||||
|
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
|
||||||
|
|
||||||
|
It can generate a conversational speech between two different speakers.
|
||||||
|
The speakers turn are delimited by the `|` character in the prompt.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example csm --features cuda -r -- \
|
||||||
|
--voices candle-examples/examples/csm/voices.safetensors \
|
||||||
|
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
||||||
|
```
|
||||||
|
|
@ -34,9 +34,18 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
/// The prompt to be used for the generation, use a | to separate the speakers.
|
||||||
|
#[arg(long, default_value = "Hey how are you doing today?")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
|
/// The voices to be used, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
voices: String,
|
||||||
|
|
||||||
|
/// The output file using the wav format.
|
||||||
|
#[arg(long, default_value = "out.wav")]
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 0.7)]
|
#[arg(long, default_value_t = 0.7)]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
@ -162,7 +171,7 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (mut model, device) = {
|
let (mut model, device) = {
|
||||||
let dtype = DType::F32;
|
let dtype = device.bf16_default_to_f32();
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
(model, device)
|
(model, device)
|
||||||
@ -177,45 +186,58 @@ fn main() -> Result<()> {
|
|||||||
let cb = config.audio_num_codebooks;
|
let cb = config.audio_num_codebooks;
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
if args.prompt.ends_with(".safetensors") {
|
|
||||||
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
let voices = candle::safetensors::load(args.voices, &device)?;
|
||||||
let mut tokens = prompt
|
let mut lp = candle_transformers::generation::LogitsProcessor::new(
|
||||||
.get("tokens")
|
args.seed,
|
||||||
.expect("no tokens in prompt")
|
Some(args.temperature),
|
||||||
.to_dtype(DType::U32)?;
|
None,
|
||||||
let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
|
);
|
||||||
println!("tokens:\n{tokens:?}");
|
let tokens = voices
|
||||||
println!("mask:\n{mask:?}");
|
.get("tokens")
|
||||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None);
|
.expect("no tokens in prompt")
|
||||||
let mut const_mask = vec![1u8; cb];
|
.to_dtype(DType::U32)?;
|
||||||
const_mask.push(0);
|
let mask = voices.get("mask").expect("no mask in prompt").clone();
|
||||||
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
|
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
let mut all_tokens = vec![];
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
for i in 0.. {
|
pos += tokens.dim(1)?;
|
||||||
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
|
||||||
|
let mut all_pcms = vec![];
|
||||||
|
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
|
||||||
|
println!("{prompt:?}");
|
||||||
|
let speaker_idx = turn_idx % 2;
|
||||||
|
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
|
||||||
|
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
|
||||||
|
|
||||||
|
let mut generated_tokens = vec![];
|
||||||
|
loop {
|
||||||
|
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
pos += tokens.dim(1)?;
|
pos += tokens.dim(1)?;
|
||||||
frame.push(0);
|
let is_done = frame.iter().all(|&x| x == 0);
|
||||||
if frame.iter().all(|&x| x == 0) {
|
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
|
||||||
|
print!("\rframe {pos}");
|
||||||
|
if is_done {
|
||||||
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
println!("frame {i} {pos}:\n{frame:?}");
|
generated_tokens.push(tokens.clone());
|
||||||
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
|
|
||||||
all_tokens.push(tokens.clone());
|
|
||||||
mask = const_mask.clone();
|
|
||||||
}
|
}
|
||||||
let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
println!();
|
||||||
println!("all_tokens:\n{all_tokens:?}");
|
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||||
let pcm = mimi_model.decode(&all_tokens)?;
|
let pcm = mimi_model.decode(&generated_tokens)?;
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
all_pcms.push(pcm);
|
||||||
let mut output = std::fs::File::create("out.wav")?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
|
||||||
} else {
|
|
||||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
|
||||||
println!("{prompt:?}");
|
|
||||||
}
|
}
|
||||||
|
let pcm = Tensor::cat(&all_pcms, 0)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
println!("writing output file {}", args.out_file);
|
||||||
|
let mut output = std::fs::File::create(args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
Binary file not shown.
@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
|
|||||||
are downloaded from the hub on the first run.
|
are downloaded from the hub on the first run.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||||
|
|
||||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||||
@ -20,3 +20,25 @@ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
|||||||
> Tensor[[1, 7, 768], f32]
|
> Tensor[[1, 7, 768], f32]
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Masked Token
|
||||||
|
|
||||||
|
DistilBert is used to compute the top K choices for a masked token.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
|
||||||
|
|
||||||
|
> Input: The capital of France is [MASK].
|
||||||
|
> Predictions for [MASK] at position 6:
|
||||||
|
> 1: marseille (probability: 12.14%)
|
||||||
|
> 2: paris (probability: 10.84%)
|
||||||
|
> 3: toulouse (probability: 8.57%)
|
||||||
|
> 4: lyon (probability: 7.61%)
|
||||||
|
> 5: montpellier (probability: 5.18%)
|
||||||
|
> 6: bordeaux (probability: 4.88%)
|
||||||
|
> 7: nantes (probability: 4.82%)
|
||||||
|
> 8: lille (probability: 4.07%)
|
||||||
|
> 9: strasbourg (probability: 3.12%)
|
||||||
|
> 10: cannes (probability: 3.04%)
|
||||||
|
|
||||||
|
```
|
@ -3,15 +3,48 @@ extern crate intel_mkl_src;
|
|||||||
|
|
||||||
#[cfg(feature = "accelerate")]
|
#[cfg(feature = "accelerate")]
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
use candle_transformers::models::distilbert::{
|
||||||
|
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
|
||||||
|
};
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Context, Error as E, Result};
|
||||||
use candle::{Device, Tensor};
|
use candle::{Device, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||||
|
use std::path::PathBuf;
|
||||||
use tokenizers::Tokenizer;
|
use tokenizers::Tokenizer;
|
||||||
|
|
||||||
|
enum ModelType {
|
||||||
|
Masked(DistilBertForMaskedLM),
|
||||||
|
UnMasked(DistilBertModel),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelType {
|
||||||
|
fn device(&self) -> &Device {
|
||||||
|
match self {
|
||||||
|
ModelType::Masked(model) => &model.bert.device,
|
||||||
|
ModelType::UnMasked(model) => &model.device,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
match self {
|
||||||
|
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||||
|
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "distilbert")]
|
||||||
|
DistilBert,
|
||||||
|
|
||||||
|
#[value(name = "distilbertformaskedlm")]
|
||||||
|
DistilbertForMaskedLM,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
@ -23,10 +56,14 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
tracing: bool,
|
tracing: bool,
|
||||||
|
|
||||||
|
#[arg(long, default_value = "distilbert")]
|
||||||
|
model: Which,
|
||||||
|
|
||||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
model_id: Option<String>,
|
model_id: Option<String>,
|
||||||
|
|
||||||
|
/// Revision or branch
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
revision: Option<String>,
|
revision: Option<String>,
|
||||||
|
|
||||||
@ -42,94 +79,246 @@ struct Args {
|
|||||||
#[arg(long, default_value = "1")]
|
#[arg(long, default_value = "1")]
|
||||||
n: usize,
|
n: usize,
|
||||||
|
|
||||||
/// L2 normalization for embeddings.
|
/// Number of top predictions to show for each mask
|
||||||
#[arg(long, default_value = "true")]
|
#[arg(long, default_value = "5")]
|
||||||
normalize_embeddings: bool,
|
top_k: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Args {
|
impl Args {
|
||||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
|
||||||
let device = candle_examples::device(self.cpu)?;
|
let device = candle_examples::device(self.cpu)?;
|
||||||
|
|
||||||
|
let (model_id, revision) = self.resolve_model_and_revision();
|
||||||
|
let (config_path, tokenizer_path, weights_path) =
|
||||||
|
self.download_model_files(&model_id, &revision)?;
|
||||||
|
|
||||||
|
let config = std::fs::read_to_string(config_path)?;
|
||||||
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
|
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let vb = self.load_variables(&weights_path, &device)?;
|
||||||
|
let model = self.create_model(&config, vb)?;
|
||||||
|
|
||||||
|
Ok((model, tokenizer))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_model_and_revision(&self) -> (String, String) {
|
||||||
let default_model = "distilbert-base-uncased".to_string();
|
let default_model = "distilbert-base-uncased".to_string();
|
||||||
let default_revision = "main".to_string();
|
let default_revision = "main".to_string();
|
||||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
|
||||||
|
match (self.model_id.clone(), self.revision.clone()) {
|
||||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
(Some(model_id), None) => (model_id, default_revision),
|
||||||
(None, Some(revision)) => (default_model, revision),
|
(None, Some(revision)) => (default_model, revision),
|
||||||
(None, None) => (default_model, default_revision),
|
(None, None) => (default_model, default_revision),
|
||||||
};
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
fn download_model_files(
|
||||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
&self,
|
||||||
let api = Api::new()?;
|
model_id: &str,
|
||||||
let api = api.repo(repo);
|
revision: &str,
|
||||||
let config = api.get("config.json")?;
|
) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
||||||
let tokenizer = api.get("tokenizer.json")?;
|
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
|
||||||
let weights = if self.use_pth {
|
let api = Api::new()?;
|
||||||
api.get("pytorch_model.bin")?
|
let api = api.repo(repo);
|
||||||
} else {
|
|
||||||
api.get("model.safetensors")?
|
|
||||||
};
|
|
||||||
(config, tokenizer, weights)
|
|
||||||
};
|
|
||||||
let config = std::fs::read_to_string(config_filename)?;
|
|
||||||
let config: Config = serde_json::from_str(&config)?;
|
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
|
||||||
|
|
||||||
let vb = if self.use_pth {
|
let config = api.get("config.json")?;
|
||||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
let tokenizer = api.get("tokenizer.json")?;
|
||||||
|
let weights = if self.use_pth {
|
||||||
|
api.get("pytorch_model.bin")?
|
||||||
} else {
|
} else {
|
||||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
api.get("model.safetensors")?
|
||||||
};
|
};
|
||||||
let model = DistilBertModel::load(vb, &config)?;
|
|
||||||
Ok((model, tokenizer))
|
Ok((config, tokenizer, weights))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
|
||||||
|
if self.use_pth {
|
||||||
|
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
|
||||||
|
} else {
|
||||||
|
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||||
|
match self.model {
|
||||||
|
Which::DistilbertForMaskedLM => {
|
||||||
|
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
||||||
|
}
|
||||||
|
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
fn main() -> Result<()> {
|
||||||
let mask: Vec<_> = (0..size)
|
let args = Args::parse();
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
let _guard = setup_tracing(&args);
|
||||||
.collect();
|
|
||||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
let (model, tokenizer) = args.build_model_and_tokenizer()?;
|
||||||
|
let device = model.device();
|
||||||
|
|
||||||
|
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
|
||||||
|
let output = model.forward(&token_ids, &mask)?;
|
||||||
|
|
||||||
|
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn setup_tracing(args: &Args) -> Option<impl Drop> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
if args.tracing {
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
|
use tracing_subscriber::prelude::*;
|
||||||
|
|
||||||
let args = Args::parse();
|
|
||||||
let _guard = if args.tracing {
|
|
||||||
println!("tracing...");
|
println!("tracing...");
|
||||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||||
tracing_subscriber::registry().with(chrome_layer).init();
|
tracing_subscriber::registry().with(chrome_layer).init();
|
||||||
Some(guard)
|
Some(guard)
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
}
|
||||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
}
|
||||||
let device = &model.device;
|
|
||||||
|
|
||||||
let tokenizer = tokenizer
|
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||||
|
let mut binding = tokenizer.clone();
|
||||||
|
let tokenizer_configured = binding
|
||||||
.with_padding(None)
|
.with_padding(None)
|
||||||
.with_truncation(None)
|
.with_truncation(None)
|
||||||
.map_err(E::msg)?;
|
.map_err(E::msg)?;
|
||||||
let tokens = tokenizer
|
|
||||||
.encode(args.prompt, true)
|
let tokens = tokenizer_configured
|
||||||
|
.encode(args.prompt.clone(), true)
|
||||||
.map_err(E::msg)?
|
.map_err(E::msg)?
|
||||||
.get_ids()
|
.get_ids()
|
||||||
.to_vec();
|
.to_vec();
|
||||||
|
|
||||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||||
let mask = get_mask(tokens.len(), device);
|
|
||||||
|
|
||||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
let mask = match args.model {
|
||||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
|
||||||
|
Which::DistilBert => attention_mask(tokens.len(), device)?,
|
||||||
|
};
|
||||||
|
|
||||||
let ys = model.forward(&token_ids, &mask)?;
|
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
|
||||||
println!("{ys}");
|
|
||||||
|
Ok((token_ids, mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn process_output(
|
||||||
|
model: &ModelType,
|
||||||
|
output: &Tensor,
|
||||||
|
token_ids: &Tensor,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
args: &Args,
|
||||||
|
) -> Result<()> {
|
||||||
|
match model {
|
||||||
|
ModelType::UnMasked(_) => {
|
||||||
|
println!("embeddings");
|
||||||
|
println!("{output}");
|
||||||
|
}
|
||||||
|
ModelType::Masked(_) => {
|
||||||
|
process_masked_output(output, token_ids, tokenizer, args)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
fn process_masked_output(
|
||||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
output: &Tensor,
|
||||||
|
token_ids: &Tensor,
|
||||||
|
tokenizer: &Tokenizer,
|
||||||
|
args: &Args,
|
||||||
|
) -> Result<()> {
|
||||||
|
let input_ids_vec = token_ids.to_vec2::<u32>()?;
|
||||||
|
let mask_token_id = tokenizer
|
||||||
|
.token_to_id("[MASK]")
|
||||||
|
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||||
|
|
||||||
|
println!("\nInput: {}", args.prompt);
|
||||||
|
|
||||||
|
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
|
||||||
|
if token_id == mask_token_id {
|
||||||
|
println!("Predictions for [MASK] at position {}:", token_idx);
|
||||||
|
|
||||||
|
let pos_logits = output.get(0)?.get(token_idx)?;
|
||||||
|
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
|
||||||
|
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
|
||||||
|
|
||||||
|
let values = top_values.to_vec1::<f32>()?;
|
||||||
|
let indices = top_indices.to_vec1::<u32>()?;
|
||||||
|
|
||||||
|
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
|
||||||
|
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
|
||||||
|
println!(
|
||||||
|
" {}: {:15} (probability: {:.2}%)",
|
||||||
|
i + 1,
|
||||||
|
token,
|
||||||
|
prob * 100.0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
|
||||||
|
let n = tensor.dims().iter().product::<usize>();
|
||||||
|
let k = std::cmp::min(k, n);
|
||||||
|
|
||||||
|
let values = tensor.to_vec1::<f32>()?;
|
||||||
|
let mut value_indices: Vec<(f32, usize)> = values
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(idx, val)| (val, idx))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
|
||||||
|
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
|
||||||
|
let top_k_indices: Vec<u32> = value_indices
|
||||||
|
.iter()
|
||||||
|
.take(k)
|
||||||
|
.map(|(_, idx)| *idx as u32)
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let device = tensor.device();
|
||||||
|
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
|
||||||
|
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
|
||||||
|
|
||||||
|
Ok((top_values, top_indices))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||||
|
let mask: Vec<_> = (0..size)
|
||||||
|
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||||
|
.collect();
|
||||||
|
Ok(Tensor::from_slice(&mask, (size, size), device)?)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
|
||||||
|
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
|
||||||
|
let seq_len = tokens.get_attention_mask().to_vec().len();
|
||||||
|
|
||||||
|
let mask_token_id = tokenizer
|
||||||
|
.token_to_id("[MASK]")
|
||||||
|
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||||
|
|
||||||
|
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
|
||||||
|
|
||||||
|
let ids = tokens.get_ids();
|
||||||
|
for _ in 0..seq_len {
|
||||||
|
for id in ids.iter() {
|
||||||
|
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
|
||||||
|
attention_mask_vec.push(mask_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let shape = (1, 1, seq_len, seq_len);
|
||||||
|
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
|
||||||
|
|
||||||
|
Ok(mask)
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_conv(&self) -> usize {
|
fn d_conv(&self) -> usize {
|
||||||
|
275
candle-examples/examples/snac/audio_io.rs
Normal file
275
candle-examples/examples/snac/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
|||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::sync::{Arc, Mutex};
|
||||||
|
|
||||||
|
pub const SAMPLE_RATE: usize = 24_000;
|
||||||
|
|
||||||
|
pub(crate) struct AudioOutputData_ {
|
||||||
|
resampled_data: std::collections::VecDeque<f32>,
|
||||||
|
resampler: rubato::FastFixedIn<f32>,
|
||||||
|
output_buffer: Vec<f32>,
|
||||||
|
input_buffer: Vec<f32>,
|
||||||
|
input_len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AudioOutputData_ {
|
||||||
|
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||||
|
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||||
|
let resampler = rubato::FastFixedIn::new(
|
||||||
|
resample_ratio,
|
||||||
|
f64::max(resample_ratio, 1.0),
|
||||||
|
rubato::PolynomialDegree::Septic,
|
||||||
|
1024,
|
||||||
|
1,
|
||||||
|
)?;
|
||||||
|
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||||
|
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||||
|
Ok(Self {
|
||||||
|
resampled_data,
|
||||||
|
resampler,
|
||||||
|
input_buffer,
|
||||||
|
output_buffer,
|
||||||
|
input_len: 0,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
use rubato::Resampler;
|
||||||
|
self.output_buffer.fill(0.);
|
||||||
|
self.input_buffer.fill(0.);
|
||||||
|
self.resampler.reset();
|
||||||
|
self.resampled_data.clear();
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||||
|
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||||
|
while let Some(elem) = self.resampled_data.pop_back() {
|
||||||
|
data.push(elem);
|
||||||
|
}
|
||||||
|
data
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_empty(&self) -> bool {
|
||||||
|
self.resampled_data.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Assumes that the input buffer is large enough.
|
||||||
|
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||||
|
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||||
|
self.input_len += samples.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pos_in = 0;
|
||||||
|
loop {
|
||||||
|
let rem = self.input_buffer.len() - self.input_len;
|
||||||
|
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||||
|
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||||
|
pos_in = pos_end;
|
||||||
|
if self.input_len < self.input_buffer.len() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
let (_, out_len) = self.resampler.process_into_buffer(
|
||||||
|
&[&self.input_buffer],
|
||||||
|
&mut [&mut self.output_buffer],
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
for &elem in self.output_buffer[..out_len].iter() {
|
||||||
|
self.resampled_data.push_front(elem)
|
||||||
|
}
|
||||||
|
self.input_len = 0;
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||||
|
|
||||||
|
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||||
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
|
|
||||||
|
println!("Setup audio output stream!");
|
||||||
|
let host = cpal::default_host();
|
||||||
|
let device = host
|
||||||
|
.default_output_device()
|
||||||
|
.context("no output device available")?;
|
||||||
|
let mut supported_configs_range = device.supported_output_configs()?;
|
||||||
|
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||||
|
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||||
|
None => device
|
||||||
|
.supported_output_configs()?
|
||||||
|
.next()
|
||||||
|
.context("no audio output available")?,
|
||||||
|
Some(config_range) => config_range,
|
||||||
|
};
|
||||||
|
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||||
|
config_range.min_sample_rate(),
|
||||||
|
config_range.max_sample_rate(),
|
||||||
|
);
|
||||||
|
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||||
|
let channels = config.channels as usize;
|
||||||
|
println!(
|
||||||
|
"cpal device: {} {} {config:?}",
|
||||||
|
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||||
|
config.sample_rate.0
|
||||||
|
);
|
||||||
|
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||||
|
SAMPLE_RATE,
|
||||||
|
config.sample_rate.0 as usize,
|
||||||
|
)?));
|
||||||
|
let ad = audio_data.clone();
|
||||||
|
let stream = device.build_output_stream(
|
||||||
|
&config,
|
||||||
|
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||||
|
data.fill(0.);
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
let mut last_elem = 0f32;
|
||||||
|
for (idx, elem) in data.iter_mut().enumerate() {
|
||||||
|
if idx % channels == 0 {
|
||||||
|
match ad.resampled_data.pop_back() {
|
||||||
|
None => break,
|
||||||
|
Some(v) => {
|
||||||
|
last_elem = v;
|
||||||
|
*elem = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
*elem = last_elem
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| eprintln!("cpal error: {err}"),
|
||||||
|
None, // None=blocking, Some(Duration)=timeout
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
Ok((stream, audio_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||||
|
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||||
|
|
||||||
|
println!("Setup audio input stream!");
|
||||||
|
let host = cpal::default_host();
|
||||||
|
let device = host
|
||||||
|
.default_input_device()
|
||||||
|
.context("no input device available")?;
|
||||||
|
let mut supported_configs_range = device.supported_input_configs()?;
|
||||||
|
let config_range = supported_configs_range
|
||||||
|
.find(|c| c.channels() == 1)
|
||||||
|
.context("no audio input available")?;
|
||||||
|
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||||
|
config_range.min_sample_rate(),
|
||||||
|
config_range.max_sample_rate(),
|
||||||
|
);
|
||||||
|
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||||
|
println!(
|
||||||
|
"cpal device: {} {} {config:?}",
|
||||||
|
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||||
|
config.sample_rate.0
|
||||||
|
);
|
||||||
|
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||||
|
config.sample_rate.0 as usize,
|
||||||
|
SAMPLE_RATE,
|
||||||
|
)?));
|
||||||
|
let ad = audio_data.clone();
|
||||||
|
let stream = device.build_input_stream(
|
||||||
|
&config,
|
||||||
|
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
if let Err(err) = ad.push_samples(data) {
|
||||||
|
eprintln!("error processing audio input {err:?}")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
move |err| eprintln!("cpal error: {err}"),
|
||||||
|
None, // None=blocking, Some(Duration)=timeout
|
||||||
|
)?;
|
||||||
|
stream.play()?;
|
||||||
|
Ok((stream, audio_data))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||||
|
where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
use symphonia::core::audio::Signal;
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
|
||||||
|
let src = std::fs::File::open(path)?;
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||||
|
let mut format = probed.format;
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||||
|
.expect("no supported audio tracks");
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &Default::default())
|
||||||
|
.expect("unsupported codec");
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||||
|
use rubato::Resampler;
|
||||||
|
|
||||||
|
let mut pcm_out =
|
||||||
|
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||||
|
|
||||||
|
let mut resampler =
|
||||||
|
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
|
||||||
|
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||||
|
let mut pos_in = 0;
|
||||||
|
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||||
|
let (in_len, out_len) =
|
||||||
|
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||||
|
pos_in += in_len;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if pos_in < pcm_in.len() {
|
||||||
|
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||||
|
Some(&[&pcm_in[pos_in..]]),
|
||||||
|
&mut output_buffer,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(pcm_out)
|
||||||
|
}
|
197
candle-examples/examples/snac/main.rs
Normal file
197
candle-examples/examples/snac/main.rs
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
#[cfg(feature = "mkl")]
|
||||||
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
|
use anyhow::Result;
|
||||||
|
use candle::{DType, IndexOp, Tensor};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
use candle_transformers::models::snac::{Config, Model};
|
||||||
|
use clap::{Parser, ValueEnum};
|
||||||
|
use hf_hub::api::sync::Api;
|
||||||
|
|
||||||
|
mod audio_io;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||||
|
enum Action {
|
||||||
|
AudioToAudio,
|
||||||
|
AudioToCode,
|
||||||
|
CodeToAudio,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
|
enum Which {
|
||||||
|
#[value(name = "24khz")]
|
||||||
|
S24khz,
|
||||||
|
#[value(name = "32khz")]
|
||||||
|
S32khz,
|
||||||
|
#[value(name = "44khz")]
|
||||||
|
S44khz,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Which {
|
||||||
|
fn sample_rate(&self) -> u32 {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => 24000,
|
||||||
|
Which::S32khz => 32000,
|
||||||
|
Which::S44khz => 44000,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn config_repo(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => "hubertsiuzdak/snac_24khz",
|
||||||
|
Which::S32khz => "hubertsiuzdak/snac_32khz",
|
||||||
|
Which::S44khz => "hubertsiuzdak/snac_44khz",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn model_file(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
Which::S24khz => "snac_24khz.safetensors",
|
||||||
|
Which::S32khz => "snac_32khz.safetensors",
|
||||||
|
Which::S44khz => "snac_44khz.safetensors",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Parser, Debug)]
|
||||||
|
#[command(author, version, about, long_about = None)]
|
||||||
|
struct Args {
|
||||||
|
/// The action to be performed, specifies the format for the input and output data.
|
||||||
|
action: Action,
|
||||||
|
|
||||||
|
/// The input file, either an audio file or some snac tokens stored as safetensors.
|
||||||
|
in_file: String,
|
||||||
|
|
||||||
|
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
|
/// The model size to use.
|
||||||
|
#[arg(long, default_value = "24khz")]
|
||||||
|
which: Which,
|
||||||
|
|
||||||
|
/// Run on CPU rather than on GPU.
|
||||||
|
#[arg(long)]
|
||||||
|
cpu: bool,
|
||||||
|
|
||||||
|
/// The model weight file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
model: Option<String>,
|
||||||
|
|
||||||
|
/// The config file, in safetensor format.
|
||||||
|
#[arg(long)]
|
||||||
|
config: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() -> Result<()> {
|
||||||
|
let args = Args::parse();
|
||||||
|
let device = candle_examples::device(args.cpu)?;
|
||||||
|
let model_sample_rate = args.which.sample_rate();
|
||||||
|
let config = match args.config {
|
||||||
|
Some(c) => std::path::PathBuf::from(c),
|
||||||
|
None => Api::new()?
|
||||||
|
.model(args.which.config_repo().to_string())
|
||||||
|
.get("config.json")?,
|
||||||
|
};
|
||||||
|
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
||||||
|
let model = match args.model {
|
||||||
|
Some(model) => std::path::PathBuf::from(model),
|
||||||
|
None => Api::new()?
|
||||||
|
.model("lmz/candle-snac".to_string())
|
||||||
|
.get(args.which.model_file())?,
|
||||||
|
};
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
|
let model = Model::new(&config, vb)?;
|
||||||
|
|
||||||
|
let codes = match args.action {
|
||||||
|
Action::CodeToAudio => {
|
||||||
|
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||||
|
let num_codebooks = model.num_codebooks();
|
||||||
|
(0..num_codebooks)
|
||||||
|
.map(|i| {
|
||||||
|
codes
|
||||||
|
.get(&format!("codes-{i}"))
|
||||||
|
.expect("no codes in input file")
|
||||||
|
.clone()
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
}
|
||||||
|
Action::AudioToCode | Action::AudioToAudio => {
|
||||||
|
let pcm = if args.in_file == "-" {
|
||||||
|
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||||
|
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||||
|
let mut pcms = vec![];
|
||||||
|
let stdin = std::thread::spawn(|| {
|
||||||
|
let mut s = String::new();
|
||||||
|
std::io::stdin().read_line(&mut s)
|
||||||
|
});
|
||||||
|
while !stdin.is_finished() {
|
||||||
|
let input = input_audio.lock().unwrap().take_all();
|
||||||
|
if input.is_empty() {
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
pcms.push(input)
|
||||||
|
}
|
||||||
|
drop(stream);
|
||||||
|
pcms.concat()
|
||||||
|
} else {
|
||||||
|
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||||
|
if sample_rate != model_sample_rate {
|
||||||
|
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
|
||||||
|
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
|
||||||
|
} else {
|
||||||
|
pcm
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let pcm_len = pcm.len();
|
||||||
|
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||||
|
println!("input pcm shape: {:?}", pcm.shape());
|
||||||
|
model.encode(&pcm)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
for codes in codes.iter() {
|
||||||
|
println!("codes shape: {:?}", codes.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
match args.action {
|
||||||
|
Action::AudioToCode => {
|
||||||
|
let mut tensors = std::collections::HashMap::new();
|
||||||
|
for (i, codes) in codes.iter().enumerate() {
|
||||||
|
tensors.insert(format!("codes-{i}"), codes.clone());
|
||||||
|
}
|
||||||
|
candle::safetensors::save(&tensors, "codes.safetensors")?;
|
||||||
|
}
|
||||||
|
Action::AudioToAudio | Action::CodeToAudio => {
|
||||||
|
let codes = codes.iter().collect::<Vec<_>>();
|
||||||
|
let pcm = model.decode(&codes)?;
|
||||||
|
println!("output pcm shape: {:?}", pcm.shape());
|
||||||
|
let pcm = pcm.i(0)?.i(0)?;
|
||||||
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
if args.out_file == "-" {
|
||||||
|
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||||
|
{
|
||||||
|
let mut ad = ad.lock().unwrap();
|
||||||
|
ad.push_samples(&pcm)?;
|
||||||
|
}
|
||||||
|
loop {
|
||||||
|
let ad = ad.lock().unwrap();
|
||||||
|
if ad.is_empty() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||||
|
// playing (the callback doesn't seem to be called anymore).
|
||||||
|
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||||
|
}
|
||||||
|
drop(stream)
|
||||||
|
} else {
|
||||||
|
let mut output = std::fs::File::create(&args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
@ -504,8 +504,9 @@ impl BertModel {
|
|||||||
Some(attention_mask) => attention_mask.clone(),
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
None => input_ids.ones_like()?,
|
None => input_ids.ones_like()?,
|
||||||
};
|
};
|
||||||
|
let dtype = embedding_output.dtype();
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||||
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
@ -519,8 +520,11 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
};
|
};
|
||||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
// torch.finfo(dtype).min
|
// torch.finfo(dtype).min
|
||||||
(attention_mask.ones_like()? - &attention_mask)?
|
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
&Tensor::try_from(f32::MIN)?
|
||||||
|
.to_device(attention_mask.device())?
|
||||||
|
.to_dtype(dtype)?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||||
|
@ -514,8 +514,9 @@ impl ChineseClipTextTransformer {
|
|||||||
Some(attention_mask) => attention_mask.clone(),
|
Some(attention_mask) => attention_mask.clone(),
|
||||||
None => input_ids.ones_like()?,
|
None => input_ids.ones_like()?,
|
||||||
};
|
};
|
||||||
|
let dtype = embedding_output.dtype();
|
||||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||||
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||||
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
||||||
let pooled_output = match &self.pooler {
|
let pooled_output = match &self.pooler {
|
||||||
@ -535,6 +536,9 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
|||||||
};
|
};
|
||||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||||
// torch.finfo(dtype).min
|
// torch.finfo(dtype).min
|
||||||
(attention_mask.ones_like()? - &attention_mask)?
|
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
&Tensor::try_from(f32::MIN)?
|
||||||
|
.to_device(attention_mask.device())?
|
||||||
|
.to_dtype(dtype)?,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
@ -498,4 +498,36 @@ impl Model {
|
|||||||
}
|
}
|
||||||
Ok(all_samples)
|
Ok(all_samples)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut mask = vec![1u8; cb];
|
||||||
|
mask.push(0);
|
||||||
|
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
|
||||||
|
|
||||||
|
frame.push(0);
|
||||||
|
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut tokens = vec![];
|
||||||
|
let mut mask = vec![];
|
||||||
|
for &v in ids.iter() {
|
||||||
|
let mut token = vec![0; cb];
|
||||||
|
token.push(v);
|
||||||
|
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
|
||||||
|
tokens.push(token);
|
||||||
|
let mut m = vec![0u8; cb];
|
||||||
|
m.push(1);
|
||||||
|
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
|
||||||
|
mask.push(m);
|
||||||
|
}
|
||||||
|
let tokens = Tensor::cat(&tokens, 1)?;
|
||||||
|
let mask = Tensor::cat(&mask, 1)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ impl EncoderBlock {
|
|||||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||||
let cfg1 = Conv1dConfig {
|
let cfg1 = Conv1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||||
@ -196,7 +196,7 @@ impl DecoderBlock {
|
|||||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||||
let cfg = ConvTranspose1dConfig {
|
let cfg = ConvTranspose1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||||
@ -330,6 +330,7 @@ impl ResidualVectorQuantizer {
|
|||||||
Ok(Self { quantizers })
|
Ok(Self { quantizers })
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::wrong_self_convention)]
|
||||||
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
||||||
let mut sum = None;
|
let mut sum = None;
|
||||||
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
||||||
|
@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum HiddenAct {
|
pub enum HiddenAct {
|
||||||
Gelu,
|
Gelu,
|
||||||
Relu,
|
Relu,
|
||||||
}
|
}
|
||||||
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
|
|||||||
|
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||||
#[serde(rename_all = "lowercase")]
|
#[serde(rename_all = "lowercase")]
|
||||||
enum PositionEmbeddingType {
|
pub enum PositionEmbeddingType {
|
||||||
#[default]
|
#[default]
|
||||||
Absolute,
|
Absolute,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
vocab_size: usize,
|
pub vocab_size: usize,
|
||||||
dim: usize,
|
pub dim: usize,
|
||||||
n_layers: usize,
|
n_layers: usize,
|
||||||
n_heads: usize,
|
n_heads: usize,
|
||||||
hidden_dim: usize,
|
hidden_dim: usize,
|
||||||
activation: HiddenAct,
|
activation: HiddenAct,
|
||||||
max_position_embeddings: usize,
|
max_position_embeddings: usize,
|
||||||
initializer_range: f64,
|
initializer_range: f64,
|
||||||
pad_token_id: usize,
|
pub pad_token_id: usize,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
position_embedding_type: PositionEmbeddingType,
|
position_embedding_type: PositionEmbeddingType,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@ -345,3 +345,107 @@ impl DistilBertModel {
|
|||||||
Ok(sequence_output)
|
Ok(sequence_output)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct DistilBertPredictionHeadTransform {
|
||||||
|
dense: Linear,
|
||||||
|
activation: HiddenActLayer,
|
||||||
|
layer_norm: LayerNorm,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertPredictionHeadTransform {
|
||||||
|
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
|
||||||
|
let activation = HiddenActLayer::new(config.activation);
|
||||||
|
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
|
||||||
|
Ok(Self {
|
||||||
|
dense,
|
||||||
|
activation,
|
||||||
|
layer_norm,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DistilBertPredictionHeadTransform {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
let hidden_states = self
|
||||||
|
.activation
|
||||||
|
.forward(&self.dense.forward(hidden_states)?)?;
|
||||||
|
self.layer_norm.forward(&hidden_states)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
||||||
|
pub struct DistilBertLMPredictionHead {
|
||||||
|
transform: DistilBertPredictionHeadTransform,
|
||||||
|
decoder: Linear,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertLMPredictionHead {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
|
||||||
|
|
||||||
|
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
|
||||||
|
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
|
||||||
|
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||||
|
let ws = vocab_projector_weight_vb.get_with_hints(
|
||||||
|
(config.vocab_size, config.dim),
|
||||||
|
"weight",
|
||||||
|
init_ws,
|
||||||
|
)?;
|
||||||
|
let bound = 1. / (config.dim as f64).sqrt();
|
||||||
|
let init_bs = candle_nn::Init::Uniform {
|
||||||
|
lo: -bound,
|
||||||
|
up: bound,
|
||||||
|
};
|
||||||
|
|
||||||
|
let vocab_projector_bias_vb = vb.pp("vocab_projector");
|
||||||
|
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
|
||||||
|
|
||||||
|
let decoder = Linear::from_weights(ws, Some(bs));
|
||||||
|
|
||||||
|
Ok(Self { transform, decoder })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DistilBertLMPredictionHead {
|
||||||
|
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||||
|
self.decoder
|
||||||
|
.forward(&self.transform.forward(hidden_states)?)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
||||||
|
pub struct DistilBertOnlyMLMHead {
|
||||||
|
predictions: DistilBertLMPredictionHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertOnlyMLMHead {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
|
||||||
|
Ok(Self { predictions })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DistilBertOnlyMLMHead {
|
||||||
|
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
||||||
|
self.predictions.forward(sequence_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct DistilBertForMaskedLM {
|
||||||
|
pub bert: DistilBertModel,
|
||||||
|
cls: DistilBertOnlyMLMHead,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DistilBertForMaskedLM {
|
||||||
|
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||||
|
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
|
||||||
|
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
|
||||||
|
Ok(Self { bert, cls })
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||||
|
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
|
||||||
|
self.cls.forward(&sequence_output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -141,6 +141,20 @@ pub fn conv1d_weight_norm(
|
|||||||
Ok(Conv1d::new(weight, Some(bias), config))
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn conv1d_weight_norm_no_bias(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: candle_nn::Conv1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||||
|
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||||
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
|
Ok(Conv1d::new(weight, None, config))
|
||||||
|
}
|
||||||
|
|
||||||
pub fn conv_transpose1d_weight_norm(
|
pub fn conv_transpose1d_weight_norm(
|
||||||
in_c: usize,
|
in_c: usize,
|
||||||
out_c: usize,
|
out_c: usize,
|
||||||
|
@ -6,8 +6,8 @@ pub fn get_noise(
|
|||||||
width: usize,
|
width: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let height = (height + 15) / 16 * 2;
|
let height = height.div_ceil(16) * 2;
|
||||||
let width = (width + 15) / 16 * 2;
|
let width = width.div_ceil(16) * 2;
|
||||||
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
|
|||||||
|
|
||||||
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||||
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
||||||
let height = (height + 15) / 16;
|
let height = height.div_ceil(16);
|
||||||
let width = (width + 15) / 16;
|
let width = width.div_ceil(16);
|
||||||
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
||||||
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
||||||
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
||||||
|
@ -27,7 +27,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
fn d_inner(&self) -> usize {
|
||||||
|
@ -716,7 +716,7 @@ pub mod transformer {
|
|||||||
None => {
|
None => {
|
||||||
let hidden_dim = self.dim * 4;
|
let hidden_dim = self.dim * 4;
|
||||||
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
||||||
(n_hidden + 255) / 256 * 256
|
n_hidden.div_ceil(256) * 256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,6 +104,7 @@ pub mod rwkv_v6;
|
|||||||
pub mod segformer;
|
pub mod segformer;
|
||||||
pub mod segment_anything;
|
pub mod segment_anything;
|
||||||
pub mod siglip;
|
pub mod siglip;
|
||||||
|
pub mod snac;
|
||||||
pub mod stable_diffusion;
|
pub mod stable_diffusion;
|
||||||
pub mod stable_lm;
|
pub mod stable_lm;
|
||||||
pub mod starcoder2;
|
pub mod starcoder2;
|
||||||
|
814
candle-transformers/src/models/snac.rs
Normal file
814
candle-transformers/src/models/snac.rs
Normal file
@ -0,0 +1,814 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
//! Implementation of the Multi-Scale Neural Audio Codec (SNAC)
|
||||||
|
//!
|
||||||
|
//! See: [SNAC](https://github.com/hubertsiuzdak/snac)
|
||||||
|
//!
|
||||||
|
/// Multi-Scale Neural Audio Codec (SNAC) compresses audio into discrete codes at a low bitrate.
|
||||||
|
/// For more information, read the paper: https://arxiv.org/abs/2410.14411
|
||||||
|
///
|
||||||
|
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::{
|
||||||
|
linear_b, Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, LayerNorm, Linear,
|
||||||
|
VarBuilder,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug, Clone)]
|
||||||
|
pub struct Config {
|
||||||
|
pub sampling_rate: usize,
|
||||||
|
pub encoder_dim: usize,
|
||||||
|
pub encoder_rates: Vec<usize>,
|
||||||
|
pub decoder_dim: usize,
|
||||||
|
pub decoder_rates: Vec<usize>,
|
||||||
|
pub attn_window_size: Option<usize>,
|
||||||
|
pub codebook_size: usize,
|
||||||
|
pub codebook_dim: usize,
|
||||||
|
pub vq_strides: Vec<usize>,
|
||||||
|
pub noise: bool,
|
||||||
|
pub depthwise: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equivalent to torch.repeat_interleave
|
||||||
|
pub fn repeat_interleave<D: candle::shape::Dim>(
|
||||||
|
img: &Tensor,
|
||||||
|
repeats: usize,
|
||||||
|
dim: D,
|
||||||
|
) -> Result<Tensor> {
|
||||||
|
if repeats == 1 {
|
||||||
|
return Ok(img.clone());
|
||||||
|
}
|
||||||
|
let dim = dim.to_index(img.shape(), "chunk")?;
|
||||||
|
let img = img.unsqueeze(dim + 1)?;
|
||||||
|
let mut dims = img.dims().to_vec();
|
||||||
|
dims[dim + 1] = repeats;
|
||||||
|
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv1d_weight_norm(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: candle_nn::Conv1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
|
||||||
|
let weight_v = {
|
||||||
|
let name = "parametrizations.weight.original1";
|
||||||
|
match vb.get((out_c, in_c, kernel_size), name) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
|
let bias = vb.get(out_c, "bias")?;
|
||||||
|
Ok(Conv1d::new(weight, Some(bias), config))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv1d_weight_norm_no_bias(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
config: candle_nn::Conv1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Conv1d> {
|
||||||
|
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
|
||||||
|
let weight_v = {
|
||||||
|
let name = "parametrizations.weight.original1";
|
||||||
|
match vb.get((out_c, in_c, kernel_size), name) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
|
Ok(Conv1d::new(weight, None, config))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conv_transpose1d_weight_norm(
|
||||||
|
in_c: usize,
|
||||||
|
out_c: usize,
|
||||||
|
kernel_size: usize,
|
||||||
|
bias: bool,
|
||||||
|
config: candle_nn::ConvTranspose1dConfig,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<ConvTranspose1d> {
|
||||||
|
let weight_g = vb.get((in_c, 1, 1), "parametrizations.weight.original0")?;
|
||||||
|
let weight_v = vb.get(
|
||||||
|
(in_c, out_c, kernel_size),
|
||||||
|
"parametrizations.weight.original1",
|
||||||
|
)?;
|
||||||
|
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||||
|
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||||
|
let bias = if bias {
|
||||||
|
Some(vb.get(out_c, "bias")?)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(ConvTranspose1d::new(weight, bias, config))
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/hubertsiuzdak/snac/blob/main/snac/attention.py
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SinusoidalEmbeddings {
|
||||||
|
inv_freq: Tensor,
|
||||||
|
scale: Tensor,
|
||||||
|
scale_base: f32,
|
||||||
|
use_xpos: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SinusoidalEmbeddings {
|
||||||
|
fn new(dim: usize, scale_base: f32, use_xpos: bool, dev: &Device) -> Result<Self> {
|
||||||
|
let inv_freq: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| 1f32 / 10_000f32.powf(i as f32 / dim as f32))
|
||||||
|
.collect();
|
||||||
|
let len = inv_freq.len();
|
||||||
|
let inv_freq = Tensor::from_vec(inv_freq, len, dev)?.to_dtype(DType::F32)?;
|
||||||
|
let scale: Vec<_> = (0..dim)
|
||||||
|
.step_by(2)
|
||||||
|
.map(|i| (i as f32 + 0.4 * dim as f32) / (1.4 * dim as f32))
|
||||||
|
.collect();
|
||||||
|
let scale = Tensor::from_vec(scale, len, dev)?.to_dtype(DType::F32)?;
|
||||||
|
Ok(Self {
|
||||||
|
inv_freq,
|
||||||
|
scale,
|
||||||
|
scale_base,
|
||||||
|
use_xpos,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct LocalMHA {
|
||||||
|
norm: LayerNorm,
|
||||||
|
to_qkv: Linear,
|
||||||
|
to_out: Linear,
|
||||||
|
num_heads: usize,
|
||||||
|
head_dim: usize,
|
||||||
|
rel_pos: Option<SinusoidalEmbeddings>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LocalMHA {
|
||||||
|
fn new(
|
||||||
|
dim: usize,
|
||||||
|
window_size: usize,
|
||||||
|
dim_head: usize,
|
||||||
|
use_rotary_pos_emb: bool,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||||
|
let to_qkv = linear_b(dim, dim * 3, false, vb.pp("to_qkv"))?;
|
||||||
|
let to_out = linear_b(dim, dim, false, vb.pp("to_out"))?;
|
||||||
|
let rel_pos = if use_rotary_pos_emb {
|
||||||
|
let rel_pos =
|
||||||
|
SinusoidalEmbeddings::new(dim_head, window_size as f32 / 2.0, false, vb.device())?;
|
||||||
|
Some(rel_pos)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
Ok(Self {
|
||||||
|
norm,
|
||||||
|
to_qkv,
|
||||||
|
to_out,
|
||||||
|
rel_pos,
|
||||||
|
num_heads: dim / dim_head,
|
||||||
|
head_dim: dim_head,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for LocalMHA {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b, c, t) = xs.dims3()?;
|
||||||
|
let residual = xs.clone();
|
||||||
|
let xs = xs.transpose(1, 2)?.apply(&self.norm)?;
|
||||||
|
let qkv = xs.apply(&self.to_qkv)?;
|
||||||
|
let q = qkv.narrow(D::Minus1, 0, c)?;
|
||||||
|
let k = qkv.narrow(D::Minus1, c, c)?;
|
||||||
|
let v = qkv.narrow(D::Minus1, 2 * c, c)?;
|
||||||
|
let q = q
|
||||||
|
.reshape((b, t, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let k = k
|
||||||
|
.reshape((b, t, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let v = v
|
||||||
|
.reshape((b, t, self.num_heads, self.head_dim))?
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.contiguous()?;
|
||||||
|
let (q, k) = match self.rel_pos {
|
||||||
|
Some(_) => todo!(),
|
||||||
|
None => (q, k),
|
||||||
|
};
|
||||||
|
let out = {
|
||||||
|
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||||
|
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||||
|
// Non-causal attention
|
||||||
|
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||||
|
attn_weights.matmul(&v)?
|
||||||
|
};
|
||||||
|
let out = out
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.reshape((b, t, self.num_heads * self.head_dim))?
|
||||||
|
.apply(&self.to_out)?;
|
||||||
|
out.transpose(1, 2)? + residual
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct Snake1d {
|
||||||
|
alpha: Tensor,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Snake1d {
|
||||||
|
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let alpha = vb.get((1, channels, 1), "alpha")?;
|
||||||
|
Ok(Self { alpha })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for Snake1d {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let xs_shape = xs.shape();
|
||||||
|
let xs = xs.flatten_from(2)?;
|
||||||
|
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
|
||||||
|
let sin = (&sin * &sin)?;
|
||||||
|
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct ResidualUnit {
|
||||||
|
snake1: Snake1d,
|
||||||
|
conv1: Conv1d,
|
||||||
|
snake2: Snake1d,
|
||||||
|
conv2: Conv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualUnit {
|
||||||
|
fn new(
|
||||||
|
dim: usize,
|
||||||
|
dilation: usize,
|
||||||
|
kernel: usize,
|
||||||
|
groups: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let pad = ((kernel - 1) * dilation) / 2;
|
||||||
|
let vb = vb.pp("block");
|
||||||
|
let snake1 = Snake1d::new(dim, vb.pp(0))?;
|
||||||
|
let cfg1 = Conv1dConfig {
|
||||||
|
dilation,
|
||||||
|
padding: pad,
|
||||||
|
groups,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv1 = conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
|
||||||
|
let snake2 = Snake1d::new(dim, vb.pp(2))?;
|
||||||
|
let conv2 = conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
|
||||||
|
Ok(Self {
|
||||||
|
snake1,
|
||||||
|
conv1,
|
||||||
|
snake2,
|
||||||
|
conv2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for ResidualUnit {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let ys = xs
|
||||||
|
.apply(&self.snake1)?
|
||||||
|
.apply(&self.conv1)?
|
||||||
|
.apply(&self.snake2)?
|
||||||
|
.apply(&self.conv2)?;
|
||||||
|
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
|
||||||
|
if pad > 0 {
|
||||||
|
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
|
||||||
|
} else {
|
||||||
|
ys + xs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct NoiseBlock {
|
||||||
|
linear: Conv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl NoiseBlock {
|
||||||
|
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let linear = conv1d_weight_norm_no_bias(dim, dim, 1, Default::default(), vb.pp("linear"))?;
|
||||||
|
Ok(Self { linear })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for NoiseBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let (b, _c, t) = xs.dims3()?;
|
||||||
|
let noise = Tensor::randn(0f32, 1f32, (b, 1, t), xs.device())?;
|
||||||
|
let h = xs.apply(&self.linear)?;
|
||||||
|
let n = noise.broadcast_mul(&h)?;
|
||||||
|
let xs = (xs + n)?;
|
||||||
|
Ok(xs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct DecoderBlock {
|
||||||
|
snake1: Snake1d,
|
||||||
|
conv_tr1: ConvTranspose1d,
|
||||||
|
noise: Option<NoiseBlock>,
|
||||||
|
res1: ResidualUnit,
|
||||||
|
res2: ResidualUnit,
|
||||||
|
res3: ResidualUnit,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DecoderBlock {
|
||||||
|
fn new(
|
||||||
|
in_dim: usize,
|
||||||
|
out_dim: usize,
|
||||||
|
stride: usize,
|
||||||
|
noise: bool,
|
||||||
|
groups: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb = vb.pp("block");
|
||||||
|
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||||
|
let cfg = ConvTranspose1dConfig {
|
||||||
|
stride,
|
||||||
|
padding: stride.div_ceil(2),
|
||||||
|
output_padding: stride % 2,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv_tr1 =
|
||||||
|
conv_transpose1d_weight_norm(in_dim, out_dim, 2 * stride, true, cfg, vb.pp(1))?;
|
||||||
|
let (n, noise) = if noise {
|
||||||
|
let noise = NoiseBlock::new(out_dim, vb.pp(2))?;
|
||||||
|
(1, Some(noise))
|
||||||
|
} else {
|
||||||
|
(0, None)
|
||||||
|
};
|
||||||
|
let res1 = ResidualUnit::new(out_dim, 1, 7, groups, vb.pp(2 + n))?;
|
||||||
|
let res2 = ResidualUnit::new(out_dim, 3, 7, groups, vb.pp(3 + n))?;
|
||||||
|
let res3 = ResidualUnit::new(out_dim, 9, 7, groups, vb.pp(4 + n))?;
|
||||||
|
Ok(Self {
|
||||||
|
snake1,
|
||||||
|
conv_tr1,
|
||||||
|
noise,
|
||||||
|
res1,
|
||||||
|
res2,
|
||||||
|
res3,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Module for DecoderBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.snake1)?
|
||||||
|
.apply(&self.conv_tr1)?
|
||||||
|
.apply(&self.noise.as_ref())?
|
||||||
|
.apply(&self.res1)?
|
||||||
|
.apply(&self.res2)?
|
||||||
|
.apply(&self.res3)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct EncoderBlock {
|
||||||
|
res1: ResidualUnit,
|
||||||
|
res2: ResidualUnit,
|
||||||
|
res3: ResidualUnit,
|
||||||
|
snake1: Snake1d,
|
||||||
|
conv1: Conv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EncoderBlock {
|
||||||
|
fn new(
|
||||||
|
out_dim: usize,
|
||||||
|
in_dim: Option<usize>,
|
||||||
|
stride: usize,
|
||||||
|
groups: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb = vb.pp("block");
|
||||||
|
let in_dim = in_dim.unwrap_or(out_dim / 2);
|
||||||
|
let res1 = ResidualUnit::new(in_dim, 1, 7, groups, vb.pp(0))?;
|
||||||
|
let res2 = ResidualUnit::new(in_dim, 3, 7, groups, vb.pp(1))?;
|
||||||
|
let res3 = ResidualUnit::new(in_dim, 9, 7, groups, vb.pp(2))?;
|
||||||
|
let snake1 = Snake1d::new(in_dim, vb.pp(3))?;
|
||||||
|
let cfg1 = Conv1dConfig {
|
||||||
|
stride,
|
||||||
|
padding: stride.div_ceil(2),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv1 = conv1d_weight_norm(in_dim, out_dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||||
|
Ok(Self {
|
||||||
|
res1,
|
||||||
|
res2,
|
||||||
|
res3,
|
||||||
|
snake1,
|
||||||
|
conv1,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl candle::Module for EncoderBlock {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
xs.apply(&self.res1)?
|
||||||
|
.apply(&self.res2)?
|
||||||
|
.apply(&self.res3)?
|
||||||
|
.apply(&self.snake1)?
|
||||||
|
.apply(&self.conv1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Encoder {
|
||||||
|
conv1: Conv1d,
|
||||||
|
blocks: Vec<EncoderBlock>,
|
||||||
|
local_mha: Option<LocalMHA>,
|
||||||
|
conv2: Conv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl candle::Module for Encoder {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = xs.apply(&self.conv1)?;
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
xs = xs.apply(block)?
|
||||||
|
}
|
||||||
|
xs.apply(&self.conv2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Encoder {
|
||||||
|
fn new(
|
||||||
|
mut d_model: usize,
|
||||||
|
strides: &[usize],
|
||||||
|
depthwise: bool,
|
||||||
|
attn_window_size: Option<usize>,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb = vb.pp("block");
|
||||||
|
let mut idx = 0;
|
||||||
|
let cfg1 = Conv1dConfig {
|
||||||
|
padding: 3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv1 = conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
let mut blocks = Vec::with_capacity(strides.len());
|
||||||
|
for &stride in strides.iter() {
|
||||||
|
d_model *= 2;
|
||||||
|
let groups = if depthwise { d_model / 2 } else { 1 };
|
||||||
|
let block = EncoderBlock::new(d_model, None, stride, groups, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let local_mha = match attn_window_size {
|
||||||
|
Some(w) => {
|
||||||
|
let mha = LocalMHA::new(d_model, w, 64, true, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
Some(mha)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
let groups = if depthwise { d_model } else { 1 };
|
||||||
|
let cfg2 = Conv1dConfig {
|
||||||
|
padding: 3,
|
||||||
|
groups,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv2 = conv1d_weight_norm(d_model, d_model, 7, cfg2, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
Ok(Self {
|
||||||
|
conv1,
|
||||||
|
blocks,
|
||||||
|
local_mha,
|
||||||
|
conv2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
enum ConvInit {
|
||||||
|
Depthwise(Conv1d, Conv1d),
|
||||||
|
Standard(Conv1d),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Decoder {
|
||||||
|
conv1: ConvInit,
|
||||||
|
local_mha: Option<LocalMHA>,
|
||||||
|
blocks: Vec<DecoderBlock>,
|
||||||
|
snake1: Snake1d,
|
||||||
|
conv2: Conv1d,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Decoder {
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn new(
|
||||||
|
in_c: usize,
|
||||||
|
mut channels: usize,
|
||||||
|
rates: &[usize],
|
||||||
|
noise: bool,
|
||||||
|
depthwise: bool,
|
||||||
|
attn_window_size: Option<usize>,
|
||||||
|
d_out: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb = vb.pp("model");
|
||||||
|
let mut idx = 0;
|
||||||
|
let pad3 = Conv1dConfig {
|
||||||
|
padding: 3,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv1 = if depthwise {
|
||||||
|
let cfg1 = Conv1dConfig {
|
||||||
|
padding: 3,
|
||||||
|
groups: in_c,
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
let conv1 = conv1d_weight_norm(in_c, in_c, 7, cfg1, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
let conv2 = conv1d_weight_norm(in_c, channels, 1, Default::default(), vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
ConvInit::Depthwise(conv1, conv2)
|
||||||
|
} else {
|
||||||
|
let conv1 = conv1d_weight_norm(in_c, channels, 7, pad3, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
ConvInit::Standard(conv1)
|
||||||
|
};
|
||||||
|
let mut blocks = Vec::with_capacity(rates.len());
|
||||||
|
let local_mha = match attn_window_size {
|
||||||
|
Some(w) => {
|
||||||
|
let mha = LocalMHA::new(channels, w, 64, true, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
Some(mha)
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
for stride in rates.iter() {
|
||||||
|
let groups = if depthwise { channels / 2 } else { 1 };
|
||||||
|
let block =
|
||||||
|
DecoderBlock::new(channels, channels / 2, *stride, noise, groups, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
channels /= 2;
|
||||||
|
blocks.push(block)
|
||||||
|
}
|
||||||
|
let snake1 = Snake1d::new(channels, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
let conv2 = conv1d_weight_norm(channels, d_out, 7, pad3, vb.pp(idx))?;
|
||||||
|
idx += 1;
|
||||||
|
Ok(Self {
|
||||||
|
conv1,
|
||||||
|
local_mha,
|
||||||
|
blocks,
|
||||||
|
snake1,
|
||||||
|
conv2,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl candle::Module for Decoder {
|
||||||
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let mut xs = match &self.conv1 {
|
||||||
|
ConvInit::Standard(c) => xs.apply(c)?,
|
||||||
|
ConvInit::Depthwise(c1, c2) => xs.apply(c1)?.apply(c2)?,
|
||||||
|
};
|
||||||
|
for block in self.blocks.iter() {
|
||||||
|
xs = xs.apply(block)?
|
||||||
|
}
|
||||||
|
xs.apply(&self.snake1)?.apply(&self.conv2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize(v: &Tensor) -> Result<Tensor> {
|
||||||
|
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/hubertsiuzdak/snac/blob/main/snac/vq.py
|
||||||
|
#[allow(unused)]
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
struct VectorQuantizer {
|
||||||
|
in_proj: Conv1d,
|
||||||
|
out_proj: Conv1d,
|
||||||
|
codebook: candle_nn::Embedding,
|
||||||
|
stride: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl VectorQuantizer {
|
||||||
|
fn new(
|
||||||
|
in_dim: usize,
|
||||||
|
cb_size: usize,
|
||||||
|
cb_dim: usize,
|
||||||
|
stride: usize,
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let in_proj = conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
|
||||||
|
let out_proj =
|
||||||
|
conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
|
||||||
|
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
|
||||||
|
Ok(Self {
|
||||||
|
in_proj,
|
||||||
|
out_proj,
|
||||||
|
codebook,
|
||||||
|
stride,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_latents(&self, latents: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let (b, d, t) = latents.dims3()?;
|
||||||
|
let encodings = latents.transpose(1, 2)?.reshape((b * t, d))?;
|
||||||
|
let encodings = normalize(&encodings)?;
|
||||||
|
let codebook = normalize(self.codebook.embeddings())?;
|
||||||
|
let dist = (encodings
|
||||||
|
.sqr()?
|
||||||
|
.sum_keepdim(1)?
|
||||||
|
.broadcast_sub(&encodings.matmul(&codebook.t()?)?)?
|
||||||
|
* 2.0)?
|
||||||
|
.broadcast_add(&codebook.sqr()?.sum_keepdim(1)?.t()?)?;
|
||||||
|
let indices = dist.argmin(1)?.reshape((b, ()))?;
|
||||||
|
let z_q = self.decode_code(&indices)?;
|
||||||
|
Ok((z_q, indices))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, z: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||||
|
let z = if self.stride > 1 {
|
||||||
|
let (b, c, t) = z.dims3()?;
|
||||||
|
z.reshape((b, c, 1, t))?
|
||||||
|
.avg_pool2d((1, self.stride))?
|
||||||
|
.squeeze(2)?
|
||||||
|
} else {
|
||||||
|
z.clone()
|
||||||
|
};
|
||||||
|
let z_e = z.apply(&self.in_proj)?;
|
||||||
|
let (z_q, indices) = self.decode_latents(&z_e)?;
|
||||||
|
let z_q = z_q.apply(&self.out_proj)?;
|
||||||
|
let z_q = if self.stride > 1 {
|
||||||
|
repeat_interleave(&z_q, self.stride, D::Minus1)?
|
||||||
|
} else {
|
||||||
|
z_q
|
||||||
|
};
|
||||||
|
Ok((z_q, indices))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||||
|
embed_id.apply(&self.codebook)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||||
|
self.embed_code(embed_id)?.transpose(1, 2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct ResidualVectorQuantizer {
|
||||||
|
quantizers: Vec<VectorQuantizer>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ResidualVectorQuantizer {
|
||||||
|
fn new(
|
||||||
|
input_dim: usize,
|
||||||
|
cb_size: usize,
|
||||||
|
cb_dim: usize,
|
||||||
|
vq_strides: &[usize],
|
||||||
|
vb: VarBuilder,
|
||||||
|
) -> Result<Self> {
|
||||||
|
let vb = &vb.pp("quantizers");
|
||||||
|
let quantizers = vq_strides
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, stride)| VectorQuantizer::new(input_dim, cb_size, cb_dim, *stride, vb.pp(i)))
|
||||||
|
.collect::<Result<Vec<_>>>()?;
|
||||||
|
Ok(Self { quantizers })
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encode(&self, z: &Tensor) -> Result<(Tensor, Vec<Tensor>)> {
|
||||||
|
let mut residual = z.clone();
|
||||||
|
let mut z_q = z.zeros_like()?;
|
||||||
|
let mut codes = Vec::with_capacity(self.quantizers.len());
|
||||||
|
for quantizer in self.quantizers.iter() {
|
||||||
|
let (z_q_i, indices_i) = quantizer.encode(&residual)?;
|
||||||
|
z_q = (z_q + &z_q_i)?;
|
||||||
|
residual = (residual - &z_q_i)?;
|
||||||
|
codes.push(indices_i)
|
||||||
|
}
|
||||||
|
Ok((z_q, codes))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::wrong_self_convention)]
|
||||||
|
fn from_codes(&self, codes: &[&Tensor]) -> Result<Tensor> {
|
||||||
|
let mut sum = None;
|
||||||
|
for (quantizer, codes) in self.quantizers.iter().zip(codes.iter()) {
|
||||||
|
let z_p_i = quantizer.decode_code(codes)?;
|
||||||
|
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
|
||||||
|
let z_q_i = repeat_interleave(&z_q_i, quantizer.stride, D::Minus1)?;
|
||||||
|
let s = match sum {
|
||||||
|
None => z_q_i,
|
||||||
|
Some(s) => (s + z_q_i)?,
|
||||||
|
};
|
||||||
|
sum = Some(s)
|
||||||
|
}
|
||||||
|
match sum {
|
||||||
|
Some(s) => Ok(s),
|
||||||
|
None => candle::bail!("empty codebooks"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gcd(mut a: usize, mut b: usize) -> usize {
|
||||||
|
while b != 0 {
|
||||||
|
let t = b;
|
||||||
|
b = a % b;
|
||||||
|
a = t;
|
||||||
|
}
|
||||||
|
a
|
||||||
|
}
|
||||||
|
|
||||||
|
fn lcm(a: usize, b: usize) -> usize {
|
||||||
|
a / gcd(a, b) * b
|
||||||
|
}
|
||||||
|
|
||||||
|
// https://github.com/hubertsiuzdak/snac/blob/main/snac/snac.py
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Model {
|
||||||
|
pub encoder: Encoder,
|
||||||
|
pub quantizer: ResidualVectorQuantizer,
|
||||||
|
pub decoder: Decoder,
|
||||||
|
pub hop_length: usize,
|
||||||
|
pub config: Config,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Model {
|
||||||
|
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||||
|
let encoder = Encoder::new(
|
||||||
|
cfg.encoder_dim,
|
||||||
|
&cfg.encoder_rates,
|
||||||
|
cfg.depthwise,
|
||||||
|
cfg.attn_window_size,
|
||||||
|
vb.pp("encoder"),
|
||||||
|
)?;
|
||||||
|
let latent_dim = cfg.encoder_dim * 2usize.pow(cfg.encoder_rates.len() as u32);
|
||||||
|
let quantizer = ResidualVectorQuantizer::new(
|
||||||
|
latent_dim,
|
||||||
|
cfg.codebook_size,
|
||||||
|
cfg.codebook_dim,
|
||||||
|
&cfg.vq_strides,
|
||||||
|
vb.pp("quantizer"),
|
||||||
|
)?;
|
||||||
|
let decoder = Decoder::new(
|
||||||
|
latent_dim,
|
||||||
|
cfg.decoder_dim,
|
||||||
|
&cfg.decoder_rates,
|
||||||
|
cfg.noise,
|
||||||
|
cfg.depthwise,
|
||||||
|
cfg.attn_window_size,
|
||||||
|
/* d_out */ 1,
|
||||||
|
vb.pp("decoder"),
|
||||||
|
)?;
|
||||||
|
let hop_length = cfg.encoder_rates.iter().product::<usize>();
|
||||||
|
Ok(Self {
|
||||||
|
encoder,
|
||||||
|
decoder,
|
||||||
|
quantizer,
|
||||||
|
config: cfg.clone(),
|
||||||
|
hop_length,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn preprocess(&self, audio_data: &Tensor) -> Result<Tensor> {
|
||||||
|
let len = audio_data.dim(D::Minus1)?;
|
||||||
|
let lcm = lcm(
|
||||||
|
self.config.vq_strides[0],
|
||||||
|
self.config.attn_window_size.unwrap_or(1),
|
||||||
|
);
|
||||||
|
let pad_to = self.hop_length * lcm;
|
||||||
|
let right_pad = len.div_ceil(pad_to) * pad_to - len;
|
||||||
|
let audio_data = audio_data.pad_with_zeros(D::Minus1, 0, right_pad)?;
|
||||||
|
Ok(audio_data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn encode(&self, audio_data: &Tensor) -> Result<Vec<Tensor>> {
|
||||||
|
let audio_data = self.preprocess(audio_data)?;
|
||||||
|
let z = self.encoder.forward(&audio_data)?;
|
||||||
|
let (_, codes) = self.quantizer.encode(&z)?;
|
||||||
|
Ok(codes)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn decode(&self, audio_codes: &[&Tensor]) -> Result<Tensor> {
|
||||||
|
let audio_values = self.quantizer.from_codes(audio_codes)?;
|
||||||
|
audio_values.apply(&self.decoder)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn config(&self) -> &Config {
|
||||||
|
&self.config
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn num_codebooks(&self) -> usize {
|
||||||
|
self.quantizer.quantizers.len()
|
||||||
|
}
|
||||||
|
}
|
@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_<T: Float>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user