mirror of
https://github.com/huggingface/candle.git
synced 2025-06-17 11:08:52 +00:00
Compare commits
2 Commits
fix-1.86
...
0.9.0-alph
Author | SHA1 | Date | |
---|---|---|---|
cf9d7bf24c | |||
9d31361c4f |
@ -816,7 +816,7 @@ impl PthTensors {
|
|||||||
/// # Arguments
|
/// # Arguments
|
||||||
/// * `path` - Path to the pth file.
|
/// * `path` - Path to the pth file.
|
||||||
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
/// * `key` - Optional key to retrieve `state_dict` from the pth file. Sometimes the pth file
|
||||||
/// contains multiple objects and the state_dict is the one we are interested in.
|
/// contains multiple objects and the state_dict is the one we are interested in.
|
||||||
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
pub fn read_all_with_key<P: AsRef<std::path::Path>>(
|
||||||
path: P,
|
path: P,
|
||||||
key: Option<&str>,
|
key: Option<&str>,
|
||||||
|
14
candle-examples/examples/csm/README.md
Normal file
14
candle-examples/examples/csm/README.md
Normal file
@ -0,0 +1,14 @@
|
|||||||
|
# Conversational Speech Model (CSM)
|
||||||
|
|
||||||
|
CSM is a speech generation model from Sesame,
|
||||||
|
[SesameAILabs/csm](https://github.com/SesameAILabs/csm).
|
||||||
|
|
||||||
|
It can generate a conversational speech between two different speakers.
|
||||||
|
The speakers turn are delimited by the `|` character in the prompt.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example csm --features cuda -r -- \
|
||||||
|
--voices voices.safetensors \
|
||||||
|
--prompt "Hey how are you doing?|Pretty good, pretty good. How about you?"
|
||||||
|
```
|
||||||
|
|
@ -34,9 +34,18 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
|
|
||||||
#[arg(long, default_value = "[0]Hey how are you doing?")]
|
/// The prompt to be used for the generation, use a | to separate the speakers.
|
||||||
|
#[arg(long, default_value = "Hey how are you doing today?")]
|
||||||
prompt: String,
|
prompt: String,
|
||||||
|
|
||||||
|
/// The voices to be used, in safetensors format.
|
||||||
|
#[arg(long)]
|
||||||
|
voices: String,
|
||||||
|
|
||||||
|
/// The output file using the wav format.
|
||||||
|
#[arg(long, default_value = "out.wav")]
|
||||||
|
out_file: String,
|
||||||
|
|
||||||
/// The temperature used to generate samples.
|
/// The temperature used to generate samples.
|
||||||
#[arg(long, default_value_t = 0.7)]
|
#[arg(long, default_value_t = 0.7)]
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
@ -162,7 +171,7 @@ fn main() -> Result<()> {
|
|||||||
};
|
};
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let (mut model, device) = {
|
let (mut model, device) = {
|
||||||
let dtype = DType::F32;
|
let dtype = device.bf16_default_to_f32();
|
||||||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
(model, device)
|
(model, device)
|
||||||
@ -177,45 +186,58 @@ fn main() -> Result<()> {
|
|||||||
let cb = config.audio_num_codebooks;
|
let cb = config.audio_num_codebooks;
|
||||||
|
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
if args.prompt.ends_with(".safetensors") {
|
|
||||||
let prompt = candle::safetensors::load(args.prompt, &device)?;
|
let voices = candle::safetensors::load(args.voices, &device)?;
|
||||||
let mut tokens = prompt
|
let mut lp = candle_transformers::generation::LogitsProcessor::new(
|
||||||
.get("tokens")
|
args.seed,
|
||||||
.expect("no tokens in prompt")
|
Some(args.temperature),
|
||||||
.to_dtype(DType::U32)?;
|
None,
|
||||||
let mut mask = prompt.get("mask").expect("no mask in prompt").clone();
|
);
|
||||||
println!("tokens:\n{tokens:?}");
|
let tokens = voices
|
||||||
println!("mask:\n{mask:?}");
|
.get("tokens")
|
||||||
let mut lp = candle_transformers::generation::LogitsProcessor::new(42, None, None);
|
.expect("no tokens in prompt")
|
||||||
let mut const_mask = vec![1u8; cb];
|
.to_dtype(DType::U32)?;
|
||||||
const_mask.push(0);
|
let mask = voices.get("mask").expect("no mask in prompt").clone();
|
||||||
let const_mask = Tensor::from_vec(const_mask, (1, 1, cb + 1), &device)?;
|
|
||||||
let mut pos = 0;
|
let mut pos = 0;
|
||||||
let mut all_tokens = vec![];
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
for i in 0.. {
|
pos += tokens.dim(1)?;
|
||||||
let mut frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
|
||||||
|
let mut all_pcms = vec![];
|
||||||
|
for (turn_idx, prompt) in args.prompt.split('|').enumerate() {
|
||||||
|
println!("{prompt:?}");
|
||||||
|
let speaker_idx = turn_idx % 2;
|
||||||
|
let prompt = format!("[{speaker_idx}]{}<|end_of_text|>", prompt);
|
||||||
|
let prompt = tokenizer.encode(prompt, true).map_err(E::msg)?;
|
||||||
|
|
||||||
|
let (mut tokens, mut mask) = model.text_tokens_and_mask(prompt.get_ids())?;
|
||||||
|
|
||||||
|
let mut generated_tokens = vec![];
|
||||||
|
loop {
|
||||||
|
let frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
pos += tokens.dim(1)?;
|
pos += tokens.dim(1)?;
|
||||||
frame.push(0);
|
let is_done = frame.iter().all(|&x| x == 0);
|
||||||
if frame.iter().all(|&x| x == 0) {
|
(tokens, mask) = model.audio_tokens_and_mask(frame)?;
|
||||||
|
print!("\rframe {pos}");
|
||||||
|
if is_done {
|
||||||
|
let _frame = model.generate_frame(&tokens, &mask, pos, &mut lp)?;
|
||||||
|
pos += tokens.dim(1)?;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
println!("frame {i} {pos}:\n{frame:?}");
|
generated_tokens.push(tokens.clone());
|
||||||
tokens = Tensor::from_vec(frame, (1, 1, cb + 1), &device)?;
|
|
||||||
all_tokens.push(tokens.clone());
|
|
||||||
mask = const_mask.clone();
|
|
||||||
}
|
}
|
||||||
let all_tokens = Tensor::cat(&all_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
println!();
|
||||||
println!("all_tokens:\n{all_tokens:?}");
|
let generated_tokens = Tensor::cat(&generated_tokens, 1)?.narrow(2, 0, cb)?.t()?;
|
||||||
let pcm = mimi_model.decode(&all_tokens)?;
|
let pcm = mimi_model.decode(&generated_tokens)?;
|
||||||
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
let pcm = pcm.i(0)?.i(0)?.to_dtype(DType::F32)?;
|
||||||
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
let pcm = candle_examples::audio::normalize_loudness(&pcm, 24_000, true)?;
|
||||||
let pcm = pcm.to_vec1::<f32>()?;
|
all_pcms.push(pcm);
|
||||||
let mut output = std::fs::File::create("out.wav")?;
|
|
||||||
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
|
||||||
} else {
|
|
||||||
let prompt = tokenizer.encode(args.prompt, true).map_err(E::msg)?;
|
|
||||||
println!("{prompt:?}");
|
|
||||||
}
|
}
|
||||||
|
let pcm = Tensor::cat(&all_pcms, 0)?;
|
||||||
|
let pcm = pcm.to_vec1::<f32>()?;
|
||||||
|
println!("writing output file {}", args.out_file);
|
||||||
|
let mut output = std::fs::File::create(args.out_file)?;
|
||||||
|
candle_examples::wav::write_pcm_as_wav(&mut output, &pcm, 24_000)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -21,7 +21,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_conv(&self) -> usize {
|
fn d_conv(&self) -> usize {
|
||||||
|
@ -498,4 +498,36 @@ impl Model {
|
|||||||
}
|
}
|
||||||
Ok(all_samples)
|
Ok(all_samples)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn audio_tokens_and_mask(&self, mut frame: Vec<u32>) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut mask = vec![1u8; cb];
|
||||||
|
mask.push(0);
|
||||||
|
let mask = Tensor::from_vec(mask, (1, 1, cb + 1), device)?;
|
||||||
|
|
||||||
|
frame.push(0);
|
||||||
|
let tokens = Tensor::from_vec(frame, (1, 1, cb + 1), device)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn text_tokens_and_mask(&self, ids: &[u32]) -> Result<(Tensor, Tensor)> {
|
||||||
|
let cb = self.config.audio_num_codebooks;
|
||||||
|
let device = &self.backbone.device;
|
||||||
|
let mut tokens = vec![];
|
||||||
|
let mut mask = vec![];
|
||||||
|
for &v in ids.iter() {
|
||||||
|
let mut token = vec![0; cb];
|
||||||
|
token.push(v);
|
||||||
|
let token = Tensor::from_vec(token, (1, 1, cb + 1), device)?;
|
||||||
|
tokens.push(token);
|
||||||
|
let mut m = vec![0u8; cb];
|
||||||
|
m.push(1);
|
||||||
|
let m = Tensor::from_vec(m, (1, 1, cb + 1), device)?;
|
||||||
|
mask.push(m);
|
||||||
|
}
|
||||||
|
let tokens = Tensor::cat(&tokens, 1)?;
|
||||||
|
let mask = Tensor::cat(&mask, 1)?;
|
||||||
|
Ok((tokens, mask))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,7 +104,7 @@ impl EncoderBlock {
|
|||||||
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
let snake1 = Snake1d::new(dim / 2, vb.pp(3))?;
|
||||||
let cfg1 = Conv1dConfig {
|
let cfg1 = Conv1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
let conv1 = encodec::conv1d_weight_norm(dim / 2, dim, 2 * stride, cfg1, vb.pp(4))?;
|
||||||
@ -196,7 +196,7 @@ impl DecoderBlock {
|
|||||||
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
let snake1 = Snake1d::new(in_dim, vb.pp(0))?;
|
||||||
let cfg = ConvTranspose1dConfig {
|
let cfg = ConvTranspose1dConfig {
|
||||||
stride,
|
stride,
|
||||||
padding: (stride + 1) / 2,
|
padding: stride.div_ceil(2),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
let conv_tr1 = encodec::conv_transpose1d_weight_norm(
|
||||||
|
@ -6,8 +6,8 @@ pub fn get_noise(
|
|||||||
width: usize,
|
width: usize,
|
||||||
device: &Device,
|
device: &Device,
|
||||||
) -> Result<Tensor> {
|
) -> Result<Tensor> {
|
||||||
let height = (height + 15) / 16 * 2;
|
let height = height.div_ceil(16) * 2;
|
||||||
let width = (width + 15) / 16 * 2;
|
let width = width.div_ceil(16) * 2;
|
||||||
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
Tensor::randn(0f32, 1., (num_samples, 16, height, width), device)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -84,8 +84,8 @@ pub fn get_schedule(num_steps: usize, shift: Option<(usize, f64, f64)>) -> Vec<f
|
|||||||
|
|
||||||
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
pub fn unpack(xs: &Tensor, height: usize, width: usize) -> Result<Tensor> {
|
||||||
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
let (b, _h_w, c_ph_pw) = xs.dims3()?;
|
||||||
let height = (height + 15) / 16;
|
let height = height.div_ceil(16);
|
||||||
let width = (width + 15) / 16;
|
let width = width.div_ceil(16);
|
||||||
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
xs.reshape((b, height, width, c_ph_pw / 4, 2, 2))? // (b, h, w, c, ph, pw)
|
||||||
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
.permute((0, 3, 1, 4, 2, 5))? // (b, c, h, ph, w, pw)
|
||||||
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
.reshape((b, c_ph_pw / 4, height * 2, width * 2))
|
||||||
|
@ -27,7 +27,7 @@ impl Config {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn dt_rank(&self) -> usize {
|
fn dt_rank(&self) -> usize {
|
||||||
(self.d_model + 15) / 16
|
self.d_model.div_ceil(16)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn d_inner(&self) -> usize {
|
fn d_inner(&self) -> usize {
|
||||||
|
@ -716,7 +716,7 @@ pub mod transformer {
|
|||||||
None => {
|
None => {
|
||||||
let hidden_dim = self.dim * 4;
|
let hidden_dim = self.dim * 4;
|
||||||
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
let n_hidden = ((2 * hidden_dim) as f64 / 3.) as usize;
|
||||||
(n_hidden + 255) / 256 * 256
|
n_hidden.div_ceil(256) * 256
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -198,7 +198,7 @@ pub fn log_mel_spectrogram_<T: Float>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -177,7 +177,7 @@ fn log_mel_spectrogram_<T: Float + std::fmt::Display>(
|
|||||||
let samples = {
|
let samples = {
|
||||||
let mut samples_padded = samples.to_vec();
|
let mut samples_padded = samples.to_vec();
|
||||||
let to_add = n_len * fft_step - samples.len();
|
let to_add = n_len * fft_step - samples.len();
|
||||||
samples_padded.extend(std::iter::repeat(zero).take(to_add));
|
samples_padded.extend(std::iter::repeat_n(zero, to_add));
|
||||||
samples_padded
|
samples_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user