Adapt more examples to the updated safetensor api. (#947)

* Simplify the safetensor usage.

* Convert more examples.

* Move more examples.

* Adapt stable-diffusion.
This commit is contained in:
Laurent Mazare
2023-09-23 21:26:03 +01:00
committed by GitHub
parent 890d069092
commit bb3471ea31
14 changed files with 31 additions and 84 deletions

View File

@ -86,9 +86,8 @@ impl Args {
let config: Config = serde_json::from_str(&config)?; let config: Config = serde_json::from_str(&config)?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let vb =
let weights = weights.deserialize()?; unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], DTYPE, &device)? };
let vb = VarBuilder::from_safetensors(vec![weights], DTYPE, &device);
let model = BertModel::load(vb, &config)?; let model = BertModel::load(vb, &config)?;
Ok((model, tokenizer)) Ok((model, tokenizer))
} }

View File

@ -138,18 +138,9 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::starcoder_1b(); let config = Config::starcoder_1b();
let model = GPTBigCode::load(vb, config)?; let model = GPTBigCode::load(vb, config)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

View File

@ -42,9 +42,7 @@ pub fn main() -> anyhow::Result<()> {
} }
Some(model) => model.into(), Some(model) => model.into(),
}; };
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = dinov2::vit_small(vb)?; let model = dinov2::vit_small(vb)?;
println!("model built"); println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?; let logits = model.forward(&image.unsqueeze(0)?)?;

View File

@ -68,9 +68,7 @@ pub fn main() -> anyhow::Result<()> {
} }
Some(model) => model.into(), Some(model) => model.into(),
}; };
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model_file], DType::F32, &device)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let cfg = match args.which { let cfg = match args.which {
Which::B0 => MBConvConfig::b0(), Which::B0 => MBConvConfig::b0(),
Which::B1 => MBConvConfig::b1(), Which::B1 => MBConvConfig::b1(),

View File

@ -177,21 +177,12 @@ fn main() -> Result<()> {
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let dtype = if args.use_f32 { let dtype = if args.use_f32 {
DType::F32 DType::F32
} else { } else {
DType::BF16 DType::BF16
}; };
let vb = VarBuilder::from_safetensors(weights, dtype, &device); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let config = Config::falcon7b(); let config = Config::falcon7b();
config.validate()?; config.validate()?;
let model = Falcon::load(vb, config)?; let model = Falcon::load(vb, config)?;

View File

@ -172,17 +172,9 @@ fn main() -> Result<()> {
} }
println!("building the model"); println!("building the model");
let handles = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f.as_path())? }))
.collect::<Result<Vec<_>>>()?;
let tensors: Vec<_> = handles
.iter()
.map(|h| Ok(h.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?; let cache = model::Cache::new(!args.no_kv_cache, dtype, &config, &device)?;
let vb = VarBuilder::from_safetensors(tensors, dtype, &device); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
(Llama::load(vb, &cache, &config)?, tokenizer_filename, cache) (Llama::load(vb, &cache, &config)?, tokenizer_filename, cache)
} }
}; };

View File

@ -73,9 +73,7 @@ fn main() -> Result<()> {
)) ))
.get("model.safetensors")?, .get("model.safetensors")?,
}; };
let model = unsafe { candle::safetensors::MmapedFile::new(model)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DTYPE, &device)? };
let model = model.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![model], DTYPE, &device);
let config = GenConfig::small(); let config = GenConfig::small();
let mut model = MusicgenForConditionalGeneration::load(vb, config)?; let mut model = MusicgenForConditionalGeneration::load(vb, config)?;

View File

@ -149,18 +149,9 @@ fn main() -> Result<()> {
println!("retrieved the files in {:?}", start.elapsed()); println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let weights = filenames
.iter()
.map(|f| Ok(unsafe { candle::safetensors::MmapedFile::new(f)? }))
.collect::<Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|f| Ok(f.deserialize()?))
.collect::<Result<Vec<_>>>()?;
let start = std::time::Instant::now(); let start = std::time::Instant::now();
let device = candle_examples::device(args.cpu)?; let device = candle_examples::device(args.cpu)?;
let vb = VarBuilder::from_safetensors(weights, DType::F32, &device); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, DType::F32, &device)? };
let config = Config::v1_5(); let config = Config::v1_5();
let model = Model::new(&config, vb)?; let model = Model::new(&config, vb)?;
println!("loaded the model in {:?}", start.elapsed()); println!("loaded the model in {:?}", start.elapsed());

View File

@ -82,9 +82,7 @@ pub fn main() -> anyhow::Result<()> {
api.get(filename)? api.get(filename)?
} }
}; };
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny { let sam = if args.use_tiny {
sam::Sam::new_tiny(vb)? // tiny vit_t sam::Sam::new_tiny(vb)? // tiny vit_t
} else { } else {

View File

@ -481,9 +481,8 @@ fn main() -> Result<()> {
let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?; let mel = Tensor::from_vec(mel, (1, m::N_MELS, mel_len / m::N_MELS), &device)?;
println!("loaded mel: {:?}", mel.dims()); println!("loaded mel: {:?}", mel.dims());
let weights = unsafe { candle::safetensors::MmapedFile::new(weights_filename)? }; let vb =
let weights = weights.deserialize()?; unsafe { VarBuilder::from_mmaped_safetensors(&[weights_filename], m::DTYPE, &device)? };
let vb = VarBuilder::from_safetensors(vec![weights], m::DTYPE, &device);
let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?; let config: Config = serde_json::from_str(&std::fs::read_to_string(config_filename)?)?;
let mut model = Whisper::load(&vb, config)?; let mut model = Whisper::load(&vb, config)?;

View File

@ -287,10 +287,10 @@ fn run(args: Args) -> Result<()> {
)?; )?;
let prior = { let prior = {
let prior_weights = ModelFile::Prior.get(prior_weights)?; let file = ModelFile::Prior.get(prior_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(prior_weights)? }; let vb = unsafe {
let weights = weights.deserialize()?; candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); };
wuerstchen::prior::WPrior::new( wuerstchen::prior::WPrior::new(
/* c_in */ PRIOR_CIN, /* c_in */ PRIOR_CIN,
/* c */ 1536, /* c */ 1536,
@ -324,10 +324,10 @@ fn run(args: Args) -> Result<()> {
println!("Building the vqgan."); println!("Building the vqgan.");
let vqgan = { let vqgan = {
let vqgan_weights = ModelFile::VqGan.get(vqgan_weights)?; let file = ModelFile::VqGan.get(vqgan_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(vqgan_weights)? }; let vb = unsafe {
let weights = weights.deserialize()?; candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); };
wuerstchen::paella_vq::PaellaVQ::new(vb)? wuerstchen::paella_vq::PaellaVQ::new(vb)?
}; };
@ -335,10 +335,10 @@ fn run(args: Args) -> Result<()> {
// https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json // https://huggingface.co/warp-ai/wuerstchen/blob/main/decoder/config.json
let decoder = { let decoder = {
let decoder_weights = ModelFile::Decoder.get(decoder_weights)?; let file = ModelFile::Decoder.get(decoder_weights)?;
let weights = unsafe { candle::safetensors::MmapedFile::new(decoder_weights)? }; let vb = unsafe {
let weights = weights.deserialize()?; candle_nn::VarBuilder::from_mmaped_safetensors(&[file], DType::F32, &device)?
let vb = candle_nn::VarBuilder::from_safetensors(vec![weights], DType::F32, &device); };
wuerstchen::diffnext::WDiffNeXt::new( wuerstchen::diffnext::WDiffNeXt::new(
/* c_in */ DECODER_CIN, /* c_in */ DECODER_CIN,
/* c_out */ DECODER_CIN, /* c_out */ DECODER_CIN,

View File

@ -146,9 +146,7 @@ pub fn main() -> Result<()> {
// Create the model and load the weights from the file. // Create the model and load the weights from the file.
let model = args.model()?; let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &Device::Cpu)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &Device::Cpu);
let config = args.config()?; let config = args.config()?;
let darknet = darknet::parse_config(config)?; let darknet = darknet::parse_config(config)?;
let model = darknet.build_model(vb)?; let model = darknet.build_model(vb)?;

View File

@ -381,9 +381,7 @@ pub fn run<T: Task>(args: Args) -> anyhow::Result<()> {
Which::X => Multiples::x(), Which::X => Multiples::x(),
}; };
let model = args.model()?; let model = args.model()?;
let weights = unsafe { candle::safetensors::MmapedFile::new(model)? }; let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[model], DType::F32, &device)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = T::load(vb, multiples)?; let model = T::load(vb, multiples)?;
println!("model loaded"); println!("model loaded");
for image_name in args.images.iter() { for image_name in args.images.iter() {

View File

@ -255,9 +255,8 @@ impl StableDiffusionConfig {
device: &Device, device: &Device,
dtype: DType, dtype: DType,
) -> Result<vae::AutoEncoderKL> { ) -> Result<vae::AutoEncoderKL> {
let weights = unsafe { candle::safetensors::MmapedFile::new(vae_weights)? }; let vs_ae =
let weights = weights.deserialize()?; unsafe { nn::VarBuilder::from_mmaped_safetensors(&[vae_weights], dtype, device)? };
let vs_ae = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
// https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json // https://huggingface.co/runwayml/stable-diffusion-v1-5/blob/main/vae/config.json
let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?; let autoencoder = vae::AutoEncoderKL::new(vs_ae, 3, 3, self.autoencoder.clone())?;
Ok(autoencoder) Ok(autoencoder)
@ -271,9 +270,8 @@ impl StableDiffusionConfig {
use_flash_attn: bool, use_flash_attn: bool,
dtype: DType, dtype: DType,
) -> Result<unet_2d::UNet2DConditionModel> { ) -> Result<unet_2d::UNet2DConditionModel> {
let weights = unsafe { candle::safetensors::MmapedFile::new(unet_weights)? }; let vs_unet =
let weights = weights.deserialize()?; unsafe { nn::VarBuilder::from_mmaped_safetensors(&[unet_weights], dtype, device)? };
let vs_unet = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let unet = unet_2d::UNet2DConditionModel::new( let unet = unet_2d::UNet2DConditionModel::new(
vs_unet, vs_unet,
in_channels, in_channels,
@ -295,9 +293,7 @@ pub fn build_clip_transformer<P: AsRef<std::path::Path>>(
device: &Device, device: &Device,
dtype: DType, dtype: DType,
) -> Result<clip::ClipTextTransformer> { ) -> Result<clip::ClipTextTransformer> {
let weights = unsafe { candle::safetensors::MmapedFile::new(clip_weights)? }; let vs = unsafe { nn::VarBuilder::from_mmaped_safetensors(&[clip_weights], dtype, device)? };
let weights = weights.deserialize()?;
let vs = nn::VarBuilder::from_safetensors(vec![weights], dtype, device);
let text_model = clip::ClipTextTransformer::new(vs, clip)?; let text_model = clip::ClipTextTransformer::new(vs, clip)?;
Ok(text_model) Ok(text_model)
} }