mirror of
https://github.com/huggingface/candle.git
synced 2025-06-18 03:28:50 +00:00
Compare commits
5 Commits
0.9.0-alph
...
spkemb
Author | SHA1 | Date | |
---|---|---|---|
9dc53ec8ad | |||
577316bc4e | |||
b5ee026cea | |||
52ed77c16f | |||
dae32d13d6 |
@ -100,4 +100,6 @@ required-features = ["candle-datasets"]
|
|||||||
name = "encodec"
|
name = "encodec"
|
||||||
required-features = ["symphonia"]
|
required-features = ["symphonia"]
|
||||||
|
|
||||||
|
[[example]]
|
||||||
|
name = "metavoice"
|
||||||
|
required-features = ["symphonia"]
|
||||||
|
@ -10,7 +10,9 @@ use std::io::Write;
|
|||||||
|
|
||||||
use candle_transformers::generation::LogitsProcessor;
|
use candle_transformers::generation::LogitsProcessor;
|
||||||
use candle_transformers::models::encodec;
|
use candle_transformers::models::encodec;
|
||||||
use candle_transformers::models::metavoice::{adapters, gpt, tokenizers, transformer};
|
use candle_transformers::models::metavoice::{
|
||||||
|
adapters, gpt, speaker_encoder, tokenizers, transformer,
|
||||||
|
};
|
||||||
|
|
||||||
use candle::{DType, IndexOp, Tensor};
|
use candle::{DType, IndexOp, Tensor};
|
||||||
use candle_nn::VarBuilder;
|
use candle_nn::VarBuilder;
|
||||||
@ -19,6 +21,60 @@ use rand::{distributions::Distribution, SeedableRng};
|
|||||||
|
|
||||||
pub const ENCODEC_NTOKENS: u32 = 1024;
|
pub const ENCODEC_NTOKENS: u32 = 1024;
|
||||||
|
|
||||||
|
fn conv<T>(samples: &mut Vec<f32>, data: std::borrow::Cow<symphonia::core::audio::AudioBuffer<T>>)
|
||||||
|
where
|
||||||
|
T: symphonia::core::sample::Sample,
|
||||||
|
f32: symphonia::core::conv::FromSample<T>,
|
||||||
|
{
|
||||||
|
use symphonia::core::audio::Signal;
|
||||||
|
use symphonia::core::conv::FromSample;
|
||||||
|
samples.extend(data.chan(0).iter().map(|v| f32::from_sample(*v)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pcm_decode<P: AsRef<std::path::Path>>(path: P) -> anyhow::Result<(Vec<f32>, u32)> {
|
||||||
|
use symphonia::core::audio::{AudioBufferRef, Signal};
|
||||||
|
|
||||||
|
let src = std::fs::File::open(path)?;
|
||||||
|
let mss = symphonia::core::io::MediaSourceStream::new(Box::new(src), Default::default());
|
||||||
|
let hint = symphonia::core::probe::Hint::new();
|
||||||
|
let meta_opts: symphonia::core::meta::MetadataOptions = Default::default();
|
||||||
|
let fmt_opts: symphonia::core::formats::FormatOptions = Default::default();
|
||||||
|
let probed = symphonia::default::get_probe().format(&hint, mss, &fmt_opts, &meta_opts)?;
|
||||||
|
let mut format = probed.format;
|
||||||
|
let track = format
|
||||||
|
.tracks()
|
||||||
|
.iter()
|
||||||
|
.find(|t| t.codec_params.codec != symphonia::core::codecs::CODEC_TYPE_NULL)
|
||||||
|
.expect("no supported audio tracks");
|
||||||
|
let mut decoder = symphonia::default::get_codecs()
|
||||||
|
.make(&track.codec_params, &Default::default())
|
||||||
|
.expect("unsupported codec");
|
||||||
|
let track_id = track.id;
|
||||||
|
let sample_rate = track.codec_params.sample_rate.unwrap_or(0);
|
||||||
|
let mut pcm_data = Vec::new();
|
||||||
|
while let Ok(packet) = format.next_packet() {
|
||||||
|
while !format.metadata().is_latest() {
|
||||||
|
format.metadata().pop();
|
||||||
|
}
|
||||||
|
if packet.track_id() != track_id {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
match decoder.decode(&packet)? {
|
||||||
|
AudioBufferRef::F32(buf) => pcm_data.extend(buf.chan(0)),
|
||||||
|
AudioBufferRef::U8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::U32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S8(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S16(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S24(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::S32(data) => conv(&mut pcm_data, data),
|
||||||
|
AudioBufferRef::F64(data) => conv(&mut pcm_data, data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok((pcm_data, sample_rate))
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)]
|
||||||
enum ArgDType {
|
enum ArgDType {
|
||||||
F32,
|
F32,
|
||||||
@ -69,9 +125,14 @@ struct Args {
|
|||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
second_stage_weights: Option<String>,
|
second_stage_weights: Option<String>,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
speaker_encoder_weights: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
encodec_weights: Option<String>,
|
encodec_weights: Option<String>,
|
||||||
|
|
||||||
|
/// The speaker embeddings, either an audio files in which case they are extracted, or a
|
||||||
|
/// safetensors file with the embeddings already extracted.
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
spk_emb: Option<String>,
|
spk_emb: Option<String>,
|
||||||
|
|
||||||
@ -79,6 +140,13 @@ struct Args {
|
|||||||
dtype: ArgDType,
|
dtype: ArgDType,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn mel_filters() -> Result<Vec<f32>> {
|
||||||
|
let mel_bytes = include_bytes!("melfilters40.bytes").as_slice();
|
||||||
|
let mut mel_filters = vec![0f32; mel_bytes.len() / 4];
|
||||||
|
<byteorder::LittleEndian as byteorder::ByteOrder>::read_f32_into(mel_bytes, &mut mel_filters);
|
||||||
|
Ok(mel_filters)
|
||||||
|
}
|
||||||
|
|
||||||
fn main() -> Result<()> {
|
fn main() -> Result<()> {
|
||||||
use tracing_chrome::ChromeLayerBuilder;
|
use tracing_chrome::ChromeLayerBuilder;
|
||||||
use tracing_subscriber::prelude::*;
|
use tracing_subscriber::prelude::*;
|
||||||
@ -120,7 +188,7 @@ fn main() -> Result<()> {
|
|||||||
Some(w) => std::path::PathBuf::from(w),
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
None => repo.get("first_stage.safetensors")?,
|
None => repo.get("first_stage.safetensors")?,
|
||||||
};
|
};
|
||||||
let second_stage_weights = match &args.first_stage_weights {
|
let second_stage_weights = match &args.second_stage_weights {
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
None => repo.get("second_stage.safetensors")?,
|
None => repo.get("second_stage.safetensors")?,
|
||||||
};
|
};
|
||||||
@ -159,16 +227,41 @@ fn main() -> Result<()> {
|
|||||||
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
let prompt_tokens = fs_tokenizer.encode(&args.prompt)?;
|
||||||
let mut tokens = prompt_tokens.clone();
|
let mut tokens = prompt_tokens.clone();
|
||||||
println!("{tokens:?}");
|
println!("{tokens:?}");
|
||||||
let spk_emb_file = match &args.spk_emb {
|
let safetensors_embeddings = args
|
||||||
Some(w) => std::path::PathBuf::from(w),
|
.spk_emb
|
||||||
None => repo.get("spk_emb.safetensors")?,
|
.as_ref()
|
||||||
|
.map_or(true, |v| v.ends_with("safetensors"));
|
||||||
|
let spk_emb = if safetensors_embeddings {
|
||||||
|
let spk_emb_file = match &args.spk_emb {
|
||||||
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
|
None => repo.get("spk_emb.safetensors")?,
|
||||||
|
};
|
||||||
|
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
||||||
|
match spk_emb.get("spk_emb") {
|
||||||
|
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
||||||
|
Some(spk_emb) => spk_emb.to_dtype(dtype)?.to_device(&device)?,
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let weights = match &args.speaker_encoder_weights {
|
||||||
|
Some(w) => std::path::PathBuf::from(w),
|
||||||
|
None => repo.get("speaker_encoder.safetensors")?,
|
||||||
|
};
|
||||||
|
let mel_filters = mel_filters()?;
|
||||||
|
let config = speaker_encoder::Config::cfg();
|
||||||
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[weights], dtype, &device)? };
|
||||||
|
let model = speaker_encoder::Model::new(config, vb)?;
|
||||||
|
let (pcm, sample_rate) = pcm_decode(&args.spk_emb.unwrap())?;
|
||||||
|
if sample_rate != 16_000 {
|
||||||
|
eprintln!("WARNING: speaker embedding input should use a 16kHz sample rate!")
|
||||||
|
}
|
||||||
|
model.embed_utterance(
|
||||||
|
&pcm,
|
||||||
|
&mel_filters,
|
||||||
|
/* rate */ 1.3,
|
||||||
|
/* min_c */ 0.75,
|
||||||
|
&device,
|
||||||
|
)?
|
||||||
};
|
};
|
||||||
let spk_emb = candle::safetensors::load(&spk_emb_file, &candle::Device::Cpu)?;
|
|
||||||
let spk_emb = match spk_emb.get("spk_emb") {
|
|
||||||
None => anyhow::bail!("missing spk_emb tensor in {spk_emb_file:?}"),
|
|
||||||
Some(spk_emb) => spk_emb.to_dtype(dtype)?,
|
|
||||||
};
|
|
||||||
let spk_emb = spk_emb.to_device(&device)?;
|
|
||||||
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
let mut logits_processor = LogitsProcessor::new(args.seed, Some(args.temperature), Some(0.95));
|
||||||
|
|
||||||
// First stage generation.
|
// First stage generation.
|
||||||
|
BIN
candle-examples/examples/metavoice/melfilters40.bytes
Normal file
BIN
candle-examples/examples/metavoice/melfilters40.bytes
Normal file
Binary file not shown.
@ -55,12 +55,12 @@ pub mod speaker_encoder {
|
|||||||
layer_idx,
|
layer_idx,
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
let lstm = candle_nn::lstm(
|
let in_c = if layer_idx == 0 {
|
||||||
cfg.mel_n_channels,
|
cfg.mel_n_channels
|
||||||
cfg.model_hidden_size,
|
} else {
|
||||||
c,
|
cfg.model_hidden_size
|
||||||
vb_l.pp(layer_idx),
|
};
|
||||||
)?;
|
let lstm = candle_nn::lstm(in_c, cfg.model_hidden_size, c, vb_l.clone())?;
|
||||||
lstms.push(lstm)
|
lstms.push(lstm)
|
||||||
}
|
}
|
||||||
let linear = linear_b(
|
let linear = linear_b(
|
||||||
@ -143,7 +143,7 @@ pub mod speaker_encoder {
|
|||||||
.iter()
|
.iter()
|
||||||
.flat_map(|s| [mel[s.0], mel[s.1]])
|
.flat_map(|s| [mel[s.0], mel[s.1]])
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let mels = Tensor::from_vec(mels, (mel_slices.len(), 2), device)?;
|
let mels = Tensor::from_vec(mels, (1, mel_slices.len(), 2), device)?;
|
||||||
let partial_embeds = self.forward(&mels)?;
|
let partial_embeds = self.forward(&mels)?;
|
||||||
let raw_embed = partial_embeds.mean(0)?;
|
let raw_embed = partial_embeds.mean(0)?;
|
||||||
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
|
let norm = raw_embed.sqr()?.sum_all()?.sqrt()?;
|
||||||
|
Reference in New Issue
Block a user