refactor: simplify dns server trait

feat: update reverse dns
This commit is contained in:
2025-03-04 14:28:50 +00:00
parent 1391b8a594
commit fc735590d6
7 changed files with 271 additions and 106 deletions

View File

@ -191,6 +191,8 @@ pub struct ApiVmHostRegion {
pub struct VMPatchRequest { pub struct VMPatchRequest {
/// SSH key assigned to vm /// SSH key assigned to vm
pub ssh_key_id: Option<u64>, pub ssh_key_id: Option<u64>,
/// Reverse DNS PTR domain
pub reverse_dns: Option<String>,
} }
#[derive(Serialize, Deserialize, JsonSchema)] #[derive(Serialize, Deserialize, JsonSchema)]

View File

@ -205,6 +205,7 @@ async fn v1_get_vm(
async fn v1_patch_vm( async fn v1_patch_vm(
auth: Nip98Auth, auth: Nip98Auth,
db: &State<Arc<dyn LNVpsDb>>, db: &State<Arc<dyn LNVpsDb>>,
provisioner: &State<Arc<LNVpsProvisioner>>,
settings: &State<Settings>, settings: &State<Settings>,
id: u64, id: u64,
data: Json<VMPatchRequest>, data: Json<VMPatchRequest>,
@ -216,20 +217,31 @@ async fn v1_patch_vm(
return ApiData::err("VM doesnt belong to you"); return ApiData::err("VM doesnt belong to you");
} }
let mut vm_config = false;
if let Some(k) = data.ssh_key_id { if let Some(k) = data.ssh_key_id {
let ssh_key = db.get_user_ssh_key(k).await?; let ssh_key = db.get_user_ssh_key(k).await?;
if ssh_key.user_id != uid { if ssh_key.user_id != uid {
return ApiData::err("SSH key doesnt belong to you"); return ApiData::err("SSH key doesnt belong to you");
} }
vm.ssh_key_id = ssh_key.id; vm.ssh_key_id = ssh_key.id;
vm_config = true;
} }
db.update_vm(&vm).await?; if let Some(ptr) = &data.reverse_dns {
let mut ips = db.list_vm_ip_assignments(vm.id).await?;
for mut ip in ips.iter_mut() {
ip.dns_reverse = Some(ptr.to_string());
provisioner.update_reverse_ip_dns(&mut ip).await?;
}
}
let info = FullVmInfo::load(vm.id, (*db).clone()).await?; if vm_config {
let host = db.get_host(vm.host_id).await?; db.update_vm(&vm).await?;
let client = get_host_client(&host, &settings.provisioner)?; let info = FullVmInfo::load(vm.id, (*db).clone()).await?;
client.configure_vm(&info).await?; let host = db.get_host(vm.host_id).await?;
let client = get_host_client(&host, &settings.provisioner)?;
client.configure_vm(&info).await?;
}
ApiData::ok(()) ApiData::ok(())
} }

View File

@ -19,7 +19,6 @@ use std::net::{IpAddr, SocketAddr};
use std::path::PathBuf; use std::path::PathBuf;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use tokio::fs::create_dir_all;
use tokio::time::sleep; use tokio::time::sleep;
#[derive(Parser)] #[derive(Parser)]

View File

@ -1,8 +1,9 @@
use crate::dns::{BasicRecord, DnsServer, RecordType}; use crate::dns::{BasicRecord, DnsServer, RecordType};
use crate::json_api::JsonApi; use crate::json_api::JsonApi;
use anyhow::Context;
use lnvps_db::async_trait; use lnvps_db::async_trait;
use log::info;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::net::IpAddr;
pub struct Cloudflare { pub struct Cloudflare {
api: JsonApi, api: JsonApi,
@ -23,67 +24,101 @@ impl Cloudflare {
#[async_trait] #[async_trait]
impl DnsServer for Cloudflare { impl DnsServer for Cloudflare {
async fn add_ptr_record(&self, key: &str, value: &str) -> anyhow::Result<BasicRecord> { async fn add_record(&self, record: &BasicRecord) -> anyhow::Result<BasicRecord> {
let zone_id = match &record.kind {
RecordType::PTR => &self.reverse_zone_id,
_ => &self.forward_zone_id,
};
info!(
"Adding record: [{}] {} => {}",
record.kind, record.name, record.value
);
let id_response: CfResult<CfRecord> = self let id_response: CfResult<CfRecord> = self
.api .api
.post( .post(
&format!("/client/v4/zones/{}/dns_records", self.reverse_zone_id), &format!("/client/v4/zones/{zone_id}/dns_records"),
CfRecord { CfRecord {
content: value.to_string(), content: record.value.to_string(),
name: key.to_string(), name: record.name.to_string(),
r_type: "PTR".to_string(), r_type: Some(record.kind.to_string()),
id: None, id: None,
}, },
) )
.await?; .await?;
Ok(BasicRecord { Ok(BasicRecord {
name: id_response.result.name, name: id_response.result.name,
value: value.to_string(), value: id_response.result.content,
id: id_response.result.id, id: id_response.result.id,
kind: RecordType::PTR, kind: record.kind.clone(),
}) })
} }
async fn delete_ptr_record(&self, key: &str) -> anyhow::Result<()> { async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> {
todo!() let zone_id = match &record.kind {
RecordType::PTR => &self.reverse_zone_id,
_ => &self.forward_zone_id,
};
let record_id = record.id.as_ref().context("record id missing")?;
info!(
"Deleting record: [{}] {} => {}",
record.kind, record.name, record.value
);
self.api
.req(
reqwest::Method::DELETE,
&format!("/client/v4/zones/{}/dns_records/{}", zone_id, record_id),
CfRecord {
content: record.value.to_string(),
name: record.name.to_string(),
r_type: None,
id: None,
},
)
.await?;
Ok(())
} }
async fn add_a_record(&self, name: &str, ip: IpAddr) -> anyhow::Result<BasicRecord> { async fn update_record(&self, record: &BasicRecord) -> anyhow::Result<BasicRecord> {
let zone_id = match &record.kind {
RecordType::PTR => &self.reverse_zone_id,
_ => &self.forward_zone_id,
};
info!(
"Updating record: [{}] {} => {}",
record.kind, record.name, record.value
);
let record_id = record.id.as_ref().context("record id missing")?;
let id_response: CfResult<CfRecord> = self let id_response: CfResult<CfRecord> = self
.api .api
.post( .req(
&format!("/client/v4/zones/{}/dns_records", self.forward_zone_id), reqwest::Method::PATCH,
&format!("/client/v4/zones/{}/dns_records/{}", zone_id, record_id),
CfRecord { CfRecord {
content: ip.to_string(), content: record.value.to_string(),
name: name.to_string(), name: record.name.to_string(),
r_type: if ip.is_ipv4() { r_type: Some(record.kind.to_string()),
"A".to_string() id: Some(record_id.to_string()),
} else {
"AAAA".to_string()
},
id: None,
}, },
) )
.await?; .await?;
Ok(BasicRecord { Ok(BasicRecord {
name: id_response.result.name, name: id_response.result.name,
value: ip.to_string(), value: id_response.result.content,
id: id_response.result.id, id: id_response.result.id,
kind: RecordType::A, kind: record.kind.clone(),
}) })
} }
async fn delete_a_record(&self, name: &str) -> anyhow::Result<()> {
todo!()
}
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
struct CfRecord { struct CfRecord {
pub content: String, pub content: String,
pub name: String, pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "type")] #[serde(rename = "type")]
pub r_type: String, pub r_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>, pub id: Option<String>,
} }

View File

@ -1,7 +1,9 @@
use anyhow::Result; use anyhow::{bail, Context, Result};
use lnvps_db::async_trait; use lnvps_db::{async_trait, VmIpAssignment};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fmt::{Display, Formatter};
use std::net::IpAddr; use std::net::IpAddr;
use std::str::FromStr;
#[cfg(feature = "cloudflare")] #[cfg(feature = "cloudflare")]
mod cloudflare; mod cloudflare;
@ -11,16 +13,13 @@ pub use cloudflare::*;
#[async_trait] #[async_trait]
pub trait DnsServer: Send + Sync { pub trait DnsServer: Send + Sync {
/// Add PTR record to the reverse zone /// Add PTR record to the reverse zone
async fn add_ptr_record(&self, key: &str, value: &str) -> Result<BasicRecord>; async fn add_record(&self, record: &BasicRecord) -> Result<BasicRecord>;
/// Delete PTR record from the reverse zone /// Delete PTR record from the reverse zone
async fn delete_ptr_record(&self, key: &str) -> Result<()>; async fn delete_record(&self, record: &BasicRecord) -> Result<()>;
/// Add A/AAAA record onto the forward zone /// Update a record
async fn add_a_record(&self, name: &str, ip: IpAddr) -> Result<BasicRecord>; async fn update_record(&self, record: &BasicRecord) -> Result<BasicRecord>;
/// Delete A/AAAA record from the forward zone
async fn delete_a_record(&self, name: &str) -> Result<()>;
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
@ -37,3 +36,117 @@ pub struct BasicRecord {
pub id: Option<String>, pub id: Option<String>,
pub kind: RecordType, pub kind: RecordType,
} }
impl Display for RecordType {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
match self {
RecordType::A => write!(f, "A"),
RecordType::AAAA => write!(f, "AAAA"),
RecordType::PTR => write!(f, "PTR"),
}
}
}
impl BasicRecord {
pub fn forward(ip: &VmIpAssignment) -> Result<Self> {
let addr = IpAddr::from_str(&ip.ip)?;
Ok(Self {
name: format!("vm-{}", &ip.vm_id),
value: addr.to_string(),
id: ip.dns_forward_ref.clone(),
kind: match addr {
IpAddr::V4(_) => RecordType::A,
IpAddr::V6(_) => RecordType::AAAA,
},
})
}
pub fn reverse_to_fwd(ip: &VmIpAssignment) -> Result<Self> {
let addr = IpAddr::from_str(&ip.ip)?;
// Use explicit reverse entry or use the forward entry
let fwd = ip
.dns_reverse
.as_ref()
.or(ip.dns_forward.as_ref())
.context("Reverse/Forward DNS name required for reverse entry")?
.to_string();
if !is_valid_fqdn(fwd.as_str()) {
bail!("Forward DNS name is not a valid FQDN");
}
Ok(Self {
name: match addr {
IpAddr::V4(i) => i.octets()[3].to_string(),
IpAddr::V6(_) => bail!("IPv6 PTR not supported"),
},
value: fwd,
id: ip.dns_reverse_ref.clone(),
kind: RecordType::PTR,
})
}
pub fn reverse(ip: &VmIpAssignment) -> Result<Self> {
let addr = IpAddr::from_str(&ip.ip)?;
let rev = ip
.dns_reverse
.as_ref()
.context("Reverse DNS name required for reverse entry")?
.to_string();
if !is_valid_fqdn(&rev) {
bail!("Reverse DNS name is not a valid FQDN");
}
Ok(Self {
name: match addr {
IpAddr::V4(i) => i.octets()[3].to_string(),
IpAddr::V6(_) => bail!("IPv6 PTR not supported"),
},
value: rev,
id: ip.dns_reverse_ref.clone(),
kind: RecordType::PTR,
})
}
}
/// Grok 3
pub fn is_valid_fqdn(s: &str) -> bool {
// Remove trailing dot if present (optional in practice)
let s = s.strip_suffix('.').unwrap_or(s);
// Check total length (max 255 chars, including dots)
if s.len() > 255 || s.is_empty() {
return false;
}
// Split into labels and validate each
let labels: Vec<&str> = s.split('.').collect();
// Must have at least two labels (e.g., "example.com")
if labels.len() < 2 {
return false;
}
for label in labels {
// Each label must be 1-63 chars
if label.len() > 63 || label.is_empty() {
return false;
}
// Must start with a letter or digit
if !label.chars().next().unwrap().is_alphanumeric() {
return false;
}
// Must end with a letter or digit
if !label.chars().last().unwrap().is_alphanumeric() {
return false;
}
// Only letters, digits, and hyphens allowed
if !label.chars().all(|c| c.is_alphanumeric() || c == '-') {
return false;
}
}
true
}

View File

@ -666,7 +666,7 @@ impl VmHostClient for MockVmHost {
} }
} }
async fn configure_vm(&self, vm: &Vm) -> anyhow::Result<()> { async fn configure_vm(&self, vm: &FullVmInfo) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
} }
@ -696,62 +696,42 @@ impl MockDnsServer {
} }
#[async_trait] #[async_trait]
impl DnsServer for MockDnsServer { impl DnsServer for MockDnsServer {
async fn add_ptr_record(&self, key: &str, value: &str) -> anyhow::Result<BasicRecord> { async fn add_record(&self, record: &BasicRecord) -> anyhow::Result<BasicRecord> {
let mut rev = self.reverse.lock().await; let mut table = match record.kind {
RecordType::PTR => self.reverse.lock().await,
_ => self.forward.lock().await,
};
if rev.values().any(|v| v.name == key) { if table.values().any(|v| v.name == record.name) {
bail!("Duplicate record with name {}", key); bail!("Duplicate record with name {}", record.name);
} }
let rnd_id: [u8; 12] = rand::random(); let rnd_id: [u8; 12] = rand::random();
let id = hex::encode(rnd_id); let id = hex::encode(rnd_id);
rev.insert( table.insert(
id.clone(), id.clone(),
MockDnsEntry { MockDnsEntry {
name: key.to_string(), name: record.name.to_string(),
value: value.to_string(), value: record.value.to_string(),
kind: "PTR".to_string(), kind: record.kind.to_string(),
}, },
); );
Ok(BasicRecord { Ok(BasicRecord {
name: format!("{}.X.Y.Z.in-addr.arpa", key), name: match record.kind {
value: value.to_string(), RecordType::PTR => format!("{}.X.Y.Z.addr.in-arpa", record.name),
_ => format!("{}.lnvps.mock", record.name),
},
value: record.value.clone(),
id: Some(id), id: Some(id),
kind: RecordType::PTR, kind: record.kind.clone(),
}) })
} }
async fn delete_ptr_record(&self, key: &str) -> anyhow::Result<()> { async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> {
todo!() todo!()
} }
async fn add_a_record(&self, name: &str, ip: IpAddr) -> anyhow::Result<BasicRecord> { async fn update_record(&self, name: &BasicRecord) -> anyhow::Result<BasicRecord> {
let mut rev = self.forward.lock().await;
if rev.values().any(|v| v.name == name) {
bail!("Duplicate record with name {}", name);
}
let fqdn = format!("{}.lnvps.mock", name);
let rnd_id: [u8; 12] = rand::random();
let id = hex::encode(rnd_id);
rev.insert(
id.clone(),
MockDnsEntry {
name: fqdn.clone(),
value: ip.to_string(),
kind: "A".to_string(),
},
);
Ok(BasicRecord {
name: fqdn,
value: ip.to_string(),
id: Some(id),
kind: RecordType::A,
})
}
async fn delete_a_record(&self, name: &str) -> anyhow::Result<()> {
todo!() todo!()
} }
} }

View File

@ -1,11 +1,11 @@
use crate::dns::DnsServer; use crate::dns::{BasicRecord, DnsServer};
use crate::exchange::{ExchangeRateService, Ticker}; use crate::exchange::{ExchangeRateService, Ticker};
use crate::host::{get_host_client, FullVmInfo}; use crate::host::{get_host_client, FullVmInfo};
use crate::lightning::{AddInvoiceRequest, LightningNode}; use crate::lightning::{AddInvoiceRequest, LightningNode};
use crate::provisioner::{NetworkProvisioner, ProvisionerMethod}; use crate::provisioner::{NetworkProvisioner, ProvisionerMethod};
use crate::router::Router; use crate::router::Router;
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings}; use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings};
use anyhow::{bail, Result}; use anyhow::{bail, Context, Result};
use chrono::{Days, Months, Utc}; use chrono::{Days, Months, Utc};
use futures::future::join_all; use futures::future::join_all;
use lnvps_db::{IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment}; use lnvps_db::{IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment};
@ -55,7 +55,8 @@ impl LNVpsProvisioner {
} }
} }
async fn delete_ip_assignment(&self, vm: &Vm) -> Result<()> { pub async fn delete_ip_assignment(&self, vm: &Vm) -> Result<()> {
// Delete access policy
if let NetworkAccessPolicy::StaticArp { .. } = &self.network_policy.access { if let NetworkAccessPolicy::StaticArp { .. } = &self.network_policy.access {
if let Some(r) = self.router.as_ref() { if let Some(r) = self.router.as_ref() {
let ent = r.list_arp_entry().await?; let ent = r.list_arp_entry().await?;
@ -72,6 +73,45 @@ impl LNVpsProvisioner {
Ok(()) Ok(())
} }
pub async fn update_forward_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> {
if let Some(dns) = &self.dns {
let fwd = BasicRecord::forward(assignment)?;
let ret_fwd = if fwd.id.is_some() {
dns.update_record(&fwd).await?
} else {
dns.add_record(&fwd).await?
};
assignment.dns_forward = Some(ret_fwd.name);
assignment.dns_forward_ref = Some(ret_fwd.id.context("Record id is missing")?);
// save to db
if assignment.id != 0 {
self.db.update_vm_ip_assignment(assignment).await?;
}
}
Ok(())
}
pub async fn update_reverse_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> {
if let Some(dns) = &self.dns {
let ret_rev = if assignment.dns_reverse_ref.is_some() {
dns.update_record(&BasicRecord::reverse(assignment)?)
.await?
} else {
dns.add_record(&BasicRecord::reverse_to_fwd(assignment)?)
.await?
};
assignment.dns_reverse = Some(ret_rev.value);
assignment.dns_reverse_ref = Some(ret_rev.id.context("Record id is missing")?);
// save to db
if assignment.id != 0 {
self.db.update_vm_ip_assignment(assignment).await?;
}
}
Ok(())
}
async fn save_ip_assignment(&self, vm: &Vm, assignment: &mut VmIpAssignment) -> Result<()> { async fn save_ip_assignment(&self, vm: &Vm, assignment: &mut VmIpAssignment) -> Result<()> {
let ip = IpAddr::from_str(&assignment.ip)?; let ip = IpAddr::from_str(&assignment.ip)?;
@ -91,27 +131,11 @@ impl LNVpsProvisioner {
} }
// Add DNS records // Add DNS records
if let Some(dns) = &self.dns { self.update_forward_ip_dns(assignment).await?;
let sub_name = format!("vm-{}", vm.id); self.update_reverse_ip_dns(assignment).await?;
let fwd = dns.add_a_record(&sub_name, ip.clone()).await?;
assignment.dns_forward = Some(fwd.name.clone());
assignment.dns_forward_ref = fwd.id;
match ip {
IpAddr::V4(ip) => {
let last_octet = ip.octets()[3].to_string();
let rev = dns.add_ptr_record(&last_octet, &fwd.name).await?;
assignment.dns_reverse = Some(fwd.name.clone());
assignment.dns_reverse_ref = rev.id;
}
IpAddr::V6(_) => {
warn!("IPv6 forward DNS not supported yet")
}
}
}
// save to db // save to db
self.db.insert_vm_ip_assignment(&assignment).await?; self.db.insert_vm_ip_assignment(assignment).await?;
Ok(()) Ok(())
} }
@ -416,7 +440,7 @@ mod tests {
assert!(ip.dns_forward.is_some()); assert!(ip.dns_forward.is_some());
assert!(ip.dns_reverse.is_some()); assert!(ip.dns_reverse.is_some());
assert!(ip.dns_reverse_ref.is_some()); assert!(ip.dns_reverse_ref.is_some());
assert!(ip.dns_forward.is_some()); assert!(ip.dns_forward_ref.is_some());
assert_eq!(ip.dns_reverse, ip.dns_forward); assert_eq!(ip.dns_reverse, ip.dns_forward);
// assert IP address is not CIDR // assert IP address is not CIDR