mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Use the same padding in metavoice as in the python version. (#1794)
This commit is contained in:
@ -129,7 +129,7 @@ fn main() -> Result<()> {
|
|||||||
VarBuilder::from_mmaped_safetensors(&[second_stage_weights], DType::F32, &device)?
|
VarBuilder::from_mmaped_safetensors(&[second_stage_weights], DType::F32, &device)?
|
||||||
};
|
};
|
||||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||||
let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?;
|
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
||||||
|
|
||||||
let encodec_device = if device.is_metal() {
|
let encodec_device = if device.is_metal() {
|
||||||
&candle::Device::Cpu
|
&candle::Device::Cpu
|
||||||
@ -182,13 +182,16 @@ fn main() -> Result<()> {
|
|||||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
||||||
// TODO: Use the config rather than hardcoding the offset here.
|
// TODO: Use the config rather than hardcoding the offset here.
|
||||||
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
||||||
let hierarchies_in1 = [encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
let mut hierarchies_in1 =
|
||||||
let hierarchies_in2 = [
|
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
||||||
|
let mut hierarchies_in2 = [
|
||||||
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
||||||
ids2.as_slice(),
|
ids2.as_slice(),
|
||||||
&[ENCODEC_NTOKENS],
|
&[ENCODEC_NTOKENS],
|
||||||
]
|
]
|
||||||
.concat();
|
.concat();
|
||||||
|
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||||
|
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||||
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
||||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||||
|
Reference in New Issue
Block a user