diff --git a/src/ipv6.rs b/src/ipv6.rs index 916ec66..90503f1 100644 --- a/src/ipv6.rs +++ b/src/ipv6.rs @@ -1,3 +1,4 @@ +use std::cmp; use std::fmt; use std::net::Ipv6Addr; use std::str::FromStr; @@ -41,6 +42,50 @@ impl Ipv6Network { pub fn prefix(&self) -> u8 { self.prefix } + + /// Returns the mask for this `Ipv6Network`. + /// That means the `prefix` most significant bits will be 1 and the rest 0 + /// + /// # Examples + /// + /// ``` + /// use std::net::Ipv6Addr; + /// use ipnetwork::Ipv6Network; + /// + /// let net = Ipv6Network::from_cidr("ff01::0/32").unwrap(); + /// assert_eq!(net.mask(), Ipv6Addr::new(0xffff, 0xffff, 0, 0, 0, 0, 0, 0)); + /// ``` + pub fn mask(&self) -> Ipv6Addr { + // Ipv6Addr::from is only implemented for [u8; 16] + let mut segments = [0; 16]; + for (i, segment) in segments.iter_mut().enumerate() { + let bits_remaining = self.prefix.saturating_sub(i as u8 * 8); + let set_bits = cmp::min(bits_remaining, 8); + *segment = !(0xff as u16 >> set_bits) as u8; + } + Ipv6Addr::from(segments) + } + + /// Checks if a given `Ipv6Addr` is in this `Ipv6Network` + /// + /// # Examples + /// + /// ``` + /// use std::net::Ipv6Addr; + /// use ipnetwork::Ipv6Network; + /// + /// let net = Ipv6Network::from_cidr("ff01::0/32").unwrap(); + /// assert!(net.contains(Ipv6Addr::new(0xff01, 0, 0, 0, 0, 0, 0, 0x1))); + /// assert!(!net.contains(Ipv6Addr::new(0xffff, 0, 0, 0, 0, 0, 0, 0x1))); + /// ``` + pub fn contains(&self, ip: Ipv6Addr) -> bool { + let a = self.addr.segments(); + let b = ip.segments(); + let addrs = Iterator::zip(a.iter(), b.iter()); + self.mask().segments().iter().zip(addrs).all(|(mask, (a, b))| + a & mask == b & mask + ) + } } impl fmt::Display for Ipv6Network { @@ -92,4 +137,25 @@ mod test { let cidr = Ipv6Network::from_cidr("::1/129"); assert!(cidr.is_err()); } + + #[test] + fn mask_v6() { + let cidr = Ipv6Network::new(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 0), 40).unwrap(); + let mask = cidr.mask(); + assert_eq!(mask, Ipv6Addr::new(0xffff, 0xffff, 0xff00, 0, 0, 0, 0, 0)); + } + + #[test] + fn contains_v6() { + let cidr = Ipv6Network::new(Ipv6Addr::new(0xff01, 0, 0, 0x17, 0, 0, 0, 0x2), 65).unwrap(); + let ip = Ipv6Addr::new(0xff01, 0, 0, 0x17, 0x7fff, 0, 0, 0x2); + assert!(cidr.contains(ip)); + } + + #[test] + fn not_contains_v6() { + let cidr = Ipv6Network::new(Ipv6Addr::new(0xff01, 0, 0, 0x17, 0, 0, 0, 0x2), 65).unwrap(); + let ip = Ipv6Addr::new(0xff01, 0, 0, 0x17, 0xffff, 0, 0, 0x2); + assert!(!cidr.contains(ip)); + } }