mirror of
https://github.com/huggingface/candle.git
synced 2025-06-16 10:38:54 +00:00
Self-contained safetensor wrappers (#946)
* Self-contained safetensor wrappers. * Use the new safetensor container in varbuilders.
This commit is contained in:
@ -122,30 +122,16 @@ impl T5ModelBuilder {
|
||||
}
|
||||
|
||||
pub fn build_encoder(&self) -> Result<t5::T5EncoderModel> {
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5EncoderModel::load(vb, &self.config)?)
|
||||
}
|
||||
|
||||
pub fn build_conditional_generation(&self) -> Result<t5::T5ForConditionalGeneration> {
|
||||
let weights = self
|
||||
.weights_filename
|
||||
.iter()
|
||||
.map(|f| unsafe { candle::safetensors::MmapedFile::new(f) })
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let weights = weights
|
||||
.iter()
|
||||
.map(|w| w.deserialize())
|
||||
.collect::<candle::Result<Vec<_>>>()?;
|
||||
let vb = VarBuilder::from_safetensors(weights, DTYPE, &self.device);
|
||||
let vb = unsafe {
|
||||
VarBuilder::from_mmaped_safetensors(&self.weights_filename, DTYPE, &self.device)?
|
||||
};
|
||||
Ok(t5::T5ForConditionalGeneration::load(vb, &self.config)?)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user