mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 02:58:50 +00:00
Compare commits
11 Commits
0.9.0-alph
...
cuda-graph
Author | SHA1 | Date | |
---|---|---|---|
543b5b5898 | |||
c87f0fa5d6 | |||
eb478ece92 | |||
d339b01726 | |||
2f3bf42bcb | |||
e3370c6316 | |||
338f6a102e | |||
bc33df77e1 | |||
1bb68854d3 | |||
b2956857ef | |||
9076dee432 |
@ -42,7 +42,7 @@ clap = { workspace = true }
|
||||
criterion = { workspace = true }
|
||||
|
||||
[features]
|
||||
default = []
|
||||
default = ["cuda"]
|
||||
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
|
||||
cudnn = ["cuda", "cudarc/cudnn"]
|
||||
mkl = ["dep:libc", "dep:intel-mkl-src"]
|
||||
@ -56,3 +56,7 @@ harness = false
|
||||
[[example]]
|
||||
name = "metal_basics"
|
||||
required-features = ["metal"]
|
||||
|
||||
[[example]]
|
||||
name = "cuda_basics"
|
||||
required-features = ["cuda"]
|
||||
|
@ -7,8 +7,79 @@ extern crate intel_mkl_src;
|
||||
use anyhow::Result;
|
||||
use candle_core::{Device, Tensor};
|
||||
|
||||
const USE_CUDA_GRAPH: bool = true;
|
||||
|
||||
fn cuda_graph() -> Result<()> {
|
||||
let device = Device::new_cuda_with_stream(0)?;
|
||||
let cu_device = match &device {
|
||||
Device::Cuda(dev) => dev,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let cu_stream = cu_device.cuda_stream();
|
||||
{
|
||||
// load_ptx cannot be called while capturing the stream so we need this to happen
|
||||
// beforehand.
|
||||
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let v = Tensor::zeros(4096, candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let _x = x.mul(&u)?.broadcast_add(&v)?;
|
||||
let _x = x.affine(1., 0.5)?;
|
||||
x.slice_set(&u, 0, 0)?;
|
||||
device.synchronize()?;
|
||||
}
|
||||
if USE_CUDA_GRAPH {
|
||||
cu_stream.begin_capture(
|
||||
cudarc::driver::sys::CUstreamCaptureMode::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL,
|
||||
)?;
|
||||
}
|
||||
{
|
||||
let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
for _i in 0..100 {
|
||||
// x.slice_set(&u, 0, 0)?;
|
||||
// x.broadcast_add(&v)?;
|
||||
x = x.affine(1., 0.5)?;
|
||||
// x = (&u + &x)?;
|
||||
}
|
||||
}
|
||||
if USE_CUDA_GRAPH {
|
||||
println!("capturing graph");
|
||||
let cu_graph = match cu_stream.end_capture(
|
||||
cudarc::driver::sys::CUgraphInstantiate_flags::CUDA_GRAPH_INSTANTIATE_FLAG_USE_NODE_PRIORITY
|
||||
)? {
|
||||
None => anyhow::bail!("no graph captured"),
|
||||
Some(cu_graph) => cu_graph,
|
||||
};
|
||||
println!("graph captured!");
|
||||
for i in 1..100 {
|
||||
println!("graph exec {i}");
|
||||
cu_graph.launch()?;
|
||||
println!("sync");
|
||||
if let Err(err) = device.synchronize() {
|
||||
println!("err: {err:?}")
|
||||
}
|
||||
println!("done syncing");
|
||||
}
|
||||
} else {
|
||||
device.synchronize()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let device = Device::new_cuda(0)?;
|
||||
cuda_graph()?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
fn _matmul() -> Result<()> {
|
||||
let device = Device::new_cuda_with_stream(0)?;
|
||||
let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)?
|
||||
.to_dtype(candle_core::DType::BF16)?;
|
||||
candle_core::cuda::set_gemm_reduced_precision_f32(false);
|
||||
|
@ -73,7 +73,7 @@ fn dequantize_f32(
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let nb = elem_count.div_ceil(256);
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f32", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f32", false, 32, nb),
|
||||
@ -133,7 +133,7 @@ fn dequantize_f16(
|
||||
elem_count: usize,
|
||||
dev: &CudaDevice,
|
||||
) -> Result<CudaStorage> {
|
||||
let nb = (elem_count + 255) / 256;
|
||||
let nb = elem_count.div_ceil(256);
|
||||
let (kernel_name, is_k, block_dim, num_blocks) = match dtype {
|
||||
GgmlDType::Q4_0 => ("dequantize_block_q4_0_f16", false, 32, nb),
|
||||
GgmlDType::Q4_1 => ("dequantize_block_q4_1_f16", false, 32, nb),
|
||||
@ -278,8 +278,8 @@ fn mul_mat_vec_via_q8_1(
|
||||
// https://github.com/ggerganov/llama.cpp/blob/facb8b56f8fd3bb10a693bf0943ae9d69d0828ef/ggml-cuda/mmvq.cu#L98
|
||||
let (nblocks, nwarps) = match b_size {
|
||||
1 => (nrows as u32, 4),
|
||||
2..=4 => ((nrows as u32 + 1) / 2, 4),
|
||||
5..=8 => ((nrows as u32 + 1) / 2, 2),
|
||||
2..=4 => ((nrows as u32).div_ceil(2), 4),
|
||||
5..=8 => ((nrows as u32).div_ceil(2), 2),
|
||||
_ => crate::bail!("unexpected bsize {b_size}"),
|
||||
};
|
||||
let cfg = cudarc::driver::LaunchConfig {
|
||||
|
@ -69,6 +69,7 @@ metal = ["candle/metal", "candle-nn/metal"]
|
||||
microphone = ["cpal", "rubato"]
|
||||
encodec = ["cpal", "symphonia", "rubato"]
|
||||
mimi = ["cpal", "symphonia", "rubato"]
|
||||
snac = ["cpal", "symphonia", "rubato"]
|
||||
depth_anything_v2 = ["palette", "enterpolation"]
|
||||
|
||||
[[example]]
|
||||
@ -107,6 +108,10 @@ required-features = ["candle-datasets"]
|
||||
name = "mimi"
|
||||
required-features = ["mimi"]
|
||||
|
||||
[[example]]
|
||||
name = "snac"
|
||||
required-features = ["snac"]
|
||||
|
||||
[[example]]
|
||||
name = "encodec"
|
||||
required-features = ["encodec"]
|
||||
|
@ -8,7 +8,7 @@ The speakers turn are delimited by the `|` character in the prompt.
|
||||
|
||||
```bash
|
||||
cargo run --example csm --features cuda -r -- \
|
||||
--voices voices.safetensors \
|
||||
--voices candle-examples/examples/csm/voices.safetensors \
|
||||
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
||||
```
|
||||
|
||||
|
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
BIN
candle-examples/examples/csm/voices.safetensors
Normal file
Binary file not shown.
@ -8,7 +8,7 @@ DistilBert is used to compute the sentence embeddings for a prompt. The model we
|
||||
are downloaded from the hub on the first run.
|
||||
|
||||
```bash
|
||||
cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
$ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
|
||||
> [[[ 0.5109, 0.1280, -0.2635, ..., 0.3462, -1.0434, 0.1441],
|
||||
> [ 0.1735, 0.0818, -0.5549, ..., 0.3472, -0.8264, -0.0244],
|
||||
@ -20,3 +20,25 @@ cargo run --example distilbert --release -- --prompt "Here is a test sentence"
|
||||
> Tensor[[1, 7, 768], f32]
|
||||
|
||||
```
|
||||
|
||||
## Masked Token
|
||||
|
||||
DistilBert is used to compute the top K choices for a masked token.
|
||||
|
||||
```bash
|
||||
$ cargo run --example distilbert -- --prompt "The capital of France is [MASK]." --top-k 10
|
||||
|
||||
> Input: The capital of France is [MASK].
|
||||
> Predictions for [MASK] at position 6:
|
||||
> 1: marseille (probability: 12.14%)
|
||||
> 2: paris (probability: 10.84%)
|
||||
> 3: toulouse (probability: 8.57%)
|
||||
> 4: lyon (probability: 7.61%)
|
||||
> 5: montpellier (probability: 5.18%)
|
||||
> 6: bordeaux (probability: 4.88%)
|
||||
> 7: nantes (probability: 4.82%)
|
||||
> 8: lille (probability: 4.07%)
|
||||
> 9: strasbourg (probability: 3.12%)
|
||||
> 10: cannes (probability: 3.04%)
|
||||
|
||||
```
|
@ -3,15 +3,48 @@ extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
use candle_transformers::models::distilbert::{Config, DistilBertModel, DTYPE};
|
||||
use candle_transformers::models::distilbert::{
|
||||
Config, DistilBertForMaskedLM, DistilBertModel, DTYPE,
|
||||
};
|
||||
|
||||
use anyhow::{Error as E, Result};
|
||||
use anyhow::{Context, Error as E, Result};
|
||||
use candle::{Device, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use clap::Parser;
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::{api::sync::Api, Repo, RepoType};
|
||||
use std::path::PathBuf;
|
||||
use tokenizers::Tokenizer;
|
||||
|
||||
enum ModelType {
|
||||
Masked(DistilBertForMaskedLM),
|
||||
UnMasked(DistilBertModel),
|
||||
}
|
||||
|
||||
impl ModelType {
|
||||
fn device(&self) -> &Device {
|
||||
match self {
|
||||
ModelType::Masked(model) => &model.bert.device,
|
||||
ModelType::UnMasked(model) => &model.device,
|
||||
}
|
||||
}
|
||||
|
||||
fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
match self {
|
||||
ModelType::Masked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||
ModelType::UnMasked(model) => Ok(model.forward(input_ids, attention_mask)?),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "distilbert")]
|
||||
DistilBert,
|
||||
|
||||
#[value(name = "distilbertformaskedlm")]
|
||||
DistilbertForMaskedLM,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
@ -23,10 +56,14 @@ struct Args {
|
||||
#[arg(long)]
|
||||
tracing: bool,
|
||||
|
||||
#[arg(long, default_value = "distilbert")]
|
||||
model: Which,
|
||||
|
||||
/// The model to use, check out available models: https://huggingface.co/models?library=sentence-transformers&sort=trending
|
||||
#[arg(long)]
|
||||
model_id: Option<String>,
|
||||
|
||||
/// Revision or branch
|
||||
#[arg(long)]
|
||||
revision: Option<String>,
|
||||
|
||||
@ -42,94 +79,246 @@ struct Args {
|
||||
#[arg(long, default_value = "1")]
|
||||
n: usize,
|
||||
|
||||
/// L2 normalization for embeddings.
|
||||
#[arg(long, default_value = "true")]
|
||||
normalize_embeddings: bool,
|
||||
/// Number of top predictions to show for each mask
|
||||
#[arg(long, default_value = "5")]
|
||||
top_k: usize,
|
||||
}
|
||||
|
||||
impl Args {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(DistilBertModel, Tokenizer)> {
|
||||
fn build_model_and_tokenizer(&self) -> Result<(ModelType, Tokenizer)> {
|
||||
let device = candle_examples::device(self.cpu)?;
|
||||
|
||||
let (model_id, revision) = self.resolve_model_and_revision();
|
||||
let (config_path, tokenizer_path, weights_path) =
|
||||
self.download_model_files(&model_id, &revision)?;
|
||||
|
||||
let config = std::fs::read_to_string(config_path)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(E::msg)?;
|
||||
|
||||
let vb = self.load_variables(&weights_path, &device)?;
|
||||
let model = self.create_model(&config, vb)?;
|
||||
|
||||
Ok((model, tokenizer))
|
||||
}
|
||||
|
||||
fn resolve_model_and_revision(&self) -> (String, String) {
|
||||
let default_model = "distilbert-base-uncased".to_string();
|
||||
let default_revision = "main".to_string();
|
||||
let (model_id, revision) = match (self.model_id.to_owned(), self.revision.to_owned()) {
|
||||
|
||||
match (self.model_id.clone(), self.revision.clone()) {
|
||||
(Some(model_id), Some(revision)) => (model_id, revision),
|
||||
(Some(model_id), None) => (model_id, "main".to_string()),
|
||||
(Some(model_id), None) => (model_id, default_revision),
|
||||
(None, Some(revision)) => (default_model, revision),
|
||||
(None, None) => (default_model, default_revision),
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
let repo = Repo::with_revision(model_id, RepoType::Model, revision);
|
||||
let (config_filename, tokenizer_filename, weights_filename) = {
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
(config, tokenizer, weights)
|
||||
};
|
||||
let config = std::fs::read_to_string(config_filename)?;
|
||||
let config: Config = serde_json::from_str(&config)?;
|
||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||
fn download_model_files(
|
||||
&self,
|
||||
model_id: &str,
|
||||
revision: &str,
|
||||
) -> Result<(PathBuf, PathBuf, PathBuf)> {
|
||||
let repo = Repo::with_revision(model_id.to_string(), RepoType::Model, revision.to_string());
|
||||
let api = Api::new()?;
|
||||
let api = api.repo(repo);
|
||||
|
||||
let vb = if self.use_pth {
|
||||
VarBuilder::from_pth(&weights_filename, DTYPE, &device)?
|
||||
let config = api.get("config.json")?;
|
||||
let tokenizer = api.get("tokenizer.json")?;
|
||||
let weights = if self.use_pth {
|
||||
api.get("pytorch_model.bin")?
|
||||
} else {
|
||||
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? }
|
||||
api.get("model.safetensors")?
|
||||
};
|
||||
let model = DistilBertModel::load(vb, &config)?;
|
||||
Ok((model, tokenizer))
|
||||
|
||||
Ok((config, tokenizer, weights))
|
||||
}
|
||||
|
||||
fn load_variables(&self, weights_path: &PathBuf, device: &Device) -> Result<VarBuilder> {
|
||||
if self.use_pth {
|
||||
Ok(VarBuilder::from_pth(weights_path, DTYPE, device)?)
|
||||
} else {
|
||||
Ok(unsafe { VarBuilder::from_mmaped_safetensors(&[weights_path], DTYPE, device)? })
|
||||
}
|
||||
}
|
||||
|
||||
fn create_model(&self, config: &Config, vb: VarBuilder) -> Result<ModelType> {
|
||||
match self.model {
|
||||
Which::DistilbertForMaskedLM => {
|
||||
Ok(ModelType::Masked(DistilBertForMaskedLM::load(vb, config)?))
|
||||
}
|
||||
Which::DistilBert => Ok(ModelType::UnMasked(DistilBertModel::load(vb, config)?)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Tensor {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device).unwrap()
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let _guard = setup_tracing(&args);
|
||||
|
||||
let (model, tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = model.device();
|
||||
|
||||
let (token_ids, mask) = prepare_inputs(&args, &tokenizer, device)?;
|
||||
let output = model.forward(&token_ids, &mask)?;
|
||||
|
||||
process_output(&model, &output, &token_ids, &tokenizer, &args)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
fn setup_tracing(args: &Args) -> Option<impl Drop> {
|
||||
if args.tracing {
|
||||
use tracing_chrome::ChromeLayerBuilder;
|
||||
use tracing_subscriber::prelude::*;
|
||||
|
||||
let args = Args::parse();
|
||||
let _guard = if args.tracing {
|
||||
println!("tracing...");
|
||||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
|
||||
tracing_subscriber::registry().with(chrome_layer).init();
|
||||
Some(guard)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let (model, mut tokenizer) = args.build_model_and_tokenizer()?;
|
||||
let device = &model.device;
|
||||
}
|
||||
}
|
||||
|
||||
let tokenizer = tokenizer
|
||||
fn prepare_inputs(args: &Args, tokenizer: &Tokenizer, device: &Device) -> Result<(Tensor, Tensor)> {
|
||||
let mut binding = tokenizer.clone();
|
||||
let tokenizer_configured = binding
|
||||
.with_padding(None)
|
||||
.with_truncation(None)
|
||||
.map_err(E::msg)?;
|
||||
let tokens = tokenizer
|
||||
.encode(args.prompt, true)
|
||||
|
||||
let tokens = tokenizer_configured
|
||||
.encode(args.prompt.clone(), true)
|
||||
.map_err(E::msg)?
|
||||
.get_ids()
|
||||
.to_vec();
|
||||
|
||||
let token_ids = Tensor::new(&tokens[..], device)?.unsqueeze(0)?;
|
||||
let mask = get_mask(tokens.len(), device);
|
||||
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>());
|
||||
println!("mask: {:?}", mask.to_vec2::<u8>());
|
||||
let mask = match args.model {
|
||||
Which::DistilbertForMaskedLM => attention_mask_maskedlm(tokenizer, &args.prompt, device)?,
|
||||
Which::DistilBert => attention_mask(tokens.len(), device)?,
|
||||
};
|
||||
|
||||
let ys = model.forward(&token_ids, &mask)?;
|
||||
println!("{ys}");
|
||||
println!("token_ids: {:?}", token_ids.to_vec2::<u32>()?);
|
||||
|
||||
Ok((token_ids, mask))
|
||||
}
|
||||
|
||||
fn process_output(
|
||||
model: &ModelType,
|
||||
output: &Tensor,
|
||||
token_ids: &Tensor,
|
||||
tokenizer: &Tokenizer,
|
||||
args: &Args,
|
||||
) -> Result<()> {
|
||||
match model {
|
||||
ModelType::UnMasked(_) => {
|
||||
println!("embeddings");
|
||||
println!("{output}");
|
||||
}
|
||||
ModelType::Masked(_) => {
|
||||
process_masked_output(output, token_ids, tokenizer, args)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn normalize_l2(v: &Tensor) -> Result<Tensor> {
|
||||
Ok(v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)?)
|
||||
fn process_masked_output(
|
||||
output: &Tensor,
|
||||
token_ids: &Tensor,
|
||||
tokenizer: &Tokenizer,
|
||||
args: &Args,
|
||||
) -> Result<()> {
|
||||
let input_ids_vec = token_ids.to_vec2::<u32>()?;
|
||||
let mask_token_id = tokenizer
|
||||
.token_to_id("[MASK]")
|
||||
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||
|
||||
println!("\nInput: {}", args.prompt);
|
||||
|
||||
for (token_idx, &token_id) in input_ids_vec[0].iter().enumerate() {
|
||||
if token_id == mask_token_id {
|
||||
println!("Predictions for [MASK] at position {}:", token_idx);
|
||||
|
||||
let pos_logits = output.get(0)?.get(token_idx)?;
|
||||
let probs = candle_nn::ops::softmax(&pos_logits, 0)?;
|
||||
let (top_values, top_indices) = get_top_k(&probs, args.top_k)?;
|
||||
|
||||
let values = top_values.to_vec1::<f32>()?;
|
||||
let indices = top_indices.to_vec1::<u32>()?;
|
||||
|
||||
for (i, (&token_id, &prob)) in indices.iter().zip(values.iter()).enumerate() {
|
||||
let token = tokenizer.decode(&[token_id], false).map_err(E::msg)?;
|
||||
println!(
|
||||
" {}: {:15} (probability: {:.2}%)",
|
||||
i + 1,
|
||||
token,
|
||||
prob * 100.0
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_top_k(tensor: &Tensor, k: usize) -> Result<(Tensor, Tensor)> {
|
||||
let n = tensor.dims().iter().product::<usize>();
|
||||
let k = std::cmp::min(k, n);
|
||||
|
||||
let values = tensor.to_vec1::<f32>()?;
|
||||
let mut value_indices: Vec<(f32, usize)> = values
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, val)| (val, idx))
|
||||
.collect();
|
||||
|
||||
value_indices.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
|
||||
|
||||
let top_k_values: Vec<f32> = value_indices.iter().take(k).map(|(val, _)| *val).collect();
|
||||
let top_k_indices: Vec<u32> = value_indices
|
||||
.iter()
|
||||
.take(k)
|
||||
.map(|(_, idx)| *idx as u32)
|
||||
.collect();
|
||||
|
||||
let device = tensor.device();
|
||||
let top_values = Tensor::from_vec(top_k_values, (k,), device)?;
|
||||
let top_indices = Tensor::from_vec(top_k_indices, (k,), device)?;
|
||||
|
||||
Ok((top_values, top_indices))
|
||||
}
|
||||
|
||||
fn attention_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.collect();
|
||||
Ok(Tensor::from_slice(&mask, (size, size), device)?)
|
||||
}
|
||||
|
||||
fn attention_mask_maskedlm(tokenizer: &Tokenizer, input: &str, device: &Device) -> Result<Tensor> {
|
||||
let tokens = tokenizer.encode(input, true).map_err(E::msg)?;
|
||||
let seq_len = tokens.get_attention_mask().to_vec().len();
|
||||
|
||||
let mask_token_id = tokenizer
|
||||
.token_to_id("[MASK]")
|
||||
.context("Mask token, \"[MASK]\", not found in tokenizer.")?;
|
||||
|
||||
let mut attention_mask_vec = Vec::with_capacity(seq_len * seq_len);
|
||||
|
||||
let ids = tokens.get_ids();
|
||||
for _ in 0..seq_len {
|
||||
for id in ids.iter() {
|
||||
let mask_value = if id == &mask_token_id { 1u8 } else { 0u8 };
|
||||
attention_mask_vec.push(mask_value);
|
||||
}
|
||||
}
|
||||
|
||||
let shape = (1, 1, seq_len, seq_len);
|
||||
let mask = Tensor::from_vec(attention_mask_vec, shape, device)?;
|
||||
|
||||
Ok(mask)
|
||||
}
|
||||
|
275
candle-examples/examples/snac/audio_io.rs
Normal file
275
candle-examples/examples/snac/audio_io.rs
Normal file
@ -0,0 +1,275 @@
|
||||
use anyhow::{Context, Result};
|
||||
use std::sync::{Arc, Mutex};
|
||||
|
||||
pub const SAMPLE_RATE: usize = 24_000;
|
||||
|
||||
pub(crate) struct AudioOutputData_ {
|
||||
resampled_data: std::collections::VecDeque<f32>,
|
||||
resampler: rubato::FastFixedIn<f32>,
|
||||
output_buffer: Vec<f32>,
|
||||
input_buffer: Vec<f32>,
|
||||
input_len: usize,
|
||||
}
|
||||
|
||||
impl AudioOutputData_ {
|
||||
pub(crate) fn new(input_sample_rate: usize, output_sample_rate: usize) -> Result<Self> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let resampled_data = std::collections::VecDeque::with_capacity(output_sample_rate * 10);
|
||||
let resample_ratio = output_sample_rate as f64 / input_sample_rate as f64;
|
||||
let resampler = rubato::FastFixedIn::new(
|
||||
resample_ratio,
|
||||
f64::max(resample_ratio, 1.0),
|
||||
rubato::PolynomialDegree::Septic,
|
||||
1024,
|
||||
1,
|
||||
)?;
|
||||
let input_buffer = resampler.input_buffer_allocate(true).remove(0);
|
||||
let output_buffer = resampler.output_buffer_allocate(true).remove(0);
|
||||
Ok(Self {
|
||||
resampled_data,
|
||||
resampler,
|
||||
input_buffer,
|
||||
output_buffer,
|
||||
input_len: 0,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
use rubato::Resampler;
|
||||
self.output_buffer.fill(0.);
|
||||
self.input_buffer.fill(0.);
|
||||
self.resampler.reset();
|
||||
self.resampled_data.clear();
|
||||
}
|
||||
|
||||
pub(crate) fn take_all(&mut self) -> Vec<f32> {
|
||||
let mut data = Vec::with_capacity(self.resampled_data.len());
|
||||
while let Some(elem) = self.resampled_data.pop_back() {
|
||||
data.push(elem);
|
||||
}
|
||||
data
|
||||
}
|
||||
|
||||
pub(crate) fn is_empty(&self) -> bool {
|
||||
self.resampled_data.is_empty()
|
||||
}
|
||||
|
||||
// Assumes that the input buffer is large enough.
|
||||
fn push_input_buffer(&mut self, samples: &[f32]) {
|
||||
self.input_buffer[self.input_len..self.input_len + samples.len()].copy_from_slice(samples);
|
||||
self.input_len += samples.len()
|
||||
}
|
||||
|
||||
pub(crate) fn push_samples(&mut self, samples: &[f32]) -> Result<()> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pos_in = 0;
|
||||
loop {
|
||||
let rem = self.input_buffer.len() - self.input_len;
|
||||
let pos_end = usize::min(pos_in + rem, samples.len());
|
||||
self.push_input_buffer(&samples[pos_in..pos_end]);
|
||||
pos_in = pos_end;
|
||||
if self.input_len < self.input_buffer.len() {
|
||||
break;
|
||||
}
|
||||
let (_, out_len) = self.resampler.process_into_buffer(
|
||||
&[&self.input_buffer],
|
||||
&mut [&mut self.output_buffer],
|
||||
None,
|
||||
)?;
|
||||
for &elem in self.output_buffer[..out_len].iter() {
|
||||
self.resampled_data.push_front(elem)
|
||||
}
|
||||
self.input_len = 0;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
type AudioOutputData = Arc<Mutex<AudioOutputData_>>;
|
||||
|
||||
pub(crate) fn setup_output_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio output stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_output_device()
|
||||
.context("no output device available")?;
|
||||
let mut supported_configs_range = device.supported_output_configs()?;
|
||||
let config_range = match supported_configs_range.find(|c| c.channels() == 1) {
|
||||
// On macOS, it's commonly the case that there are only stereo outputs.
|
||||
None => device
|
||||
.supported_output_configs()?
|
||||
.next()
|
||||
.context("no audio output available")?,
|
||||
Some(config_range) => config_range,
|
||||
};
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
let channels = config.channels as usize;
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
SAMPLE_RATE,
|
||||
config.sample_rate.0 as usize,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_output_stream(
|
||||
&config,
|
||||
move |data: &mut [f32], _: &cpal::OutputCallbackInfo| {
|
||||
data.fill(0.);
|
||||
let mut ad = ad.lock().unwrap();
|
||||
let mut last_elem = 0f32;
|
||||
for (idx, elem) in data.iter_mut().enumerate() {
|
||||
if idx % channels == 0 {
|
||||
match ad.resampled_data.pop_back() {
|
||||
None => break,
|
||||
Some(v) => {
|
||||
last_elem = v;
|
||||
*elem = v
|
||||
}
|
||||
}
|
||||
} else {
|
||||
*elem = last_elem
|
||||
}
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
pub(crate) fn setup_input_stream() -> Result<(cpal::Stream, AudioOutputData)> {
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
|
||||
println!("Setup audio input stream!");
|
||||
let host = cpal::default_host();
|
||||
let device = host
|
||||
.default_input_device()
|
||||
.context("no input device available")?;
|
||||
let mut supported_configs_range = device.supported_input_configs()?;
|
||||
let config_range = supported_configs_range
|
||||
.find(|c| c.channels() == 1)
|
||||
.context("no audio input available")?;
|
||||
let sample_rate = cpal::SampleRate(SAMPLE_RATE as u32).clamp(
|
||||
config_range.min_sample_rate(),
|
||||
config_range.max_sample_rate(),
|
||||
);
|
||||
let config: cpal::StreamConfig = config_range.with_sample_rate(sample_rate).into();
|
||||
println!(
|
||||
"cpal device: {} {} {config:?}",
|
||||
device.name().unwrap_or_else(|_| "unk".to_string()),
|
||||
config.sample_rate.0
|
||||
);
|
||||
let audio_data = Arc::new(Mutex::new(AudioOutputData_::new(
|
||||
config.sample_rate.0 as usize,
|
||||
SAMPLE_RATE,
|
||||
)?));
|
||||
let ad = audio_data.clone();
|
||||
let stream = device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
let mut ad = ad.lock().unwrap();
|
||||
if let Err(err) = ad.push_samples(data) {
|
||||
eprintln!("error processing audio input {err:?}")
|
||||
}
|
||||
},
|
||||
move |err| eprintln!("cpal error: {err}"),
|
||||
None, // None=blocking, Some(Duration)=timeout
|
||||
)?;
|
||||
stream.play()?;
|
||||
Ok((stream, audio_data))
|
||||
}
|
||||
|
||||
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||
where
|
||||
T: symphonia::core::sample::Sample,
|
||||
f32: symphonia::core::conv::FromSample<T>,
|
||||
{
|
||||
use symphonia::core::audio::Signal;
|
||||
use symphonia::core::conv::FromSample;
|
||||
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||
}
|
||||
|
||||
pub(crate) fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> Result<(Vec<f32>, u32)> {
|
||||
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||
|
||||
let src = std::fs::File::open(path)?;
|
||||
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||
let hint = symphonia::core::probe::Hint::new();
|
||||
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||
let mut format = probed.format;
|
||||
let track = format
|
||||
.tracks()
|
||||
.iter()
|
||||
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||
.expect("no supported audio tracks");
|
||||
let mut decoder = symphonia::default::get_codecs()
|
||||
.make(&track.codec_params, &Default::default())
|
||||
.expect("unsupported codec");
|
||||
let track_id = track.id;
|
||||
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||
let mut pcm_data = Vec::new();
|
||||
while let Ok(packet) = format.next_packet() {
|
||||
while !format.metadata().is_latest() {
|
||||
format.metadata().pop();
|
||||
}
|
||||
if packet.track_id() != track_id {
|
||||
continue;
|
||||
}
|
||||
match decoder.decode(&packet)? {
|
||||
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||
}
|
||||
}
|
||||
Ok((pcm_data, sample_rate))
|
||||
}
|
||||
|
||||
pub(crate) fn resample(pcm_in: &[f32], sr_in: u32, sr_out: u32) -> Result<Vec<f32>> {
|
||||
use rubato::Resampler;
|
||||
|
||||
let mut pcm_out =
|
||||
Vec::with_capacity((pcm_in.len() as f64 * sr_out as f64 / sr_in as f64) as usize + 1024);
|
||||
|
||||
let mut resampler =
|
||||
rubato::FftFixedInOut::<f32>::new(sr_in as usize, sr_out as usize, 1024, 1)?;
|
||||
let mut output_buffer = resampler.output_buffer_allocate(true);
|
||||
let mut pos_in = 0;
|
||||
while pos_in + resampler.input_frames_next() < pcm_in.len() {
|
||||
let (in_len, out_len) =
|
||||
resampler.process_into_buffer(&[&pcm_in[pos_in..]], &mut output_buffer, None)?;
|
||||
pos_in += in_len;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
if pos_in < pcm_in.len() {
|
||||
let (_in_len, out_len) = resampler.process_partial_into_buffer(
|
||||
Some(&[&pcm_in[pos_in..]]),
|
||||
&mut output_buffer,
|
||||
None,
|
||||
)?;
|
||||
pcm_out.extend_from_slice(&output_buffer[0][..out_len]);
|
||||
}
|
||||
|
||||
Ok(pcm_out)
|
||||
}
|
197
candle-examples/examples/snac/main.rs
Normal file
197
candle-examples/examples/snac/main.rs
Normal file
@ -0,0 +1,197 @@
|
||||
#[cfg(feature = "mkl")]
|
||||
extern crate intel_mkl_src;
|
||||
|
||||
#[cfg(feature = "accelerate")]
|
||||
extern crate accelerate_src;
|
||||
|
||||
use anyhow::Result;
|
||||
use candle::{DType, IndexOp, Tensor};
|
||||
use candle_nn::VarBuilder;
|
||||
use candle_transformers::models::snac::{Config, Model};
|
||||
use clap::{Parser, ValueEnum};
|
||||
use hf_hub::api::sync::Api;
|
||||
|
||||
mod audio_io;
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, ValueEnum)]
|
||||
enum Action {
|
||||
AudioToAudio,
|
||||
AudioToCode,
|
||||
CodeToAudio,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||
enum Which {
|
||||
#[value(name = "24khz")]
|
||||
S24khz,
|
||||
#[value(name = "32khz")]
|
||||
S32khz,
|
||||
#[value(name = "44khz")]
|
||||
S44khz,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
fn sample_rate(&self) -> u32 {
|
||||
match self {
|
||||
Which::S24khz => 24000,
|
||||
Which::S32khz => 32000,
|
||||
Which::S44khz => 44000,
|
||||
}
|
||||
}
|
||||
|
||||
fn config_repo(&self) -> &'static str {
|
||||
match self {
|
||||
Which::S24khz => "hubertsiuzdak/snac_24khz",
|
||||
Which::S32khz => "hubertsiuzdak/snac_32khz",
|
||||
Which::S44khz => "hubertsiuzdak/snac_44khz",
|
||||
}
|
||||
}
|
||||
|
||||
fn model_file(&self) -> &'static str {
|
||||
match self {
|
||||
Which::S24khz => "snac_24khz.safetensors",
|
||||
Which::S32khz => "snac_32khz.safetensors",
|
||||
Which::S44khz => "snac_44khz.safetensors",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
#[command(author, version, about, long_about = None)]
|
||||
struct Args {
|
||||
/// The action to be performed, specifies the format for the input and output data.
|
||||
action: Action,
|
||||
|
||||
/// The input file, either an audio file or some snac tokens stored as safetensors.
|
||||
in_file: String,
|
||||
|
||||
/// The output file, either a wave audio file or some snac tokens stored as safetensors.
|
||||
out_file: String,
|
||||
|
||||
/// The model size to use.
|
||||
#[arg(long, default_value = "24khz")]
|
||||
which: Which,
|
||||
|
||||
/// Run on CPU rather than on GPU.
|
||||
#[arg(long)]
|
||||
cpu: bool,
|
||||
|
||||
/// The model weight file, in safetensor format.
|
||||
#[arg(long)]
|
||||
model: Option<String>,
|
||||
|
||||
/// The config file, in safetensor format.
|
||||
#[arg(long)]
|
||||
config: Option<String>,
|
||||
}
|
||||
|
||||
fn main() -> Result<()> {
|
||||
let args = Args::parse();
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
let model_sample_rate = args.which.sample_rate();
|
||||
let config = match args.config {
|
||||
Some(c) => std::path::PathBuf::from(c),
|
||||
None => Api::new()?
|
||||
.model(args.which.config_repo().to_string())
|
||||
.get("config.json")?,
|
||||
};
|
||||
let config: Config = serde_json::from_slice(&std::fs::read(config)?)?;
|
||||
let model = match args.model {
|
||||
Some(model) => std::path::PathBuf::from(model),
|
||||
None => Api::new()?
|
||||
.model("lmz/candle-snac".to_string())
|
||||
.get(args.which.model_file())?,
|
||||
};
|
||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||
let model = Model::new(&config, vb)?;
|
||||
|
||||
let codes = match args.action {
|
||||
Action::CodeToAudio => {
|
||||
let codes = candle::safetensors::load(args.in_file, &device)?;
|
||||
let num_codebooks = model.num_codebooks();
|
||||
(0..num_codebooks)
|
||||
.map(|i| {
|
||||
codes
|
||||
.get(&format!("codes-{i}"))
|
||||
.expect("no codes in input file")
|
||||
.clone()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
Action::AudioToCode | Action::AudioToAudio => {
|
||||
let pcm = if args.in_file == "-" {
|
||||
println!(">>>> RECORDING AUDIO, PRESS ENTER ONCE DONE <<<<");
|
||||
let (stream, input_audio) = audio_io::setup_input_stream()?;
|
||||
let mut pcms = vec![];
|
||||
let stdin = std::thread::spawn(|| {
|
||||
let mut s = String::new();
|
||||
std::io::stdin().read_line(&mut s)
|
||||
});
|
||||
while !stdin.is_finished() {
|
||||
let input = input_audio.lock().unwrap().take_all();
|
||||
if input.is_empty() {
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
continue;
|
||||
}
|
||||
pcms.push(input)
|
||||
}
|
||||
drop(stream);
|
||||
pcms.concat()
|
||||
} else {
|
||||
let (pcm, sample_rate) = audio_io::pcm_decode(args.in_file)?;
|
||||
if sample_rate != model_sample_rate {
|
||||
println!("WARNING: snac uses a {model_sample_rate} sample rate, input uses {sample_rate}, resampling...");
|
||||
audio_io::resample(&pcm, sample_rate, model_sample_rate)?
|
||||
} else {
|
||||
pcm
|
||||
}
|
||||
};
|
||||
let pcm_len = pcm.len();
|
||||
let pcm = Tensor::from_vec(pcm, (1, 1, pcm_len), &device)?;
|
||||
println!("input pcm shape: {:?}", pcm.shape());
|
||||
model.encode(&pcm)?
|
||||
}
|
||||
};
|
||||
for codes in codes.iter() {
|
||||
println!("codes shape: {:?}", codes.shape());
|
||||
}
|
||||
|
||||
match args.action {
|
||||
Action::AudioToCode => {
|
||||
let mut tensors = std::collections::HashMap::new();
|
||||
for (i, codes) in codes.iter().enumerate() {
|
||||
tensors.insert(format!("codes-{i}"), codes.clone());
|
||||
}
|
||||
candle::safetensors::save(&tensors, "codes.safetensors")?;
|
||||
}
|
||||
Action::AudioToAudio | Action::CodeToAudio => {
|
||||
let codes = codes.iter().collect::<Vec<_>>();
|
||||
let pcm = model.decode(&codes)?;
|
||||
println!("output pcm shape: {:?}", pcm.shape());
|
||||
let pcm = pcm.i(0)?.i(0)?;
|
||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, model_sample_rate, true)?;
|
||||
let pcm = pcm.to_vec1::<f32>()?;
|
||||
if args.out_file == "-" {
|
||||
let (stream, ad) = audio_io::setup_output_stream()?;
|
||||
{
|
||||
let mut ad = ad.lock().unwrap();
|
||||
ad.push_samples(&pcm)?;
|
||||
}
|
||||
loop {
|
||||
let ad = ad.lock().unwrap();
|
||||
if ad.is_empty() {
|
||||
break;
|
||||
}
|
||||
// That's very weird, calling thread::sleep here triggers the stream to stop
|
||||
// playing (the callback doesn't seem to be called anymore).
|
||||
// std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
}
|
||||
drop(stream)
|
||||
} else {
|
||||
let mut output = std::fs::File::create(&args.out_file)?;
|
||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, model_sample_rate)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
@ -504,8 +504,9 @@ impl BertModel {
|
||||
Some(attention_mask) => attention_mask.clone(),
|
||||
None => input_ids.ones_like()?,
|
||||
};
|
||||
let dtype = embedding_output.dtype();
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||
let sequence_output = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||
Ok(sequence_output)
|
||||
}
|
||||
@ -519,8 +520,11 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||
&Tensor::try_from(f32::MIN)?
|
||||
.to_device(attention_mask.device())?
|
||||
.to_dtype(dtype)?,
|
||||
)
|
||||
}
|
||||
|
||||
//https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L752-L766
|
||||
|
@ -514,8 +514,9 @@ impl ChineseClipTextTransformer {
|
||||
Some(attention_mask) => attention_mask.clone(),
|
||||
None => input_ids.ones_like()?,
|
||||
};
|
||||
let dtype = embedding_output.dtype();
|
||||
// https://github.com/huggingface/transformers/blob/6eedfa6dd15dc1e22a55ae036f681914e5a0d9a1/src/transformers/models/bert/modeling_bert.py#L995
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, DType::F32)?;
|
||||
let attention_mask = get_extended_attention_mask(&attention_mask, dtype)?;
|
||||
let encoder_outputs = self.encoder.forward(&embedding_output, &attention_mask)?;
|
||||
let encoder_output = encoder_outputs.i((.., 0, ..))?;
|
||||
let pooled_output = match &self.pooler {
|
||||
@ -535,6 +536,9 @@ fn get_extended_attention_mask(attention_mask: &Tensor, dtype: DType) -> Result<
|
||||
};
|
||||
let attention_mask = attention_mask.to_dtype(dtype)?;
|
||||
// torch.finfo(dtype).min
|
||||
(attention_mask.ones_like()? - &attention_mask)?
|
||||
.broadcast_mul(&Tensor::try_from(f32::MIN)?.to_device(attention_mask.device())?)
|
||||
(attention_mask.ones_like()? - &attention_mask)?.broadcast_mul(
|
||||
&Tensor::try_from(f32::MIN)?
|
||||
.to_device(attention_mask.device())?
|
||||
.to_dtype(dtype)?,
|
||||
)
|
||||
}
|
||||
|
@ -330,6 +330,7 @@ impl ResidualVectorQuantizer {
|
||||
Ok(Self { quantizers })
|
||||
}
|
||||
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
pub fn from_codes(&self, codes: &Tensor) -> Result<Tensor> {
|
||||
let mut sum = None;
|
||||
for (idx, quantizer) in self.quantizers.iter().enumerate() {
|
||||
|
@ -19,7 +19,7 @@ fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor>
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum HiddenAct {
|
||||
pub enum HiddenAct {
|
||||
Gelu,
|
||||
Relu,
|
||||
}
|
||||
@ -49,22 +49,22 @@ impl Module for HiddenActLayer {
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum PositionEmbeddingType {
|
||||
pub enum PositionEmbeddingType {
|
||||
#[default]
|
||||
Absolute,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Deserialize)]
|
||||
pub struct Config {
|
||||
vocab_size: usize,
|
||||
dim: usize,
|
||||
pub vocab_size: usize,
|
||||
pub dim: usize,
|
||||
n_layers: usize,
|
||||
n_heads: usize,
|
||||
hidden_dim: usize,
|
||||
activation: HiddenAct,
|
||||
max_position_embeddings: usize,
|
||||
initializer_range: f64,
|
||||
pad_token_id: usize,
|
||||
pub pad_token_id: usize,
|
||||
#[serde(default)]
|
||||
position_embedding_type: PositionEmbeddingType,
|
||||
#[serde(default)]
|
||||
@ -345,3 +345,107 @@ impl DistilBertModel {
|
||||
Ok(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
struct DistilBertPredictionHeadTransform {
|
||||
dense: Linear,
|
||||
activation: HiddenActLayer,
|
||||
layer_norm: LayerNorm,
|
||||
}
|
||||
|
||||
impl DistilBertPredictionHeadTransform {
|
||||
fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let dense = linear(config.dim, config.dim, vb.pp("vocab_transform"))?;
|
||||
let activation = HiddenActLayer::new(config.activation);
|
||||
let layer_norm = layer_norm(config.dim, 1e-12, vb.pp("vocab_layer_norm"))?;
|
||||
Ok(Self {
|
||||
dense,
|
||||
activation,
|
||||
layer_norm,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DistilBertPredictionHeadTransform {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self
|
||||
.activation
|
||||
.forward(&self.dense.forward(hidden_states)?)?;
|
||||
self.layer_norm.forward(&hidden_states)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L769C1-L790C1
|
||||
pub struct DistilBertLMPredictionHead {
|
||||
transform: DistilBertPredictionHeadTransform,
|
||||
decoder: Linear,
|
||||
}
|
||||
|
||||
impl DistilBertLMPredictionHead {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let transform = DistilBertPredictionHeadTransform::load(vb.clone(), config)?;
|
||||
|
||||
// distil_bert_uncased uses the word embeddings for the vocab projector weight, but has a seperate vocab_projector bias
|
||||
let vocab_projector_weight_vb = vb.pp("distilbert.embeddings.word_embeddings");
|
||||
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
|
||||
let ws = vocab_projector_weight_vb.get_with_hints(
|
||||
(config.vocab_size, config.dim),
|
||||
"weight",
|
||||
init_ws,
|
||||
)?;
|
||||
let bound = 1. / (config.dim as f64).sqrt();
|
||||
let init_bs = candle_nn::Init::Uniform {
|
||||
lo: -bound,
|
||||
up: bound,
|
||||
};
|
||||
|
||||
let vocab_projector_bias_vb = vb.pp("vocab_projector");
|
||||
let bs = vocab_projector_bias_vb.get_with_hints(config.vocab_size, "bias", init_bs)?;
|
||||
|
||||
let decoder = Linear::from_weights(ws, Some(bs));
|
||||
|
||||
Ok(Self { transform, decoder })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DistilBertLMPredictionHead {
|
||||
fn forward(&self, hidden_states: &Tensor) -> Result<Tensor> {
|
||||
self.decoder
|
||||
.forward(&self.transform.forward(hidden_states)?)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/1bd604d11c405dfb8b78bda4062d88fc75c17de0/src/transformers/models/bert/modeling_bert.py#L792
|
||||
pub struct DistilBertOnlyMLMHead {
|
||||
predictions: DistilBertLMPredictionHead,
|
||||
}
|
||||
|
||||
impl DistilBertOnlyMLMHead {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let predictions = DistilBertLMPredictionHead::load(vb.clone(), config)?;
|
||||
Ok(Self { predictions })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DistilBertOnlyMLMHead {
|
||||
fn forward(&self, sequence_output: &Tensor) -> Result<Tensor> {
|
||||
self.predictions.forward(sequence_output)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct DistilBertForMaskedLM {
|
||||
pub bert: DistilBertModel,
|
||||
cls: DistilBertOnlyMLMHead,
|
||||
}
|
||||
|
||||
impl DistilBertForMaskedLM {
|
||||
pub fn load(vb: VarBuilder, config: &Config) -> Result<Self> {
|
||||
let bert = DistilBertModel::load(vb.pp("distilbert"), config)?;
|
||||
let cls = DistilBertOnlyMLMHead::load(vb.clone(), config)?;
|
||||
Ok(Self { bert, cls })
|
||||
}
|
||||
|
||||
pub fn forward(&self, input_ids: &Tensor, attention_mask: &Tensor) -> Result<Tensor> {
|
||||
let sequence_output = self.bert.forward(input_ids, attention_mask)?;
|
||||
self.cls.forward(&sequence_output)
|
||||
}
|
||||
}
|
||||
|
@ -141,6 +141,20 @@ pub fn conv1d_weight_norm(
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
pub fn conv1d_weight_norm_no_bias(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "weight_g")?;
|
||||
let weight_v = vb.get((out_c, in_c, kernel_size), "weight_v")?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
Ok(Conv1d::new(weight, None, config))
|
||||
}
|
||||
|
||||
pub fn conv_transpose1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
|
@ -104,6 +104,7 @@ pub mod rwkv_v6;
|
||||
pub mod segformer;
|
||||
pub mod segment_anything;
|
||||
pub mod siglip;
|
||||
pub mod snac;
|
||||
pub mod stable_diffusion;
|
||||
pub mod stable_lm;
|
||||
pub mod starcoder2;
|
||||
|
814
candle-transformers/src/models/snac.rs
Normal file
814
candle-transformers/src/models/snac.rs
Normal file
@ -0,0 +1,814 @@
|
||||
#![allow(unused)]
|
||||
//! Implementation of the Multi-Scale Neural Audio Codec (SNAC)
|
||||
//!
|
||||
//! See: [SNAC](https://github.com/hubertsiuzdak/snac)
|
||||
//!
|
||||
/// Multi-Scale Neural Audio Codec (SNAC) compresses audio into discrete codes at a low bitrate.
|
||||
/// For more information, read the paper: https://arxiv.org/abs/2410.14411
|
||||
///
|
||||
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::{
|
||||
linear_b, Conv1d, Conv1dConfig, ConvTranspose1d, ConvTranspose1dConfig, LayerNorm, Linear,
|
||||
VarBuilder,
|
||||
};
|
||||
|
||||
#[derive(serde::Deserialize, Debug, Clone)]
|
||||
pub struct Config {
|
||||
pub sampling_rate: usize,
|
||||
pub encoder_dim: usize,
|
||||
pub encoder_rates: Vec<usize>,
|
||||
pub decoder_dim: usize,
|
||||
pub decoder_rates: Vec<usize>,
|
||||
pub attn_window_size: Option<usize>,
|
||||
pub codebook_size: usize,
|
||||
pub codebook_dim: usize,
|
||||
pub vq_strides: Vec<usize>,
|
||||
pub noise: bool,
|
||||
pub depthwise: bool,
|
||||
}
|
||||
|
||||
// Equivalent to torch.repeat_interleave
|
||||
pub fn repeat_interleave<D: candle::shape::Dim>(
|
||||
img: &Tensor,
|
||||
repeats: usize,
|
||||
dim: D,
|
||||
) -> Result<Tensor> {
|
||||
if repeats == 1 {
|
||||
return Ok(img.clone());
|
||||
}
|
||||
let dim = dim.to_index(img.shape(), "chunk")?;
|
||||
let img = img.unsqueeze(dim + 1)?;
|
||||
let mut dims = img.dims().to_vec();
|
||||
dims[dim + 1] = repeats;
|
||||
img.broadcast_as(dims)?.flatten(dim, dim + 1)
|
||||
}
|
||||
|
||||
pub fn conv1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
|
||||
let weight_v = {
|
||||
let name = "parametrizations.weight.original1";
|
||||
match vb.get((out_c, in_c, kernel_size), name) {
|
||||
Ok(v) => v,
|
||||
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
|
||||
}
|
||||
};
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = vb.get(out_c, "bias")?;
|
||||
Ok(Conv1d::new(weight, Some(bias), config))
|
||||
}
|
||||
|
||||
pub fn conv1d_weight_norm_no_bias(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
config: candle_nn::Conv1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Conv1d> {
|
||||
let weight_g = vb.get((out_c, 1, 1), "parametrizations.weight.original0")?;
|
||||
let weight_v = {
|
||||
let name = "parametrizations.weight.original1";
|
||||
match vb.get((out_c, in_c, kernel_size), name) {
|
||||
Ok(v) => v,
|
||||
Err(_) => vb.get((out_c, 1, kernel_size), name)?,
|
||||
}
|
||||
};
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
Ok(Conv1d::new(weight, None, config))
|
||||
}
|
||||
|
||||
pub fn conv_transpose1d_weight_norm(
|
||||
in_c: usize,
|
||||
out_c: usize,
|
||||
kernel_size: usize,
|
||||
bias: bool,
|
||||
config: candle_nn::ConvTranspose1dConfig,
|
||||
vb: VarBuilder,
|
||||
) -> Result<ConvTranspose1d> {
|
||||
let weight_g = vb.get((in_c, 1, 1), "parametrizations.weight.original0")?;
|
||||
let weight_v = vb.get(
|
||||
(in_c, out_c, kernel_size),
|
||||
"parametrizations.weight.original1",
|
||||
)?;
|
||||
let norm_v = weight_v.sqr()?.sum_keepdim((1, 2))?.sqrt()?;
|
||||
let weight = weight_v.broadcast_mul(&weight_g)?.broadcast_div(&norm_v)?;
|
||||
let bias = if bias {
|
||||
Some(vb.get(out_c, "bias")?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(ConvTranspose1d::new(weight, bias, config))
|
||||
}
|
||||
|
||||
// https://github.com/hubertsiuzdak/snac/blob/main/snac/attention.py
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct SinusoidalEmbeddings {
|
||||
inv_freq: Tensor,
|
||||
scale: Tensor,
|
||||
scale_base: f32,
|
||||
use_xpos: bool,
|
||||
}
|
||||
|
||||
impl SinusoidalEmbeddings {
|
||||
fn new(dim: usize, scale_base: f32, use_xpos: bool, dev: &Device) -> Result<Self> {
|
||||
let inv_freq: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| 1f32 / 10_000f32.powf(i as f32 / dim as f32))
|
||||
.collect();
|
||||
let len = inv_freq.len();
|
||||
let inv_freq = Tensor::from_vec(inv_freq, len, dev)?.to_dtype(DType::F32)?;
|
||||
let scale: Vec<_> = (0..dim)
|
||||
.step_by(2)
|
||||
.map(|i| (i as f32 + 0.4 * dim as f32) / (1.4 * dim as f32))
|
||||
.collect();
|
||||
let scale = Tensor::from_vec(scale, len, dev)?.to_dtype(DType::F32)?;
|
||||
Ok(Self {
|
||||
inv_freq,
|
||||
scale,
|
||||
scale_base,
|
||||
use_xpos,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
#[derive(Debug, Clone)]
|
||||
struct LocalMHA {
|
||||
norm: LayerNorm,
|
||||
to_qkv: Linear,
|
||||
to_out: Linear,
|
||||
num_heads: usize,
|
||||
head_dim: usize,
|
||||
rel_pos: Option<SinusoidalEmbeddings>,
|
||||
}
|
||||
|
||||
impl LocalMHA {
|
||||
fn new(
|
||||
dim: usize,
|
||||
window_size: usize,
|
||||
dim_head: usize,
|
||||
use_rotary_pos_emb: bool,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
|
||||
let to_qkv = linear_b(dim, dim * 3, false, vb.pp("to_qkv"))?;
|
||||
let to_out = linear_b(dim, dim, false, vb.pp("to_out"))?;
|
||||
let rel_pos = if use_rotary_pos_emb {
|
||||
let rel_pos =
|
||||
SinusoidalEmbeddings::new(dim_head, window_size as f32 / 2.0, false, vb.device())?;
|
||||
Some(rel_pos)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
norm,
|
||||
to_qkv,
|
||||
to_out,
|
||||
rel_pos,
|
||||
num_heads: dim / dim_head,
|
||||
head_dim: dim_head,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for LocalMHA {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, c, t) = xs.dims3()?;
|
||||
let residual = xs.clone();
|
||||
let xs = xs.transpose(1, 2)?.apply(&self.norm)?;
|
||||
let qkv = xs.apply(&self.to_qkv)?;
|
||||
let q = qkv.narrow(D::Minus1, 0, c)?;
|
||||
let k = qkv.narrow(D::Minus1, c, c)?;
|
||||
let v = qkv.narrow(D::Minus1, 2 * c, c)?;
|
||||
let q = q
|
||||
.reshape((b, t, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let k = k
|
||||
.reshape((b, t, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let v = v
|
||||
.reshape((b, t, self.num_heads, self.head_dim))?
|
||||
.transpose(1, 2)?
|
||||
.contiguous()?;
|
||||
let (q, k) = match self.rel_pos {
|
||||
Some(_) => todo!(),
|
||||
None => (q, k),
|
||||
};
|
||||
let out = {
|
||||
let scale = 1f64 / f64::sqrt(self.head_dim as f64);
|
||||
let attn_weights = (q.matmul(&k.transpose(2, 3)?)? * scale)?;
|
||||
// Non-causal attention
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
attn_weights.matmul(&v)?
|
||||
};
|
||||
let out = out
|
||||
.transpose(1, 2)?
|
||||
.reshape((b, t, self.num_heads * self.head_dim))?
|
||||
.apply(&self.to_out)?;
|
||||
out.transpose(1, 2)? + residual
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct Snake1d {
|
||||
alpha: Tensor,
|
||||
}
|
||||
|
||||
impl Snake1d {
|
||||
pub fn new(channels: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let alpha = vb.get((1, channels, 1), "alpha")?;
|
||||
Ok(Self { alpha })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Snake1d {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let xs_shape = xs.shape();
|
||||
let xs = xs.flatten_from(2)?;
|
||||
let sin = self.alpha.broadcast_mul(&xs)?.sin()?;
|
||||
let sin = (&sin * &sin)?;
|
||||
(xs + (&self.alpha + 1e-9)?.recip()?.broadcast_mul(&sin)?)?.reshape(xs_shape)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ResidualUnit {
|
||||
snake1: Snake1d,
|
||||
conv1: Conv1d,
|
||||
snake2: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl ResidualUnit {
|
||||
fn new(
|
||||
dim: usize,
|
||||
dilation: usize,
|
||||
kernel: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let pad = ((kernel - 1) * dilation) / 2;
|
||||
let vb = vb.pp("block");
|
||||
let snake1 = Snake1d::new(dim, vb.pp(0))?;
|
||||
let cfg1 = Conv1dConfig {
|
||||
dilation,
|
||||
padding: pad,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = conv1d_weight_norm(dim, dim, 7, cfg1, vb.pp(1))?;
|
||||
let snake2 = Snake1d::new(dim, vb.pp(2))?;
|
||||
let conv2 = conv1d_weight_norm(dim, dim, 1, Default::default(), vb.pp(3))?;
|
||||
Ok(Self {
|
||||
snake1,
|
||||
conv1,
|
||||
snake2,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for ResidualUnit {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let ys = xs
|
||||
.apply(&self.snake1)?
|
||||
.apply(&self.conv1)?
|
||||
.apply(&self.snake2)?
|
||||
.apply(&self.conv2)?;
|
||||
let pad = (xs.dim(D::Minus1)? - ys.dim(D::Minus1)?) / 2;
|
||||
if pad > 0 {
|
||||
&ys + xs.narrow(D::Minus1, pad, ys.dim(D::Minus1)?)
|
||||
} else {
|
||||
ys + xs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct NoiseBlock {
|
||||
linear: Conv1d,
|
||||
}
|
||||
|
||||
impl NoiseBlock {
|
||||
fn new(dim: usize, vb: VarBuilder) -> Result<Self> {
|
||||
let linear = conv1d_weight_norm_no_bias(dim, dim, 1, Default::default(), vb.pp("linear"))?;
|
||||
Ok(Self { linear })
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for NoiseBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let (b, _c, t) = xs.dims3()?;
|
||||
let noise = Tensor::randn(0f32, 1f32, (b, 1, t), xs.device())?;
|
||||
let h = xs.apply(&self.linear)?;
|
||||
let n = noise.broadcast_mul(&h)?;
|
||||
let xs = (xs + n)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct DecoderBlock {
|
||||
snake1: Snake1d,
|
||||
conv_tr1: ConvTranspose1d,
|
||||
noise: Option<NoiseBlock>,
|
||||
res1: ResidualUnit,
|
||||
res2: ResidualUnit,
|
||||
res3: ResidualUnit,
|
||||
}
|
||||
|
||||
impl DecoderBlock {
|
||||
fn new(
|
||||
in_dim: usize,
|
||||
out_dim: usize,
|
||||
stride: usize,
|
||||
noise: bool,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||
let cfg = ConvTranspose1dConfig {
|
||||
stride,
|
||||
padding: stride.div_ceil(2),
|
||||
output_padding: stride % 2,
|
||||
..Default::default()
|
||||
};
|
||||
let conv_tr1 =
|
||||
conv_transpose1d_weight_norm(in_dim, out_dim, 2 * stride, true, cfg, vb.pp(1))?;
|
||||
let (n, noise) = if noise {
|
||||
let noise = NoiseBlock::new(out_dim, vb.pp(2))?;
|
||||
(1, Some(noise))
|
||||
} else {
|
||||
(0, None)
|
||||
};
|
||||
let res1 = ResidualUnit::new(out_dim, 1, 7, groups, vb.pp(2 + n))?;
|
||||
let res2 = ResidualUnit::new(out_dim, 3, 7, groups, vb.pp(3 + n))?;
|
||||
let res3 = ResidualUnit::new(out_dim, 9, 7, groups, vb.pp(4 + n))?;
|
||||
Ok(Self {
|
||||
snake1,
|
||||
conv_tr1,
|
||||
noise,
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for DecoderBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.snake1)?
|
||||
.apply(&self.conv_tr1)?
|
||||
.apply(&self.noise.as_ref())?
|
||||
.apply(&self.res1)?
|
||||
.apply(&self.res2)?
|
||||
.apply(&self.res3)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct EncoderBlock {
|
||||
res1: ResidualUnit,
|
||||
res2: ResidualUnit,
|
||||
res3: ResidualUnit,
|
||||
snake1: Snake1d,
|
||||
conv1: Conv1d,
|
||||
}
|
||||
|
||||
impl EncoderBlock {
|
||||
fn new(
|
||||
out_dim: usize,
|
||||
in_dim: Option<usize>,
|
||||
stride: usize,
|
||||
groups: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let in_dim = in_dim.unwrap_or(out_dim / 2);
|
||||
let res1 = ResidualUnit::new(in_dim, 1, 7, groups, vb.pp(0))?;
|
||||
let res2 = ResidualUnit::new(in_dim, 3, 7, groups, vb.pp(1))?;
|
||||
let res3 = ResidualUnit::new(in_dim, 9, 7, groups, vb.pp(2))?;
|
||||
let snake1 = Snake1d::new(in_dim, vb.pp(3))?;
|
||||
let cfg1 = Conv1dConfig {
|
||||
stride,
|
||||
padding: stride.div_ceil(2),
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = conv1d_weight_norm(in_dim, out_dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||
Ok(Self {
|
||||
res1,
|
||||
res2,
|
||||
res3,
|
||||
snake1,
|
||||
conv1,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for EncoderBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
xs.apply(&self.res1)?
|
||||
.apply(&self.res2)?
|
||||
.apply(&self.res3)?
|
||||
.apply(&self.snake1)?
|
||||
.apply(&self.conv1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Encoder {
|
||||
conv1: Conv1d,
|
||||
blocks: Vec<EncoderBlock>,
|
||||
local_mha: Option<LocalMHA>,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl candle::Module for Encoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = xs.apply(&self.conv1)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
xs.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
impl Encoder {
|
||||
fn new(
|
||||
mut d_model: usize,
|
||||
strides: &[usize],
|
||||
depthwise: bool,
|
||||
attn_window_size: Option<usize>,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("block");
|
||||
let mut idx = 0;
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = conv1d_weight_norm(1, d_model, 7, cfg1, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
let mut blocks = Vec::with_capacity(strides.len());
|
||||
for &stride in strides.iter() {
|
||||
d_model *= 2;
|
||||
let groups = if depthwise { d_model / 2 } else { 1 };
|
||||
let block = EncoderBlock::new(d_model, None, stride, groups, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
blocks.push(block)
|
||||
}
|
||||
let local_mha = match attn_window_size {
|
||||
Some(w) => {
|
||||
let mha = LocalMHA::new(d_model, w, 64, true, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
Some(mha)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
let groups = if depthwise { d_model } else { 1 };
|
||||
let cfg2 = Conv1dConfig {
|
||||
padding: 3,
|
||||
groups,
|
||||
..Default::default()
|
||||
};
|
||||
let conv2 = conv1d_weight_norm(d_model, d_model, 7, cfg2, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
blocks,
|
||||
local_mha,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
enum ConvInit {
|
||||
Depthwise(Conv1d, Conv1d),
|
||||
Standard(Conv1d),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Decoder {
|
||||
conv1: ConvInit,
|
||||
local_mha: Option<LocalMHA>,
|
||||
blocks: Vec<DecoderBlock>,
|
||||
snake1: Snake1d,
|
||||
conv2: Conv1d,
|
||||
}
|
||||
|
||||
impl Decoder {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn new(
|
||||
in_c: usize,
|
||||
mut channels: usize,
|
||||
rates: &[usize],
|
||||
noise: bool,
|
||||
depthwise: bool,
|
||||
attn_window_size: Option<usize>,
|
||||
d_out: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = vb.pp("model");
|
||||
let mut idx = 0;
|
||||
let pad3 = Conv1dConfig {
|
||||
padding: 3,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = if depthwise {
|
||||
let cfg1 = Conv1dConfig {
|
||||
padding: 3,
|
||||
groups: in_c,
|
||||
..Default::default()
|
||||
};
|
||||
let conv1 = conv1d_weight_norm(in_c, in_c, 7, cfg1, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
let conv2 = conv1d_weight_norm(in_c, channels, 1, Default::default(), vb.pp(idx))?;
|
||||
idx += 1;
|
||||
ConvInit::Depthwise(conv1, conv2)
|
||||
} else {
|
||||
let conv1 = conv1d_weight_norm(in_c, channels, 7, pad3, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
ConvInit::Standard(conv1)
|
||||
};
|
||||
let mut blocks = Vec::with_capacity(rates.len());
|
||||
let local_mha = match attn_window_size {
|
||||
Some(w) => {
|
||||
let mha = LocalMHA::new(channels, w, 64, true, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
Some(mha)
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
for stride in rates.iter() {
|
||||
let groups = if depthwise { channels / 2 } else { 1 };
|
||||
let block =
|
||||
DecoderBlock::new(channels, channels / 2, *stride, noise, groups, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
channels /= 2;
|
||||
blocks.push(block)
|
||||
}
|
||||
let snake1 = Snake1d::new(channels, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
let conv2 = conv1d_weight_norm(channels, d_out, 7, pad3, vb.pp(idx))?;
|
||||
idx += 1;
|
||||
Ok(Self {
|
||||
conv1,
|
||||
local_mha,
|
||||
blocks,
|
||||
snake1,
|
||||
conv2,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl candle::Module for Decoder {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let mut xs = match &self.conv1 {
|
||||
ConvInit::Standard(c) => xs.apply(c)?,
|
||||
ConvInit::Depthwise(c1, c2) => xs.apply(c1)?.apply(c2)?,
|
||||
};
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?
|
||||
}
|
||||
xs.apply(&self.snake1)?.apply(&self.conv2)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize(v: &Tensor) -> Result<Tensor> {
|
||||
v.broadcast_div(&v.sqr()?.sum_keepdim(1)?.sqrt()?)
|
||||
}
|
||||
|
||||
// https://github.com/hubertsiuzdak/snac/blob/main/snac/vq.py
|
||||
#[allow(unused)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct VectorQuantizer {
|
||||
in_proj: Conv1d,
|
||||
out_proj: Conv1d,
|
||||
codebook: candle_nn::Embedding,
|
||||
stride: usize,
|
||||
}
|
||||
|
||||
impl VectorQuantizer {
|
||||
fn new(
|
||||
in_dim: usize,
|
||||
cb_size: usize,
|
||||
cb_dim: usize,
|
||||
stride: usize,
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let in_proj = conv1d_weight_norm(in_dim, cb_dim, 1, Default::default(), vb.pp("in_proj"))?;
|
||||
let out_proj =
|
||||
conv1d_weight_norm(cb_dim, in_dim, 1, Default::default(), vb.pp("out_proj"))?;
|
||||
let codebook = candle_nn::embedding(cb_size, cb_dim, vb.pp("codebook"))?;
|
||||
Ok(Self {
|
||||
in_proj,
|
||||
out_proj,
|
||||
codebook,
|
||||
stride,
|
||||
})
|
||||
}
|
||||
|
||||
fn decode_latents(&self, latents: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let (b, d, t) = latents.dims3()?;
|
||||
let encodings = latents.transpose(1, 2)?.reshape((b * t, d))?;
|
||||
let encodings = normalize(&encodings)?;
|
||||
let codebook = normalize(self.codebook.embeddings())?;
|
||||
let dist = (encodings
|
||||
.sqr()?
|
||||
.sum_keepdim(1)?
|
||||
.broadcast_sub(&encodings.matmul(&codebook.t()?)?)?
|
||||
* 2.0)?
|
||||
.broadcast_add(&codebook.sqr()?.sum_keepdim(1)?.t()?)?;
|
||||
let indices = dist.argmin(1)?.reshape((b, ()))?;
|
||||
let z_q = self.decode_code(&indices)?;
|
||||
Ok((z_q, indices))
|
||||
}
|
||||
|
||||
fn encode(&self, z: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let z = if self.stride > 1 {
|
||||
let (b, c, t) = z.dims3()?;
|
||||
z.reshape((b, c, 1, t))?
|
||||
.avg_pool2d((1, self.stride))?
|
||||
.squeeze(2)?
|
||||
} else {
|
||||
z.clone()
|
||||
};
|
||||
let z_e = z.apply(&self.in_proj)?;
|
||||
let (z_q, indices) = self.decode_latents(&z_e)?;
|
||||
let z_q = z_q.apply(&self.out_proj)?;
|
||||
let z_q = if self.stride > 1 {
|
||||
repeat_interleave(&z_q, self.stride, D::Minus1)?
|
||||
} else {
|
||||
z_q
|
||||
};
|
||||
Ok((z_q, indices))
|
||||
}
|
||||
|
||||
fn embed_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||
embed_id.apply(&self.codebook)
|
||||
}
|
||||
|
||||
fn decode_code(&self, embed_id: &Tensor) -> Result<Tensor> {
|
||||
self.embed_code(embed_id)?.transpose(1, 2)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ResidualVectorQuantizer {
|
||||
quantizers: Vec<VectorQuantizer>,
|
||||
}
|
||||
|
||||
impl ResidualVectorQuantizer {
|
||||
fn new(
|
||||
input_dim: usize,
|
||||
cb_size: usize,
|
||||
cb_dim: usize,
|
||||
vq_strides: &[usize],
|
||||
vb: VarBuilder,
|
||||
) -> Result<Self> {
|
||||
let vb = &vb.pp("quantizers");
|
||||
let quantizers = vq_strides
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, stride)| VectorQuantizer::new(input_dim, cb_size, cb_dim, *stride, vb.pp(i)))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
Ok(Self { quantizers })
|
||||
}
|
||||
|
||||
fn encode(&self, z: &Tensor) -> Result<(Tensor, Vec<Tensor>)> {
|
||||
let mut residual = z.clone();
|
||||
let mut z_q = z.zeros_like()?;
|
||||
let mut codes = Vec::with_capacity(self.quantizers.len());
|
||||
for quantizer in self.quantizers.iter() {
|
||||
let (z_q_i, indices_i) = quantizer.encode(&residual)?;
|
||||
z_q = (z_q + &z_q_i)?;
|
||||
residual = (residual - &z_q_i)?;
|
||||
codes.push(indices_i)
|
||||
}
|
||||
Ok((z_q, codes))
|
||||
}
|
||||
|
||||
#[allow(clippy::wrong_self_convention)]
|
||||
fn from_codes(&self, codes: &[&Tensor]) -> Result<Tensor> {
|
||||
let mut sum = None;
|
||||
for (quantizer, codes) in self.quantizers.iter().zip(codes.iter()) {
|
||||
let z_p_i = quantizer.decode_code(codes)?;
|
||||
let z_q_i = z_p_i.apply(&quantizer.out_proj)?;
|
||||
let z_q_i = repeat_interleave(&z_q_i, quantizer.stride, D::Minus1)?;
|
||||
let s = match sum {
|
||||
None => z_q_i,
|
||||
Some(s) => (s + z_q_i)?,
|
||||
};
|
||||
sum = Some(s)
|
||||
}
|
||||
match sum {
|
||||
Some(s) => Ok(s),
|
||||
None => candle::bail!("empty codebooks"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gcd(mut a: usize, mut b: usize) -> usize {
|
||||
while b != 0 {
|
||||
let t = b;
|
||||
b = a % b;
|
||||
a = t;
|
||||
}
|
||||
a
|
||||
}
|
||||
|
||||
fn lcm(a: usize, b: usize) -> usize {
|
||||
a / gcd(a, b) * b
|
||||
}
|
||||
|
||||
// https://github.com/hubertsiuzdak/snac/blob/main/snac/snac.py
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Model {
|
||||
pub encoder: Encoder,
|
||||
pub quantizer: ResidualVectorQuantizer,
|
||||
pub decoder: Decoder,
|
||||
pub hop_length: usize,
|
||||
pub config: Config,
|
||||
}
|
||||
|
||||
impl Model {
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let encoder = Encoder::new(
|
||||
cfg.encoder_dim,
|
||||
&cfg.encoder_rates,
|
||||
cfg.depthwise,
|
||||
cfg.attn_window_size,
|
||||
vb.pp("encoder"),
|
||||
)?;
|
||||
let latent_dim = cfg.encoder_dim * 2usize.pow(cfg.encoder_rates.len() as u32);
|
||||
let quantizer = ResidualVectorQuantizer::new(
|
||||
latent_dim,
|
||||
cfg.codebook_size,
|
||||
cfg.codebook_dim,
|
||||
&cfg.vq_strides,
|
||||
vb.pp("quantizer"),
|
||||
)?;
|
||||
let decoder = Decoder::new(
|
||||
latent_dim,
|
||||
cfg.decoder_dim,
|
||||
&cfg.decoder_rates,
|
||||
cfg.noise,
|
||||
cfg.depthwise,
|
||||
cfg.attn_window_size,
|
||||
/* d_out */ 1,
|
||||
vb.pp("decoder"),
|
||||
)?;
|
||||
let hop_length = cfg.encoder_rates.iter().product::<usize>();
|
||||
Ok(Self {
|
||||
encoder,
|
||||
decoder,
|
||||
quantizer,
|
||||
config: cfg.clone(),
|
||||
hop_length,
|
||||
})
|
||||
}
|
||||
|
||||
fn preprocess(&self, audio_data: &Tensor) -> Result<Tensor> {
|
||||
let len = audio_data.dim(D::Minus1)?;
|
||||
let lcm = lcm(
|
||||
self.config.vq_strides[0],
|
||||
self.config.attn_window_size.unwrap_or(1),
|
||||
);
|
||||
let pad_to = self.hop_length * lcm;
|
||||
let right_pad = len.div_ceil(pad_to) * pad_to - len;
|
||||
let audio_data = audio_data.pad_with_zeros(D::Minus1, 0, right_pad)?;
|
||||
Ok(audio_data)
|
||||
}
|
||||
|
||||
pub fn encode(&self, audio_data: &Tensor) -> Result<Vec<Tensor>> {
|
||||
let audio_data = self.preprocess(audio_data)?;
|
||||
let z = self.encoder.forward(&audio_data)?;
|
||||
let (_, codes) = self.quantizer.encode(&z)?;
|
||||
Ok(codes)
|
||||
}
|
||||
|
||||
pub fn decode(&self, audio_codes: &[&Tensor]) -> Result<Tensor> {
|
||||
let audio_values = self.quantizer.from_codes(audio_codes)?;
|
||||
audio_values.apply(&self.decoder)
|
||||
}
|
||||
|
||||
pub fn config(&self) -> &Config {
|
||||
&self.config
|
||||
}
|
||||
|
||||
pub fn num_codebooks(&self) -> usize {
|
||||
self.quantizer.quantizers.len()
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user