Move the llama2-c model in transformers. (#1205)

This commit is contained in:
Laurent Mazare
2023-10-28 17:51:19 +02:00
committed by GitHub
parent 612f5b8156
commit 95a857cf57
6 changed files with 12 additions and 9 deletions

View File

@ -6,10 +6,10 @@ extern crate accelerate_src;
#[cfg(feature = "mkl")] #[cfg(feature = "mkl")]
extern crate intel_mkl_src; extern crate intel_mkl_src;
mod model; use candle_transformers::models::llama2_c as model;
mod qmodel; use candle_transformers::models::llama2_c_weights as weights;
use candle_transformers::models::quantized_llama2_c as qmodel;
mod training; mod training;
mod weights;
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use anyhow::{Error as E, Result}; use anyhow::{Error as E, Result};

View File

@ -11,6 +11,7 @@ readme = "README.md"
[dependencies] [dependencies]
accelerate-src = { workspace = true, optional = true } accelerate-src = { workspace = true, optional = true }
byteorder = { workspace = true }
candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" } candle = { path = "../candle-core", version = "0.3.0", package = "candle-core" }
candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true } candle-flash-attn = { path = "../candle-flash-attn", version = "0.3.0", optional = true }
candle-nn = { path = "../candle-nn", version = "0.3.0" } candle-nn = { path = "../candle-nn", version = "0.3.0" }

View File

@ -1,9 +1,8 @@
use anyhow::Result;
use byteorder::{LittleEndian, ReadBytesExt}; use byteorder::{LittleEndian, ReadBytesExt};
use candle::{DType, Device, IndexOp, Shape, Tensor}; use candle::{DType, Device, IndexOp, Result, Shape, Tensor};
use candle_nn::VarBuilder; use candle_nn::VarBuilder;
use crate::model::Config; use super::llama2_c::Config;
pub struct TransformerWeights { pub struct TransformerWeights {
// token embedding table // token embedding table

View File

@ -8,6 +8,8 @@ pub mod efficientnet;
pub mod falcon; pub mod falcon;
pub mod jina_bert; pub mod jina_bert;
pub mod llama; pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod mistral; pub mod mistral;
pub mod mixformer; pub mod mixformer;
pub mod mpt; pub mod mpt;
@ -15,6 +17,7 @@ pub mod persimmon;
pub mod quantized_blip; pub mod quantized_blip;
pub mod quantized_blip_text; pub mod quantized_blip_text;
pub mod quantized_llama; pub mod quantized_llama;
pub mod quantized_llama2_c;
pub mod quantized_mistral; pub mod quantized_mistral;
pub mod quantized_mixformer; pub mod quantized_mixformer;
pub mod quantized_mpt; pub mod quantized_mpt;

View File

@ -1,7 +1,7 @@
use super::model::{Cache, Config}; use super::llama2_c::{Cache, Config};
use crate::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
pub use crate::quantized_var_builder::VarBuilder;
use candle::{DType, IndexOp, Module, Result, Tensor, D}; use candle::{DType, IndexOp, Module, Result, Tensor, D};
use candle_transformers::quantized_nn::{linear_no_bias as linear, Embedding, Linear, RmsNorm};
pub use candle_transformers::quantized_var_builder::VarBuilder;
fn silu(xs: &Tensor) -> Result<Tensor> { fn silu(xs: &Tensor) -> Result<Tensor> {
xs / (xs.neg()?.exp()? + 1.0)? xs / (xs.neg()?.exp()? + 1.0)?