mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
Quantized version for phi-v2. (#1430)
* Quantized version for phi-v2. * More quantized support.
This commit is contained in:
@ -1,14 +1,36 @@
|
||||
# candle-phi: 1.3b LLM with state of the art performance for <10b models.
|
||||
# candle-phi: 1.3b and 2.7b LLM with state of the art performance for <10b models.
|
||||
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) is a language model using
|
||||
only 1.3 billion parameters but with state of the art performance compared to
|
||||
[Phi-1.5](https://huggingface.co/microsoft/phi-1_5) and
|
||||
[Phi-2](https://huggingface.co/microsoft/phi-2) are language models using
|
||||
only 1.3 and 2.7 billion parameters but with state of the art performance compared to
|
||||
models with up to 10 billion parameters.
|
||||
|
||||
The candle implementation provides both the standard version as well as a
|
||||
quantized variant.
|
||||
|
||||
## Running some example
|
||||
## Running some examples
|
||||
|
||||
For the v2 version.
|
||||
```bash
|
||||
$ cargo run --example phi --release cuda -- --prompt "def print_prime(n): " --model 2
|
||||
def print_prime(n):
|
||||
if n <= 1:
|
||||
print("Not a prime number")
|
||||
else:
|
||||
for i in range(2, int(n**0.5)+1):
|
||||
if (n % i) == 0:
|
||||
print("Not a prime number")
|
||||
break
|
||||
else:
|
||||
print("Prime number")
|
||||
|
||||
|
||||
# Driver code
|
||||
n = 17
|
||||
print_prime(n)
|
||||
```
|
||||
|
||||
For the v1.5 version.
|
||||
```bash
|
||||
$ cargo run --example phi --release -- --prompt "def print_prime(n): "
|
||||
|
||||
|
@ -268,7 +268,7 @@ fn main() -> Result<()> {
|
||||
match args.model {
|
||||
WhichModel::V1 => vec![repo.get("model-v1-q4k.gguf")?],
|
||||
WhichModel::V1_5 => vec![repo.get("model-q4k.gguf")?],
|
||||
WhichModel::V2 => anyhow::bail!("phi-2 is not supported in quantized mode"),
|
||||
WhichModel::V2 => vec![repo.get("model-v2-q4k.gguf")?],
|
||||
WhichModel::PuffinPhiV2 => vec![repo.get("model-puffin-phi-v2-q4k.gguf")?],
|
||||
WhichModel::PhiHermes => vec![repo.get("model-phi-hermes-1_3B-q4k.gguf")?],
|
||||
}
|
||||
@ -298,7 +298,10 @@ fn main() -> Result<()> {
|
||||
};
|
||||
let (model, device) = if args.quantized {
|
||||
let vb = candle_transformers::quantized_var_builder::VarBuilder::from_gguf(&filenames[0])?;
|
||||
let model = QMixFormer::new(&config, vb)?;
|
||||
let model = match args.model {
|
||||
WhichModel::V2 => QMixFormer::new_v2(&config, vb)?,
|
||||
_ => QMixFormer::new(&config, vb)?,
|
||||
};
|
||||
(Model::Quantized(model), Device::Cpu)
|
||||
} else {
|
||||
let device = candle_examples::device(args.cpu)?;
|
||||
|
@ -287,6 +287,24 @@ pub struct MixFormerSequentialForCausalLM {
|
||||
}
|
||||
|
||||
impl MixFormerSequentialForCausalLM {
|
||||
pub fn new_v2(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb_head = vb.pp("lm_head");
|
||||
let vb = vb.pp("transformer");
|
||||
let embedding = Embedding::new(cfg, vb.pp("embd"))?;
|
||||
let mut blocks = Vec::new();
|
||||
for i in 0..cfg.n_layer {
|
||||
let block = ParallelBlock::new(cfg, vb.pp("h").pp(i))?;
|
||||
blocks.push(block)
|
||||
}
|
||||
let head = CausalLMHead::new(cfg, vb_head)?;
|
||||
Ok(Self {
|
||||
embedding,
|
||||
blocks,
|
||||
head,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mixformer"),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
|
||||
let vb = vb.pp("layers");
|
||||
let embedding = Embedding::new(cfg, vb.pp(0))?;
|
||||
|
Reference in New Issue
Block a user