ipv6: rewrite core ipv6 methods to operate on u128s (#187)

* errors: adds `From<AddrParseError>` for `IpNetworkError`

* ipv4: adds safety comment to `new_unchecked` and debug assertions to verify compliance

* ipv6: rewrite core ipv6 methods to operate on `u128`s
This commit is contained in:
Christopher Mahoney
2024-04-30 22:41:17 -04:00
committed by GitHub
parent 068959e2c4
commit 70d1f74e29
3 changed files with 117 additions and 81 deletions

View File

@ -1,4 +1,4 @@
use std::{error::Error, fmt};
use std::{error::Error, fmt, net::AddrParseError};
use crate::error::IpNetworkError::*;
@ -9,7 +9,7 @@ pub enum IpNetworkError {
InvalidAddr(String),
InvalidPrefix,
InvalidCidrFormat(String),
NetworkSizeError(NetworkSizeError)
NetworkSizeError(NetworkSizeError),
}
impl fmt::Display for IpNetworkError {
@ -18,7 +18,7 @@ impl fmt::Display for IpNetworkError {
InvalidAddr(ref s) => write!(f, "invalid address: {s}"),
InvalidPrefix => write!(f, "invalid prefix"),
InvalidCidrFormat(ref s) => write!(f, "invalid cidr format: {s}"),
NetworkSizeError(ref e) => write!(f, "network size error: {e}")
NetworkSizeError(ref e) => write!(f, "network size error: {e}"),
}
}
}
@ -29,16 +29,22 @@ impl Error for IpNetworkError {
InvalidAddr(_) => "address is invalid",
InvalidPrefix => "prefix is invalid",
InvalidCidrFormat(_) => "cidr is invalid",
NetworkSizeError(_) => "network size error"
NetworkSizeError(_) => "network size error",
}
}
}
impl From<AddrParseError> for IpNetworkError {
fn from(e: AddrParseError) -> Self {
InvalidAddr(e.to_string())
}
}
/// Cannot convert an IPv6 network size to a u32 as it is a 128-bit value.
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum NetworkSizeError {
NetworkIsTooLarge
NetworkIsTooLarge,
}
impl fmt::Display for NetworkSizeError {

View File

@ -76,6 +76,23 @@ impl Ipv4Network {
/// Constructs without checking prefix a new `Ipv4Network` from any `Ipv4Addr,
/// and a prefix denoting the network size.
///
/// # Safety
///
/// The caller must ensure that the prefix is less than or equal to 32.
///
/// # Examples
///
/// ```
/// use std::net::Ipv4Addr;
/// use ipnetwork::Ipv4Network;
///
/// let prefix = 24;
/// let addr = Ipv4Addr::new(192, 168, 1, 1);
///
/// debug_assert!(prefix <= 32);
/// let network = unsafe { Ipv4Network::new_unchecked(addr, prefix) };
/// ```
pub const unsafe fn new_unchecked(addr: Ipv4Addr, prefix: u8) -> Ipv4Network {
Ipv4Network { addr, prefix }
}
@ -128,8 +145,9 @@ impl Ipv4Network {
/// Checks if the given `Ipv4Network` is partly contained in other.
pub fn overlaps(self, other: Ipv4Network) -> bool {
other.contains(self.ip())
|| (other.contains(self.broadcast())
|| (self.contains(other.ip()) || (self.contains(other.broadcast()))))
|| other.contains(self.broadcast())
|| self.contains(other.ip())
|| self.contains(other.broadcast())
}
/// Returns the mask for this `Ipv4Network`.
@ -147,9 +165,11 @@ impl Ipv4Network {
/// assert_eq!(net.mask(), Ipv4Addr::new(255, 255, 0, 0));
/// ```
pub fn mask(&self) -> Ipv4Addr {
let mask = !(0xffff_ffff_u64 >> u64::from(self.prefix)) as u32;
debug_assert!(self.prefix <= 32);
let mask = u32::MAX << (IPV4_BITS - self.prefix);
Ipv4Addr::from(mask)
}
}
/// Returns the address of the network denoted by this `Ipv4Network`.
/// This means the lowest possible IPv4 address inside of the network.
@ -201,6 +221,8 @@ impl Ipv4Network {
/// ```
#[inline]
pub fn contains(&self, ip: Ipv4Addr) -> bool {
debug_assert!(self.prefix <= IPV4_BITS);
let mask = !(0xffff_ffff_u64 >> self.prefix) as u32;
let net = u32::from(self.addr) & mask;
(u32::from(ip) & mask) == net
@ -221,8 +243,9 @@ impl Ipv4Network {
/// assert_eq!(tinynet.size(), 1);
/// ```
pub fn size(self) -> u32 {
1 << (u32::from(IPV4_BITS - self.prefix))
}
debug_assert!(self.prefix <= 32);
1 << (IPV4_BITS - self.prefix)
}
/// Returns the `n`:th address within this network.
/// The adresses are indexed from 0 and `n` must be smaller than the size of the network.
@ -274,8 +297,7 @@ impl FromStr for Ipv4Network {
type Err = IpNetworkError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (addr_str, prefix_str) = cidr_parts(s)?;
let addr = Ipv4Addr::from_str(addr_str)
.map_err(|_| IpNetworkError::InvalidAddr(addr_str.to_string()))?;
let addr = Ipv4Addr::from_str(addr_str)?;
let prefix = match prefix_str {
Some(v) => {
if let Ok(netmask) = Ipv4Addr::from_str(v) {
@ -458,6 +480,7 @@ mod test {
}
#[test]
#[allow(dropping_copy_types)]
fn copy_compatibility_v4() {
let net = Ipv4Network::new(Ipv4Addr::new(127, 0, 0, 1), 16).unwrap();
mem::drop(net);

View File

@ -1,6 +1,6 @@
use crate::error::IpNetworkError;
use crate::parse::{cidr_parts, parse_prefix};
use std::{cmp, convert::TryFrom, fmt, net::Ipv6Addr, str::FromStr};
use std::{convert::TryFrom, fmt, net::Ipv6Addr, str::FromStr};
const IPV6_BITS: u8 = 128;
const IPV6_SEGMENT_BITS: u8 = 16;
@ -87,6 +87,23 @@ impl Ipv6Network {
/// Constructs without checking prefix a new `Ipv6Network` from any `Ipv6Addr,
/// and a prefix denoting the network size.
///
/// # Safety
///
/// The caller must ensure that the prefix is less than or equal to 32.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let prefix = 64;
/// let addr = Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0);
///
/// debug_assert!(prefix <= 128);
/// let net = unsafe { Ipv6Network::new_unchecked(addr, prefix) };
/// ```
pub const unsafe fn new_unchecked(addr: Ipv6Addr, prefix: u8) -> Ipv6Network {
Ipv6Network { addr, prefix }
}
@ -106,6 +123,10 @@ impl Ipv6Network {
/// Returns an iterator over `Ipv6Network`. Each call to `next` will return the next
/// `Ipv6Addr` in the given network. `None` will be returned when there are no more
/// addresses.
///
/// # Warning
///
/// This can return up to 2^128 addresses, which will take a _long_ time to iterate over.
pub fn iter(&self) -> Ipv6NetworkIterator {
let dec = u128::from(self.addr);
let max = u128::max_value();
@ -123,42 +144,6 @@ impl Ipv6Network {
}
}
/// Returns the address of the network denoted by this `Ipv6Network`.
/// This means the lowest possible IPv6 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.network(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0));
/// ```
pub fn network(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let ip = u128::from(self.addr) & mask;
Ipv6Addr::from(ip)
}
/// Returns the broadcast address of this `Ipv6Network`.
/// This means the highest possible IPv4 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.broadcast(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xffff, 0xffff));
/// ```
pub fn broadcast(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let broadcast = u128::from(self.addr) | !mask;
Ipv6Addr::from(broadcast)
}
pub fn ip(&self) -> Ipv6Addr {
self.addr
}
@ -180,8 +165,9 @@ impl Ipv6Network {
/// Checks if the given `Ipv6Network` is partly contained in other.
pub fn overlaps(self, other: Ipv6Network) -> bool {
other.contains(self.ip())
|| (other.contains(self.broadcast())
|| (self.contains(other.ip()) || (self.contains(other.broadcast()))))
|| other.contains(self.broadcast())
|| self.contains(other.ip())
|| self.contains(other.broadcast())
}
/// Returns the mask for this `Ipv6Network`.
@ -199,17 +185,47 @@ impl Ipv6Network {
/// assert_eq!(net.mask(), Ipv6Addr::new(0xffff, 0xffff, 0, 0, 0, 0, 0, 0));
/// ```
pub fn mask(&self) -> Ipv6Addr {
let mut segments = [0; 16];
for (i, chunk) in segments.chunks_mut(2).enumerate() {
let bits_remaining = self.prefix.saturating_sub(i as u8 * 16);
let set_bits = cmp::min(bits_remaining, 16);
let mask = !(0xffff >> set_bits) as u16;
chunk[0] = (mask >> 8) as u8;
chunk[1] = mask as u8;
}
Ipv6Addr::from(segments)
debug_assert!(self.prefix <= IPV6_BITS);
let mask = u128::MAX << (IPV6_BITS - self.prefix);
Ipv6Addr::from(mask)
}
/// Returns the address of the network denoted by this `Ipv6Network`.
/// This means the lowest possible IPv6 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.network(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0, 0));
/// ```
pub fn network(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let network = u128::from(self.addr) & mask;
Ipv6Addr::from(network)
}
/// Returns the broadcast address of this `Ipv6Network`.
/// This means the highest possible IPv4 address inside of the network.
///
/// # Examples
///
/// ```
/// use std::net::Ipv6Addr;
/// use ipnetwork::Ipv6Network;
///
/// let net: Ipv6Network = "2001:db8::/96".parse().unwrap();
/// assert_eq!(net.broadcast(), Ipv6Addr::new(0x2001, 0xdb8, 0, 0, 0, 0, 0xffff, 0xffff));
/// ```
pub fn broadcast(&self) -> Ipv6Addr {
let mask = u128::from(self.mask());
let broadcast = u128::from(self.addr) | !mask;
Ipv6Addr::from(broadcast)
}
/// Checks if a given `Ipv6Addr` is in this `Ipv6Network`
///
@ -225,14 +241,10 @@ impl Ipv6Network {
/// ```
#[inline]
pub fn contains(&self, ip: Ipv6Addr) -> bool {
let a = self.addr.segments();
let b = ip.segments();
let addrs = Iterator::zip(a.iter(), b.iter());
self.mask()
.segments()
.iter()
.zip(addrs)
.all(|(mask, (a, b))| a & mask == b & mask)
let ip = u128::from(ip);
let net = u128::from(self.network());
let mask = u128::from(self.mask());
(ip & mask) == net
}
/// Returns number of possible host addresses in this `Ipv6Network`.
@ -250,12 +262,12 @@ impl Ipv6Network {
/// assert_eq!(tinynet.size(), 1);
/// ```
pub fn size(&self) -> u128 {
let host_bits = u32::from(IPV6_BITS - self.prefix);
2u128.pow(host_bits)
debug_assert!(self.prefix <= IPV6_BITS);
1 << (IPV6_BITS - self.prefix)
}
/// Returns the `n`:th address within this network.
/// The adresses are indexed from 0 and `n` must be smaller than the size of the network.
/// The addresses are indexed from 0 and `n` must be smaller than the size of the network.
///
/// # Examples
///
@ -296,14 +308,12 @@ impl FromStr for Ipv6Network {
type Err = IpNetworkError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (addr_str, prefix_str) = cidr_parts(s)?;
let addr = Ipv6Addr::from_str(addr_str).map_err(|e| IpNetworkError::InvalidAddr(e.to_string()))?;
let addr = Ipv6Addr::from_str(addr_str)?;
let prefix = parse_prefix(prefix_str.unwrap_or(&IPV6_BITS.to_string()), IPV6_BITS)?;
Ipv6Network::new(addr, prefix)
}
}
impl TryFrom<&str> for Ipv6Network {
type Error = IpNetworkError;
@ -694,7 +704,7 @@ mod test {
let other: Ipv6Network = "2001:DB8:ACAD::1/64".parse().unwrap();
let other2: Ipv6Network = "2001:DB8:ACAD::20:2/64".parse().unwrap();
assert_eq!(other2.overlaps(other), true);
assert!(other2.overlaps(other));
}
#[test]
@ -726,9 +736,6 @@ mod test {
net.nth(65538).unwrap(),
Ipv6Addr::from_str("ff01::1:2").unwrap()
);
assert_eq!(
net.nth(net.size()),
None
);
assert_eq!(net.nth(net.size()), None);
}
}