mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Use the same padding in metavoice as in the python version. (#1794)
This commit is contained in:
@ -129,7 +129,7 @@ fn main() -> Result<()> {
|
||||
VarBuilder::from_mmaped_safetensors(&[second_stage_weights], DType::F32, &device)?
|
||||
};
|
||||
let second_stage_config = gpt::Config::cfg1b_v0_1();
|
||||
let second_stage_model = gpt::Model::new(second_stage_config, second_stage_vb)?;
|
||||
let second_stage_model = gpt::Model::new(second_stage_config.clone(), second_stage_vb)?;
|
||||
|
||||
let encodec_device = if device.is_metal() {
|
||||
&candle::Device::Cpu
|
||||
@ -182,13 +182,16 @@ fn main() -> Result<()> {
|
||||
let mut rng = rand::rngs::StdRng::seed_from_u64(args.seed + 1337);
|
||||
// TODO: Use the config rather than hardcoding the offset here.
|
||||
let encoded_text: Vec<_> = prompt_tokens.iter().map(|v| v - 1024).collect();
|
||||
let hierarchies_in1 = [encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
||||
let hierarchies_in2 = [
|
||||
let mut hierarchies_in1 =
|
||||
[encoded_text.as_slice(), ids1.as_slice(), &[ENCODEC_NTOKENS]].concat();
|
||||
let mut hierarchies_in2 = [
|
||||
vec![ENCODEC_NTOKENS; encoded_text.len()].as_slice(),
|
||||
ids2.as_slice(),
|
||||
&[ENCODEC_NTOKENS],
|
||||
]
|
||||
.concat();
|
||||
hierarchies_in1.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
hierarchies_in2.resize(second_stage_config.block_size, ENCODEC_NTOKENS);
|
||||
let in_x1 = Tensor::new(hierarchies_in1, &device)?;
|
||||
let in_x2 = Tensor::new(hierarchies_in2, &device)?;
|
||||
let in_x = Tensor::stack(&[in_x1, in_x2], 0)?.unsqueeze(0)?;
|
||||
|
Reference in New Issue
Block a user