mirror of
https://github.com/huggingface/candle.git
synced 2025-06-15 10:26:33 +00:00
MPT alibi fixes. (#1120)
* MPT alibi fixes. * Some more fixes. * Finally get the model to return some sensible outputs. * Add a readme.
This commit is contained in:
45
candle-examples/examples/replit-code/README.md
Normal file
45
candle-examples/examples/replit-code/README.md
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
# candle-replit-code: code completion specialized model.
|
||||||
|
|
||||||
|
[replit-code-v1_5-3b](https://huggingface.co/replit/replit-code-v1_5-3b) is a
|
||||||
|
language model specialized for code completion. This model uses 3.3B parameters
|
||||||
|
in `bfloat16` (so the GPU version will only work on recent nvidia cards).
|
||||||
|
|
||||||
|
## Running some example
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo run --example replit-code --release -- --prompt 'def fibonacci(n): '
|
||||||
|
```
|
||||||
|
This produces the following output which actually doesn't generate the fibonacci
|
||||||
|
series properly.
|
||||||
|
|
||||||
|
```
|
||||||
|
def fibonacci(n): # write Fibonacci series up to n
|
||||||
|
"""Print a Fibonacci series up to n."""
|
||||||
|
|
||||||
|
assert type(n) == int, "n must be an integer"
|
||||||
|
|
||||||
|
if (type(fib_list)==None or len==0 ):
|
||||||
|
fib_list = [1]
|
||||||
|
|
||||||
|
for i in range((len-2)): # start at 2nd element of list and go until end.
|
||||||
|
n += 1
|
||||||
|
|
||||||
|
print("Fibonacci number",n,"is:",i)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Call the functions."""
|
||||||
|
|
||||||
|
userInput=input('Enter a positive integer: ')
|
||||||
|
|
||||||
|
fibonacci(userInput)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__': # only run if this file is called directly.
|
||||||
|
print("This program prints out Fibonacci numbers.")
|
||||||
|
main()
|
||||||
|
```
|
@ -139,7 +139,7 @@ struct Args {
|
|||||||
seed: u64,
|
seed: u64,
|
||||||
|
|
||||||
/// The length of the sample to generate (in tokens).
|
/// The length of the sample to generate (in tokens).
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 1000)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
|
@ -103,23 +103,25 @@ impl GroupedQueryAttention {
|
|||||||
(k, v)
|
(k, v)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?;
|
self.kv_cache = Some((key.clone(), value.clone()));
|
||||||
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?;
|
let query = query.contiguous()?;
|
||||||
|
let key = repeat_kv(key, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||||
|
let value = repeat_kv(value, self.n_heads / self.kv_n_heads)?.contiguous()?;
|
||||||
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
let attn_weights = (query.matmul(&key)? * self.softmax_scale)?;
|
||||||
let attn_bias = {
|
let attn_bias = {
|
||||||
let s_q = query.dim(D::Minus2)?;
|
let s_q = query.dim(D::Minus2)?;
|
||||||
let s_k = key.dim(D::Minus1)?;
|
let s_k = key.dim(D::Minus1)?;
|
||||||
let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
|
let (_, _, a_q, a_k) = self.attn_bias.dims4()?;
|
||||||
self.attn_bias
|
let start_q = a_q.saturating_sub(s_q);
|
||||||
.narrow(2, a_q - s_q, s_q)?
|
let start_k = a_k.saturating_sub(s_k);
|
||||||
.narrow(3, a_k - s_k, s_k)?
|
self.attn_bias.i((.., .., start_q.., start_k..))?
|
||||||
};
|
};
|
||||||
let attn_weights = (attn_weights + attn_bias)?;
|
let attn_weights = attn_weights.broadcast_add(&attn_bias)?;
|
||||||
let attn_weights = match mask {
|
let attn_weights = match mask {
|
||||||
None => attn_weights,
|
None => attn_weights,
|
||||||
Some(mask) => masked_fill(
|
Some(mask) => masked_fill(
|
||||||
&attn_weights,
|
&attn_weights,
|
||||||
&mask.broadcast_left(b_size * self.n_heads)?,
|
&mask.broadcast_as(attn_weights.shape())?,
|
||||||
f32::NEG_INFINITY,
|
f32::NEG_INFINITY,
|
||||||
)?,
|
)?,
|
||||||
};
|
};
|
||||||
@ -128,7 +130,8 @@ impl GroupedQueryAttention {
|
|||||||
.matmul(&value)?
|
.matmul(&value)?
|
||||||
.transpose(1, 2)?
|
.transpose(1, 2)?
|
||||||
.flatten_from(D::Minus2)?;
|
.flatten_from(D::Minus2)?;
|
||||||
attn_output.apply(&self.out_proj)
|
let out = attn_output.apply(&self.out_proj)?;
|
||||||
|
Ok(out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +202,7 @@ impl MPTBlock {
|
|||||||
let xs = self.attn.forward(&xs, mask)?;
|
let xs = self.attn.forward(&xs, mask)?;
|
||||||
let xs = (xs + residual)?;
|
let xs = (xs + residual)?;
|
||||||
let residual = &xs;
|
let residual = &xs;
|
||||||
let xs = xs.apply(&self.norm2)?.apply(&self.ffn);
|
let xs = xs.apply(&self.norm2)?.apply(&self.ffn)?;
|
||||||
xs + residual
|
xs + residual
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -275,12 +278,15 @@ impl Model {
|
|||||||
Some(get_mask(seq_len, xs.device())?)
|
Some(get_mask(seq_len, xs.device())?)
|
||||||
};
|
};
|
||||||
for block in self.blocks.iter_mut() {
|
for block in self.blocks.iter_mut() {
|
||||||
xs = block.forward(&xs, mask.as_ref())?
|
xs = block.forward(&xs, mask.as_ref())?;
|
||||||
}
|
}
|
||||||
xs.narrow(1, seq_len - 1, 1)?
|
let xs = xs.apply(&self.norm_f)?;
|
||||||
|
let logits = xs
|
||||||
|
.narrow(1, seq_len - 1, 1)?
|
||||||
.squeeze(1)?
|
.squeeze(1)?
|
||||||
.matmul(&self.wte.embeddings().t()?)?
|
.matmul(&self.wte.embeddings().t()?)?
|
||||||
.squeeze(1)
|
.squeeze(1)?;
|
||||||
|
Ok(logits)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user