Self-contained safetensor wrappers (#946)

* Self-contained safetensor wrappers.

* Use the new safetensor container in varbuilders.
This commit is contained in:
Laurent Mazare
2023-09-23 20:39:52 +01:00
committed by GitHub
parent 5dbe46b389
commit 890d069092
3 changed files with 61 additions and 30 deletions

View File

@ -122,30 +122,16 @@ impl T5ModelBuilder {
}
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
let weights = self
.weights_filename
.iter()
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
};
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
}
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
let weights = self
.weights_filename
.iter()
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
.collect::<candle::Result<Vec<_>>>()?;
let weights = weights
.iter()
.map(|w| w.deserialize())
.collect::<candle::Result<Vec<_>>>()?;
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
};
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
}
}