mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 18:28:24 +00:00
Adapt more examples to the updated safetensor api. (#947)
* Simplify the safetensor usage. * Convert more examples. * Move more examples. * Adapt stable-diffusion.
This commit is contained in:
@ -86,9 +86,8 @@ impl Args {
|
|||||||
let config: Config = serde_json::from_str(&config)?;
|
let config: Config = serde_json::from_str(&config)?;
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let vb =
|
||||||
let weights = weights.deserialize()?;
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
|
|
||||||
let model = BertModel::load(vb, &config)?;
|
let model = BertModel::load(vb, &config)?;
|
||||||
Ok((model, tokenizer))
|
Ok((model, tokenizer))
|
||||||
}
|
}
|
||||||
|
@ -138,18 +138,9 @@ fn main() -> Result<()> {
|
|||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let weights = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let weights = weights
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(f.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let config = Config::starcoder_1b();
|
let config = Config::starcoder_1b();
|
||||||
let model = GPTBigCode::load(vb, config)?;
|
let model = GPTBigCode::load(vb, config)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
@ -42,9 +42,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
||||||
let model = dinov2::vit_small(vb)?;
|
let model = dinov2::vit_small(vb)?;
|
||||||
println!("model built");
|
println!("model built");
|
||||||
let logits = model.forward(&image.unsqueeze(0)?)?;
|
let logits = model.forward(&image.unsqueeze(0)?)?;
|
||||||
|
@ -68,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
Some(model) => model.into(),
|
Some(model) => model.into(),
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
||||||
let cfg = match args.which {
|
let cfg = match args.which {
|
||||||
Which::B0 => MBConvConfig::b0(),
|
Which::B0 => MBConvConfig::b0(),
|
||||||
Which::B1 => MBConvConfig::b1(),
|
Which::B1 => MBConvConfig::b1(),
|
||||||
|
@ -177,21 +177,12 @@ fn main() -> Result<()> {
|
|||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let weights = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let weights = weights
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(f.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let dtype = if args.use_f32 {
|
let dtype = if args.use_f32 {
|
||||||
DType::F32
|
DType::F32
|
||||||
} else {
|
} else {
|
||||||
DType::BF16
|
DType::BF16
|
||||||
};
|
};
|
||||||
let vb = VarBuilder::from_safetensors(weights, dtype, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
let config = Config::falcon7b();
|
let config = Config::falcon7b();
|
||||||
config.validate()?;
|
config.validate()?;
|
||||||
let model = Falcon::load(vb, config)?;
|
let model = Falcon::load(vb, config)?;
|
||||||
|
@ -172,17 +172,9 @@ fn main() -> Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
println!("building the model");
|
println!("building the model");
|
||||||
let handles = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let tensors: Vec<_> = handles
|
|
||||||
.iter()
|
|
||||||
.map(|h| Ok(h.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
|
||||||
|
|
||||||
let vb = VarBuilder::from_safetensors(tensors, dtype, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
|
||||||
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -73,9 +73,7 @@ fn main() -> Result<()> {
|
|||||||
))
|
))
|
||||||
.get("model.safetensors")?,
|
.get("model.safetensors")?,
|
||||||
};
|
};
|
||||||
let model = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
|
||||||
let model = model.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
|
|
||||||
let config = GenConfig::small();
|
let config = GenConfig::small();
|
||||||
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
let mut model = MusicgenForConditionalGeneration::load(vb, config)?;
|
||||||
|
|
||||||
|
@ -149,18 +149,9 @@ fn main() -> Result<()> {
|
|||||||
println!("retrieved the files in {:?}", start.elapsed());
|
println!("retrieved the files in {:?}", start.elapsed());
|
||||||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
|
||||||
|
|
||||||
let weights = filenames
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
let weights = weights
|
|
||||||
.iter()
|
|
||||||
.map(|f| Ok(f.deserialize()?))
|
|
||||||
.collect::<Result<Vec<_>>>()?;
|
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let device = candle_examples::device(args.cpu)?;
|
let device = candle_examples::device(args.cpu)?;
|
||||||
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device);
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
|
||||||
let config = Config::v1_5();
|
let config = Config::v1_5();
|
||||||
let model = Model::new(&config, vb)?;
|
let model = Model::new(&config, vb)?;
|
||||||
println!("loaded the model in {:?}", start.elapsed());
|
println!("loaded the model in {:?}", start.elapsed());
|
||||||
|
@ -82,9 +82,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||||||
api.get(filename)?
|
api.get(filename)?
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
||||||
let sam = if args.use_tiny {
|
let sam = if args.use_tiny {
|
||||||
sam::Sam::new_tiny(vb)? // tiny vit_t
|
sam::Sam::new_tiny(vb)? // tiny vit_t
|
||||||
} else {
|
} else {
|
||||||
|
@ -481,9 +481,8 @@ fn main() -> Result<()> {
|
|||||||
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
|
||||||
println!("loaded mel: {:?}", mel.dims());
|
println!("loaded mel: {:?}", mel.dims());
|
||||||
|
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? };
|
let vb =
|
||||||
let weights = weights.deserialize()?;
|
unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
|
|
||||||
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
|
||||||
let mut model = Whisper::load(&vb, config)?;
|
let mut model = Whisper::load(&vb, config)?;
|
||||||
|
|
||||||
|
@ -287,10 +287,10 @@ fn run(args: Args) -> Result<()> {
|
|||||||
)?;
|
)?;
|
||||||
|
|
||||||
let prior = {
|
let prior = {
|
||||||
let prior_weights = ModelFile::Prior.get(prior_weights)?;
|
let file = ModelFile::Prior.get(prior_weights)?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? };
|
let vb = unsafe {
|
||||||
let weights = weights.deserialize()?;
|
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
};
|
||||||
wuerstchen::prior::WPrior::new(
|
wuerstchen::prior::WPrior::new(
|
||||||
/* c_in */ PRIOR_CIN,
|
/* c_in */ PRIOR_CIN,
|
||||||
/* c */ 1536,
|
/* c */ 1536,
|
||||||
@ -324,10 +324,10 @@ fn run(args: Args) -> Result<()> {
|
|||||||
|
|
||||||
println!("Building the vqgan.");
|
println!("Building the vqgan.");
|
||||||
let vqgan = {
|
let vqgan = {
|
||||||
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?;
|
let file = ModelFile::VqGan.get(vqgan_weights)?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? };
|
let vb = unsafe {
|
||||||
let weights = weights.deserialize()?;
|
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
};
|
||||||
wuerstchen::paella_vq::PaellaVQ::new(vb)?
|
wuerstchen::paella_vq::PaellaVQ::new(vb)?
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -335,10 +335,10 @@ fn run(args: Args) -> Result<()> {
|
|||||||
|
|
||||||
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
|
||||||
let decoder = {
|
let decoder = {
|
||||||
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?;
|
let file = ModelFile::Decoder.get(decoder_weights)?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? };
|
let vb = unsafe {
|
||||||
let weights = weights.deserialize()?;
|
candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
|
||||||
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
};
|
||||||
wuerstchen::diffnext::WDiffNeXt::new(
|
wuerstchen::diffnext::WDiffNeXt::new(
|
||||||
/* c_in */ DECODER_CIN,
|
/* c_in */ DECODER_CIN,
|
||||||
/* c_out */ DECODER_CIN,
|
/* c_out */ DECODER_CIN,
|
||||||
|
@ -146,9 +146,7 @@ pub fn main() -> Result<()> {
|
|||||||
|
|
||||||
// Create the model and load the weights from the file.
|
// Create the model and load the weights from the file.
|
||||||
let model = args.model()?;
|
let model = args.model()?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
|
|
||||||
let config = args.config()?;
|
let config = args.config()?;
|
||||||
let darknet = darknet::parse_config(config)?;
|
let darknet = darknet::parse_config(config)?;
|
||||||
let model = darknet.build_model(vb)?;
|
let model = darknet.build_model(vb)?;
|
||||||
|
@ -381,9 +381,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
|
|||||||
Which::X => Multiples::x(),
|
Which::X => Multiples::x(),
|
||||||
};
|
};
|
||||||
let model = args.model()?;
|
let model = args.model()?;
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? };
|
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
|
|
||||||
let model = T::load(vb, multiples)?;
|
let model = T::load(vb, multiples)?;
|
||||||
println!("model loaded");
|
println!("model loaded");
|
||||||
for image_name in args.images.iter() {
|
for image_name in args.images.iter() {
|
||||||
|
@ -255,9 +255,8 @@ impl StableDiffusionConfig {
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
) -> Result<vae::AutoEncoderKL> {
|
) -> Result<vae::AutoEncoderKL> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? };
|
let vs_ae =
|
||||||
let weights = weights.deserialize()?;
|
unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? };
|
||||||
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
|
||||||
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
|
||||||
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
|
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
|
||||||
Ok(autoencoder)
|
Ok(autoencoder)
|
||||||
@ -271,9 +270,8 @@ impl StableDiffusionConfig {
|
|||||||
use_flash_attn: bool,
|
use_flash_attn: bool,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
) -> Result<unet_2d::UNet2DConditionModel> {
|
) -> Result<unet_2d::UNet2DConditionModel> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? };
|
let vs_unet =
|
||||||
let weights = weights.deserialize()?;
|
unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? };
|
||||||
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
|
||||||
let unet = unet_2d::UNet2DConditionModel::new(
|
let unet = unet_2d::UNet2DConditionModel::new(
|
||||||
vs_unet,
|
vs_unet,
|
||||||
in_channels,
|
in_channels,
|
||||||
@ -295,9 +293,7 @@ pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
|
|||||||
device: &Device,
|
device: &Device,
|
||||||
dtype: DType,
|
dtype: DType,
|
||||||
) -> Result<clip::ClipTextTransformer> {
|
) -> Result<clip::ClipTextTransformer> {
|
||||||
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? };
|
let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? };
|
||||||
let weights = weights.deserialize()?;
|
|
||||||
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
|
|
||||||
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
|
let text_model = clip::ClipTextTransformer::new(vs, clip)?;
|
||||||
Ok(text_model)
|
Ok(text_model)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user