refactor: simplify dns server trait
feat: update reverse dns
This commit is contained in:
@ -191,6 +191,8 @@ pub struct ApiVmHostRegion {
|
||||
pub struct VMPatchRequest {
|
||||
/// SSH key assigned to vm
|
||||
pub ssh_key_id: Option<u64>,
|
||||
/// Reverse DNS PTR domain
|
||||
pub reverse_dns: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, JsonSchema)]
|
||||
|
@ -205,6 +205,7 @@ async fn v1_get_vm(
|
||||
async fn v1_patch_vm(
|
||||
auth: Nip98Auth,
|
||||
db: &State<Arc<dyn LNVpsDb>>,
|
||||
provisioner: &State<Arc<LNVpsProvisioner>>,
|
||||
settings: &State<Settings>,
|
||||
id: u64,
|
||||
data: Json<VMPatchRequest>,
|
||||
@ -216,20 +217,31 @@ async fn v1_patch_vm(
|
||||
return ApiData::err("VM doesnt belong to you");
|
||||
}
|
||||
|
||||
let mut vm_config = false;
|
||||
if let Some(k) = data.ssh_key_id {
|
||||
let ssh_key = db.get_user_ssh_key(k).await?;
|
||||
if ssh_key.user_id != uid {
|
||||
return ApiData::err("SSH key doesnt belong to you");
|
||||
}
|
||||
vm.ssh_key_id = ssh_key.id;
|
||||
vm_config = true;
|
||||
}
|
||||
|
||||
db.update_vm(&vm).await?;
|
||||
if let Some(ptr) = &data.reverse_dns {
|
||||
let mut ips = db.list_vm_ip_assignments(vm.id).await?;
|
||||
for mut ip in ips.iter_mut() {
|
||||
ip.dns_reverse = Some(ptr.to_string());
|
||||
provisioner.update_reverse_ip_dns(&mut ip).await?;
|
||||
}
|
||||
}
|
||||
|
||||
if vm_config {
|
||||
db.update_vm(&vm).await?;
|
||||
let info = FullVmInfo::load(vm.id, (*db).clone()).await?;
|
||||
let host = db.get_host(vm.host_id).await?;
|
||||
let client = get_host_client(&host, &settings.provisioner)?;
|
||||
client.configure_vm(&info).await?;
|
||||
}
|
||||
|
||||
ApiData::ok(())
|
||||
}
|
||||
|
@ -19,7 +19,6 @@ use std::net::{IpAddr, SocketAddr};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::fs::create_dir_all;
|
||||
use tokio::time::sleep;
|
||||
|
||||
#[derive(Parser)]
|
||||
|
@ -1,8 +1,9 @@
|
||||
use crate::dns::{BasicRecord, DnsServer, RecordType};
|
||||
use crate::json_api::JsonApi;
|
||||
use anyhow::Context;
|
||||
use lnvps_db::async_trait;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::net::IpAddr;
|
||||
|
||||
pub struct Cloudflare {
|
||||
api: JsonApi,
|
||||
@ -23,67 +24,101 @@ impl Cloudflare {
|
||||
|
||||
#[async_trait]
|
||||
impl DnsServer for Cloudflare {
|
||||
async fn add_ptr_record(&self, key: &str, value: &str) -> anyhow::Result<BasicRecord> {
|
||||
async fn add_record(&self, record: &BasicRecord) -> anyhow::Result<BasicRecord> {
|
||||
let zone_id = match &record.kind {
|
||||
RecordType::PTR => &self.reverse_zone_id,
|
||||
_ => &self.forward_zone_id,
|
||||
};
|
||||
info!(
|
||||
"Adding record: [{}] {} => {}",
|
||||
record.kind, record.name, record.value
|
||||
);
|
||||
let id_response: CfResult<CfRecord> = self
|
||||
.api
|
||||
.post(
|
||||
&format!("/client/v4/zones/{}/dns_records", self.reverse_zone_id),
|
||||
&format!("/client/v4/zones/{zone_id}/dns_records"),
|
||||
CfRecord {
|
||||
content: value.to_string(),
|
||||
name: key.to_string(),
|
||||
r_type: "PTR".to_string(),
|
||||
content: record.value.to_string(),
|
||||
name: record.name.to_string(),
|
||||
r_type: Some(record.kind.to_string()),
|
||||
id: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(BasicRecord {
|
||||
name: id_response.result.name,
|
||||
value: value.to_string(),
|
||||
value: id_response.result.content,
|
||||
id: id_response.result.id,
|
||||
kind: RecordType::PTR,
|
||||
kind: record.kind.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn delete_ptr_record(&self, key: &str) -> anyhow::Result<()> {
|
||||
todo!()
|
||||
async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> {
|
||||
let zone_id = match &record.kind {
|
||||
RecordType::PTR => &self.reverse_zone_id,
|
||||
_ => &self.forward_zone_id,
|
||||
};
|
||||
let record_id = record.id.as_ref().context("record id missing")?;
|
||||
info!(
|
||||
"Deleting record: [{}] {} => {}",
|
||||
record.kind, record.name, record.value
|
||||
);
|
||||
self.api
|
||||
.req(
|
||||
reqwest::Method::DELETE,
|
||||
&format!("/client/v4/zones/{}/dns_records/{}", zone_id, record_id),
|
||||
CfRecord {
|
||||
content: record.value.to_string(),
|
||||
name: record.name.to_string(),
|
||||
r_type: None,
|
||||
id: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn add_a_record(&self, name: &str, ip: IpAddr) -> anyhow::Result<BasicRecord> {
|
||||
async fn update_record(&self, record: &BasicRecord) -> anyhow::Result<BasicRecord> {
|
||||
let zone_id = match &record.kind {
|
||||
RecordType::PTR => &self.reverse_zone_id,
|
||||
_ => &self.forward_zone_id,
|
||||
};
|
||||
info!(
|
||||
"Updating record: [{}] {} => {}",
|
||||
record.kind, record.name, record.value
|
||||
);
|
||||
let record_id = record.id.as_ref().context("record id missing")?;
|
||||
let id_response: CfResult<CfRecord> = self
|
||||
.api
|
||||
.post(
|
||||
&format!("/client/v4/zones/{}/dns_records", self.forward_zone_id),
|
||||
.req(
|
||||
reqwest::Method::PATCH,
|
||||
&format!("/client/v4/zones/{}/dns_records/{}", zone_id, record_id),
|
||||
CfRecord {
|
||||
content: ip.to_string(),
|
||||
name: name.to_string(),
|
||||
r_type: if ip.is_ipv4() {
|
||||
"A".to_string()
|
||||
} else {
|
||||
"AAAA".to_string()
|
||||
},
|
||||
id: None,
|
||||
content: record.value.to_string(),
|
||||
name: record.name.to_string(),
|
||||
r_type: Some(record.kind.to_string()),
|
||||
id: Some(record_id.to_string()),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
Ok(BasicRecord {
|
||||
name: id_response.result.name,
|
||||
value: ip.to_string(),
|
||||
value: id_response.result.content,
|
||||
id: id_response.result.id,
|
||||
kind: RecordType::A,
|
||||
kind: record.kind.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn delete_a_record(&self, name: &str) -> anyhow::Result<()> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct CfRecord {
|
||||
pub content: String,
|
||||
pub name: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "type")]
|
||||
pub r_type: String,
|
||||
pub r_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
}
|
||||
|
||||
|
131
src/dns/mod.rs
131
src/dns/mod.rs
@ -1,7 +1,9 @@
|
||||
use anyhow::Result;
|
||||
use lnvps_db::async_trait;
|
||||
use anyhow::{bail, Context, Result};
|
||||
use lnvps_db::{async_trait, VmIpAssignment};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::{Display, Formatter};
|
||||
use std::net::IpAddr;
|
||||
use std::str::FromStr;
|
||||
|
||||
#[cfg(feature = "cloudflare")]
|
||||
mod cloudflare;
|
||||
@ -11,16 +13,13 @@ pub use cloudflare::*;
|
||||
#[async_trait]
|
||||
pub trait DnsServer: Send + Sync {
|
||||
/// Add PTR record to the reverse zone
|
||||
async fn add_ptr_record(&self, key: &str, value: &str) -> Result<BasicRecord>;
|
||||
async fn add_record(&self, record: &BasicRecord) -> Result<BasicRecord>;
|
||||
|
||||
/// Delete PTR record from the reverse zone
|
||||
async fn delete_ptr_record(&self, key: &str) -> Result<()>;
|
||||
async fn delete_record(&self, record: &BasicRecord) -> Result<()>;
|
||||
|
||||
/// Add A/AAAA record onto the forward zone
|
||||
async fn add_a_record(&self, name: &str, ip: IpAddr) -> Result<BasicRecord>;
|
||||
|
||||
/// Delete A/AAAA record from the forward zone
|
||||
async fn delete_a_record(&self, name: &str) -> Result<()>;
|
||||
/// Update a record
|
||||
async fn update_record(&self, record: &BasicRecord) -> Result<BasicRecord>;
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
@ -37,3 +36,117 @@ pub struct BasicRecord {
|
||||
pub id: Option<String>,
|
||||
pub kind: RecordType,
|
||||
}
|
||||
|
||||
impl Display for RecordType {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
RecordType::A => write!(f, "A"),
|
||||
RecordType::AAAA => write!(f, "AAAA"),
|
||||
RecordType::PTR => write!(f, "PTR"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BasicRecord {
|
||||
pub fn forward(ip: &VmIpAssignment) -> Result<Self> {
|
||||
let addr = IpAddr::from_str(&ip.ip)?;
|
||||
Ok(Self {
|
||||
name: format!("vm-{}", &ip.vm_id),
|
||||
value: addr.to_string(),
|
||||
id: ip.dns_forward_ref.clone(),
|
||||
kind: match addr {
|
||||
IpAddr::V4(_) => RecordType::A,
|
||||
IpAddr::V6(_) => RecordType::AAAA,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reverse_to_fwd(ip: &VmIpAssignment) -> Result<Self> {
|
||||
let addr = IpAddr::from_str(&ip.ip)?;
|
||||
|
||||
// Use explicit reverse entry or use the forward entry
|
||||
let fwd = ip
|
||||
.dns_reverse
|
||||
.as_ref()
|
||||
.or(ip.dns_forward.as_ref())
|
||||
.context("Reverse/Forward DNS name required for reverse entry")?
|
||||
.to_string();
|
||||
|
||||
if !is_valid_fqdn(fwd.as_str()) {
|
||||
bail!("Forward DNS name is not a valid FQDN");
|
||||
}
|
||||
Ok(Self {
|
||||
name: match addr {
|
||||
IpAddr::V4(i) => i.octets()[3].to_string(),
|
||||
IpAddr::V6(_) => bail!("IPv6 PTR not supported"),
|
||||
},
|
||||
value: fwd,
|
||||
id: ip.dns_reverse_ref.clone(),
|
||||
kind: RecordType::PTR,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn reverse(ip: &VmIpAssignment) -> Result<Self> {
|
||||
let addr = IpAddr::from_str(&ip.ip)?;
|
||||
let rev = ip
|
||||
.dns_reverse
|
||||
.as_ref()
|
||||
.context("Reverse DNS name required for reverse entry")?
|
||||
.to_string();
|
||||
if !is_valid_fqdn(&rev) {
|
||||
bail!("Reverse DNS name is not a valid FQDN");
|
||||
}
|
||||
Ok(Self {
|
||||
name: match addr {
|
||||
IpAddr::V4(i) => i.octets()[3].to_string(),
|
||||
IpAddr::V6(_) => bail!("IPv6 PTR not supported"),
|
||||
},
|
||||
value: rev,
|
||||
id: ip.dns_reverse_ref.clone(),
|
||||
kind: RecordType::PTR,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Grok 3
|
||||
pub fn is_valid_fqdn(s: &str) -> bool {
|
||||
// Remove trailing dot if present (optional in practice)
|
||||
let s = s.strip_suffix('.').unwrap_or(s);
|
||||
|
||||
// Check total length (max 255 chars, including dots)
|
||||
if s.len() > 255 || s.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Split into labels and validate each
|
||||
let labels: Vec<&str> = s.split('.').collect();
|
||||
|
||||
// Must have at least two labels (e.g., "example.com")
|
||||
if labels.len() < 2 {
|
||||
return false;
|
||||
}
|
||||
|
||||
for label in labels {
|
||||
// Each label must be 1-63 chars
|
||||
if label.len() > 63 || label.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must start with a letter or digit
|
||||
if !label.chars().next().unwrap().is_alphanumeric() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Must end with a letter or digit
|
||||
if !label.chars().last().unwrap().is_alphanumeric() {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Only letters, digits, and hyphens allowed
|
||||
if !label.chars().all(|c| c.is_alphanumeric() || c == '-') {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
64
src/mocks.rs
64
src/mocks.rs
@ -666,7 +666,7 @@ impl VmHostClient for MockVmHost {
|
||||
}
|
||||
}
|
||||
|
||||
async fn configure_vm(&self, vm: &Vm) -> anyhow::Result<()> {
|
||||
async fn configure_vm(&self, vm: &FullVmInfo) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
@ -696,62 +696,42 @@ impl MockDnsServer {
|
||||
}
|
||||
#[async_trait]
|
||||
impl DnsServer for MockDnsServer {
|
||||
async fn add_ptr_record(&self, key: &str, value: &str) -> anyhow::Result<BasicRecord> {
|
||||
let mut rev = self.reverse.lock().await;
|
||||
async fn add_record(&self, record: &BasicRecord) -> anyhow::Result<BasicRecord> {
|
||||
let mut table = match record.kind {
|
||||
RecordType::PTR => self.reverse.lock().await,
|
||||
_ => self.forward.lock().await,
|
||||
};
|
||||
|
||||
if rev.values().any(|v| v.name == key) {
|
||||
bail!("Duplicate record with name {}", key);
|
||||
if table.values().any(|v| v.name == record.name) {
|
||||
bail!("Duplicate record with name {}", record.name);
|
||||
}
|
||||
|
||||
let rnd_id: [u8; 12] = rand::random();
|
||||
let id = hex::encode(rnd_id);
|
||||
rev.insert(
|
||||
table.insert(
|
||||
id.clone(),
|
||||
MockDnsEntry {
|
||||
name: key.to_string(),
|
||||
value: value.to_string(),
|
||||
kind: "PTR".to_string(),
|
||||
name: record.name.to_string(),
|
||||
value: record.value.to_string(),
|
||||
kind: record.kind.to_string(),
|
||||
},
|
||||
);
|
||||
Ok(BasicRecord {
|
||||
name: format!("{}.X.Y.Z.in-addr.arpa", key),
|
||||
value: value.to_string(),
|
||||
id: Some(id),
|
||||
kind: RecordType::PTR,
|
||||
})
|
||||
}
|
||||
|
||||
async fn delete_ptr_record(&self, key: &str) -> anyhow::Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn add_a_record(&self, name: &str, ip: IpAddr) -> anyhow::Result<BasicRecord> {
|
||||
let mut rev = self.forward.lock().await;
|
||||
|
||||
if rev.values().any(|v| v.name == name) {
|
||||
bail!("Duplicate record with name {}", name);
|
||||
}
|
||||
|
||||
let fqdn = format!("{}.lnvps.mock", name);
|
||||
let rnd_id: [u8; 12] = rand::random();
|
||||
let id = hex::encode(rnd_id);
|
||||
rev.insert(
|
||||
id.clone(),
|
||||
MockDnsEntry {
|
||||
name: fqdn.clone(),
|
||||
value: ip.to_string(),
|
||||
kind: "A".to_string(),
|
||||
name: match record.kind {
|
||||
RecordType::PTR => format!("{}.X.Y.Z.addr.in-arpa", record.name),
|
||||
_ => format!("{}.lnvps.mock", record.name),
|
||||
},
|
||||
);
|
||||
Ok(BasicRecord {
|
||||
name: fqdn,
|
||||
value: ip.to_string(),
|
||||
value: record.value.clone(),
|
||||
id: Some(id),
|
||||
kind: RecordType::A,
|
||||
kind: record.kind.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn delete_a_record(&self, name: &str) -> anyhow::Result<()> {
|
||||
async fn delete_record(&self, record: &BasicRecord) -> anyhow::Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn update_record(&self, name: &BasicRecord) -> anyhow::Result<BasicRecord> {
|
||||
todo!()
|
||||
}
|
||||
}
|
||||
|
@ -1,11 +1,11 @@
|
||||
use crate::dns::DnsServer;
|
||||
use crate::dns::{BasicRecord, DnsServer};
|
||||
use crate::exchange::{ExchangeRateService, Ticker};
|
||||
use crate::host::{get_host_client, FullVmInfo};
|
||||
use crate::lightning::{AddInvoiceRequest, LightningNode};
|
||||
use crate::provisioner::{NetworkProvisioner, ProvisionerMethod};
|
||||
use crate::router::Router;
|
||||
use crate::settings::{NetworkAccessPolicy, NetworkPolicy, ProvisionerConfig, Settings};
|
||||
use anyhow::{bail, Result};
|
||||
use anyhow::{bail, Context, Result};
|
||||
use chrono::{Days, Months, Utc};
|
||||
use futures::future::join_all;
|
||||
use lnvps_db::{IpRange, LNVpsDb, Vm, VmCostPlanIntervalType, VmIpAssignment, VmPayment};
|
||||
@ -55,7 +55,8 @@ impl LNVpsProvisioner {
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_ip_assignment(&self, vm: &Vm) -> Result<()> {
|
||||
pub async fn delete_ip_assignment(&self, vm: &Vm) -> Result<()> {
|
||||
// Delete access policy
|
||||
if let NetworkAccessPolicy::StaticArp { .. } = &self.network_policy.access {
|
||||
if let Some(r) = self.router.as_ref() {
|
||||
let ent = r.list_arp_entry().await?;
|
||||
@ -72,6 +73,45 @@ impl LNVpsProvisioner {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_forward_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> {
|
||||
if let Some(dns) = &self.dns {
|
||||
let fwd = BasicRecord::forward(assignment)?;
|
||||
let ret_fwd = if fwd.id.is_some() {
|
||||
dns.update_record(&fwd).await?
|
||||
} else {
|
||||
dns.add_record(&fwd).await?
|
||||
};
|
||||
assignment.dns_forward = Some(ret_fwd.name);
|
||||
assignment.dns_forward_ref = Some(ret_fwd.id.context("Record id is missing")?);
|
||||
|
||||
// save to db
|
||||
if assignment.id != 0 {
|
||||
self.db.update_vm_ip_assignment(assignment).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_reverse_ip_dns(&self, assignment: &mut VmIpAssignment) -> Result<()> {
|
||||
if let Some(dns) = &self.dns {
|
||||
let ret_rev = if assignment.dns_reverse_ref.is_some() {
|
||||
dns.update_record(&BasicRecord::reverse(assignment)?)
|
||||
.await?
|
||||
} else {
|
||||
dns.add_record(&BasicRecord::reverse_to_fwd(assignment)?)
|
||||
.await?
|
||||
};
|
||||
assignment.dns_reverse = Some(ret_rev.value);
|
||||
assignment.dns_reverse_ref = Some(ret_rev.id.context("Record id is missing")?);
|
||||
|
||||
// save to db
|
||||
if assignment.id != 0 {
|
||||
self.db.update_vm_ip_assignment(assignment).await?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn save_ip_assignment(&self, vm: &Vm, assignment: &mut VmIpAssignment) -> Result<()> {
|
||||
let ip = IpAddr::from_str(&assignment.ip)?;
|
||||
|
||||
@ -91,27 +131,11 @@ impl LNVpsProvisioner {
|
||||
}
|
||||
|
||||
// Add DNS records
|
||||
if let Some(dns) = &self.dns {
|
||||
let sub_name = format!("vm-{}", vm.id);
|
||||
let fwd = dns.add_a_record(&sub_name, ip.clone()).await?;
|
||||
assignment.dns_forward = Some(fwd.name.clone());
|
||||
assignment.dns_forward_ref = fwd.id;
|
||||
|
||||
match ip {
|
||||
IpAddr::V4(ip) => {
|
||||
let last_octet = ip.octets()[3].to_string();
|
||||
let rev = dns.add_ptr_record(&last_octet, &fwd.name).await?;
|
||||
assignment.dns_reverse = Some(fwd.name.clone());
|
||||
assignment.dns_reverse_ref = rev.id;
|
||||
}
|
||||
IpAddr::V6(_) => {
|
||||
warn!("IPv6 forward DNS not supported yet")
|
||||
}
|
||||
}
|
||||
}
|
||||
self.update_forward_ip_dns(assignment).await?;
|
||||
self.update_reverse_ip_dns(assignment).await?;
|
||||
|
||||
// save to db
|
||||
self.db.insert_vm_ip_assignment(&assignment).await?;
|
||||
self.db.insert_vm_ip_assignment(assignment).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -416,7 +440,7 @@ mod tests {
|
||||
assert!(ip.dns_forward.is_some());
|
||||
assert!(ip.dns_reverse.is_some());
|
||||
assert!(ip.dns_reverse_ref.is_some());
|
||||
assert!(ip.dns_forward.is_some());
|
||||
assert!(ip.dns_forward_ref.is_some());
|
||||
assert_eq!(ip.dns_reverse, ip.dns_forward);
|
||||
|
||||
// assert IP address is not CIDR
|
||||
|
Reference in New Issue
Block a user