Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions devolutions-agent/src/endpoint.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
//! Helpers for parsing `host:port` endpoint strings used by the agent tunnel.
//!
//! The agent persists gateway endpoints as `format_endpoint(host, port)` (see
//! [`crate::enrollment::format_endpoint`]) — DNS / IPv4 stay as-is, IPv6
//! literals are wrapped in brackets: `[fd00::7]:4433`.
//!
//! When that string is later split back into `(host, port)` we MUST drop the
//! brackets from the IPv6 host before passing it to Rustls / Quinn: Rustls'
//! [`rustls_pki_types::ServerName`] does not accept a bracketed IPv6 literal,
//! and a naive `rsplit_once(':')` would leave `[fd00::7]` as the "host" half.
//!
//! Both `tunnel.rs` (runtime) and `verify_tunnel` (one-shot probe) need this
//! same split, hence the shared module.

use anyhow::{Context as _, Result, bail};

/// The host part of a parsed endpoint, ready to be used as a TLS server name
/// and/or DNS-resolved.
///
/// IPv6 literals are returned **without** surrounding brackets — that's the
/// form Rustls expects.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EndpointHost(String);

impl EndpointHost {
/// View the host as a plain string (no brackets for IPv6 literals).
pub fn as_str(&self) -> &str {
&self.0
}
}

impl std::fmt::Display for EndpointHost {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&self.0)
}
}

/// Split a `host:port` endpoint string into its host and port components.
///
/// Accepts:
/// - `gateway.example.com:4433` (DNS)
/// - `10.10.0.7:4433` (IPv4)
/// - `[fd00::7]:4433` (IPv6 literal, bracketed)
///
/// The returned host is always unbracketed — safe to pass to
/// [`rustls_pki_types::ServerName::try_from`] and to DNS resolvers. The full
/// original string (with brackets, if any) is still appropriate for
/// `tokio::net::lookup_host` because both bracketed and unbracketed IPv6
/// `host:port` forms are accepted there; callers that already have the raw
/// endpoint can keep using it directly for lookup.
pub fn split_endpoint(endpoint: &str) -> Result<(EndpointHost, u16)> {
let trimmed = endpoint.trim();
if trimmed.is_empty() {
bail!("endpoint is empty");
}

// IPv6 bracketed form first: "[<host>]:<port>".
if let Some(after_open) = trimmed.strip_prefix('[') {
let (host_part, rest) = after_open
.split_once(']')
.with_context(|| format!("missing ']' in bracketed endpoint: {endpoint}"))?;
let port_str = rest
.strip_prefix(':')
.with_context(|| format!("missing ':' after ']' in bracketed endpoint: {endpoint}"))?;
let port: u16 = port_str
.parse()
.with_context(|| format!("invalid port in endpoint: {endpoint}"))?;
if host_part.is_empty() {
bail!("empty host inside brackets: {endpoint}");
}
return Ok((EndpointHost(host_part.to_owned()), port));
}

// Unbracketed: DNS or IPv4. Split on the last ':' — DNS / IPv4 have no
// other colons in the host part.
let (host, port_str) = trimmed
.rsplit_once(':')
.with_context(|| format!("endpoint missing port: {endpoint}"))?;
if host.is_empty() {
bail!("empty host in endpoint: {endpoint}");
}
let port: u16 = port_str
.parse()
.with_context(|| format!("invalid port in endpoint: {endpoint}"))?;
Ok((EndpointHost(host.to_owned()), port))
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn split_dns_endpoint() {
let (host, port) = split_endpoint("gateway.example.com:4433").expect("dns");
assert_eq!(host.as_str(), "gateway.example.com");
assert_eq!(port, 4433);
}

#[test]
fn split_ipv4_endpoint() {
let (host, port) = split_endpoint("10.10.0.7:4433").expect("ipv4");
assert_eq!(host.as_str(), "10.10.0.7");
assert_eq!(port, 4433);
}

#[test]
fn split_ipv6_bracketed_endpoint_unbrackets_host() {
let (host, port) = split_endpoint("[fd00::7]:4433").expect("ipv6 bracketed");
// Critical: the host must NOT include the surrounding brackets so it
// can be passed straight to `rustls_pki_types::ServerName::try_from`.
assert_eq!(host.as_str(), "fd00::7");
assert_eq!(port, 4433);
}

#[test]
fn split_ipv6_bracketed_host_parses_as_rustls_server_name() {
let (host, _port) = split_endpoint("[fd00::7]:4433").expect("ipv6 bracketed");
let server_name = rustls_pki_types::ServerName::try_from(host.as_str().to_owned());
assert!(
server_name.is_ok(),
"unbracketed IPv6 literal must be a valid rustls ServerName, got: {:?}",
server_name.err()
);
}

#[test]
fn split_dns_host_parses_as_rustls_server_name() {
let (host, _port) = split_endpoint("gateway.example.com:4433").expect("dns");
let server_name = rustls_pki_types::ServerName::try_from(host.as_str().to_owned());
assert!(server_name.is_ok());
}

#[test]
fn split_rejects_missing_port() {
let err = split_endpoint("gateway.example.com").expect_err("must reject");
let msg = format!("{err:#}");
assert!(msg.contains("missing port"), "got: {msg}");
}

#[test]
fn split_rejects_empty_host_brackets() {
let err = split_endpoint("[]:4433").expect_err("must reject empty brackets");
let msg = format!("{err:#}");
assert!(msg.contains("empty host"), "got: {msg}");
}

#[test]
fn split_rejects_unparseable_port() {
let err = split_endpoint("gateway.example.com:notaport").expect_err("must reject");
let msg = format!("{err:#}");
assert!(msg.contains("invalid port"), "got: {msg}");
}
}
187 changes: 183 additions & 4 deletions devolutions-agent/src/enrollment.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,26 @@ struct EnrollRequest {
}

/// Response from enrollment API
///
/// The compat bridge (per the identity refactor design) means both
/// `quic_endpoint` and `quic_port` may be present:
///
/// - `quic_port` is the canonical new field. Agents should pair it with the
/// host they already enrolled through (parsed from the JWT's `jet_gw_url`).
/// - `quic_endpoint` is kept for one release so older gateways still work.
///
/// Both fields are `#[serde(default)]` so the deserializer accepts either or
/// both. After enroll, the agent picks `quic_port` when available, otherwise
/// it parses the port off `quic_endpoint`.
#[derive(Deserialize)]
struct EnrollResponse {
agent_id: Uuid,
client_cert_pem: String,
gateway_ca_cert_pem: String,
quic_endpoint: String,
#[serde(default)]
quic_endpoint: Option<String>,
#[serde(default)]
quic_port: Option<u16>,
server_spki_sha256: String,
}

Expand Down Expand Up @@ -107,7 +121,23 @@ pub async fn enroll_agent(
let (key_pem, csr_pem) = generate_key_and_csr(agent_name)?;

let enroll_response = request_enrollment(gateway_url, enrollment_token, agent_name, &csr_pem).await?;
persist_enrollment_response(agent_name, advertise_subnets, enroll_response, &key_pem)

// The agent dials the QUIC tunnel at whichever host the operator already
// proved is reachable from this agent's network — that's `gateway_url`'s
// host. The Gateway tells the agent which *port* to dial (via `quic_port`),
// not which host. For older Gateways the host is parsed off the legacy
// `quic_endpoint` field.
let enrollment_host = url::Url::parse(gateway_url)
.ok()
.and_then(|u| u.host_str().map(str::to_owned));

persist_enrollment_response(
agent_name,
advertise_subnets,
enroll_response,
enrollment_host.as_deref(),
&key_pem,
)
}

/// Generate an ECDSA P-256 key pair and a CSR containing the agent name as CN.
Expand Down Expand Up @@ -169,10 +199,34 @@ fn persist_enrollment_response(
client_cert_pem,
gateway_ca_cert_pem,
quic_endpoint,
quic_port,
server_spki_sha256,
}: EnrollResponse,
enrollment_host: Option<&str>,
key_pem: &str,
) -> Result<PersistedEnrollment> {
// Pick the QUIC port: prefer the new `quic_port` field, otherwise parse
// the port off the legacy `quic_endpoint` (compat with older gateways).
let quic_port_resolved = if let Some(port) = quic_port {
port
} else {
let endpoint = quic_endpoint
.as_deref()
.context("enrollment response carries neither `quic_port` nor `quic_endpoint`")?;
parse_endpoint_port(endpoint).with_context(|| format!("parse legacy quic_endpoint {endpoint:?}"))?
};

// Compose the gateway endpoint from `(enrollment_host, quic_port)` when we
// know the enrollment host (new agents talking to new gateways and to old
// gateways alike). If the caller did not pass it — only possible when
// running against the unit tests or a malformed URL — fall back to the
// legacy `quic_endpoint` verbatim.
let resolved_endpoint = match enrollment_host {
Some(host) => format_endpoint(host, quic_port_resolved),
None => quic_endpoint
.clone()
.context("enrollment URL has no host and response did not include a usable quic_endpoint")?,
};
let config_path = config::get_conf_file_path();
let config_dir = config_path
.parent()
Expand Down Expand Up @@ -219,7 +273,7 @@ fn persist_enrollment_response(

let tunnel_conf = config::dto::TunnelConf {
enabled: true,
gateway_endpoint: quic_endpoint.clone(),
gateway_endpoint: resolved_endpoint.clone(),
client_cert_path: Some(client_cert_path.clone()),
client_key_path: Some(client_key_path.clone()),
gateway_ca_cert_path: Some(gateway_ca_path.clone()),
Expand All @@ -241,10 +295,55 @@ fn persist_enrollment_response(
client_cert_path,
client_key_path,
gateway_ca_path,
quic_endpoint,
quic_endpoint: resolved_endpoint,
})
}

/// Format a `host:port` endpoint string, bracketing IPv6 literals so the
/// resulting string is parseable as a `SocketAddr` and unambiguous to humans.
///
/// | host kind | output |
/// |---|---|
/// | DNS | `gateway.example.com:4433` |
/// | IPv4 | `10.10.0.7:4433` |
/// | IPv6 | `[fd00::7]:4433` |
///
/// The IPv6 case strips any pre-existing surrounding brackets first, so both
/// `fd00::7` and `[fd00::7]` produce the same canonical bracketed form.
pub fn format_endpoint(host: &str, port: u16) -> String {
let trimmed = host.trim();
// url::Url surfaces IPv6 hosts already bracketed; strip them here so we
// can detect "it's an IPv6 literal" by trying to parse as Ipv6Addr.
let unbracketed = trimmed
.strip_prefix('[')
.and_then(|s| s.strip_suffix(']'))
.unwrap_or(trimmed);
if unbracketed.parse::<std::net::Ipv6Addr>().is_ok() {
format!("[{unbracketed}]:{port}")
} else {
format!("{trimmed}:{port}")
}
}

/// Parse the port off a legacy `quic_endpoint` string of the form
/// `<host>:<port>` (DNS / IPv4) or `[<ipv6>]:<port>`.
fn parse_endpoint_port(endpoint: &str) -> Result<u16> {
let trimmed = endpoint.trim();
let port_str = if let Some(rest) = trimmed.rsplit_once(']') {
// IPv6: "[host]:port" — `rest.0` is "[host", `rest.1` is ":port".
rest.1
.strip_prefix(':')
.context("missing ':' before port in bracketed endpoint")?
} else {
// DNS / IPv4: "host:port" — split on the last ':' since DNS / IPv4 have no colons in the host.
trimmed
.rsplit_once(':')
.map(|(_, p)| p)
.context("missing ':' between host and port in endpoint")?
};
port_str.parse::<u16>().context("endpoint port is not a valid u16")
}

// ---------------------------------------------------------------------------
// Certificate renewal helpers
// ---------------------------------------------------------------------------
Expand Down Expand Up @@ -349,4 +448,84 @@ mod tests {
}));
assert!(parse_enrollment_jwt(&jwt).is_err());
}

// ---- format_endpoint -----------------------------------------------------

#[test]
fn format_endpoint_dns() {
assert_eq!(format_endpoint("gateway.example.com", 4433), "gateway.example.com:4433");
}

#[test]
fn format_endpoint_ipv4() {
assert_eq!(format_endpoint("10.10.0.7", 4433), "10.10.0.7:4433");
}

#[test]
fn format_endpoint_ipv6_bracketed() {
assert_eq!(format_endpoint("fd00::7", 4433), "[fd00::7]:4433");
}

#[test]
fn format_endpoint_ipv6_already_bracketed_input() {
// Defensive: if the caller already pre-bracketed (as `url::Url::host_str`
// does for IPv6), the helper still produces the canonical form once.
assert_eq!(format_endpoint("[fd00::7]", 4433), "[fd00::7]:4433");
}

// ---- parse_endpoint_port -------------------------------------------------

#[test]
fn parse_endpoint_port_dns() {
assert_eq!(parse_endpoint_port("gateway.example.com:4433").unwrap(), 4433);
}

#[test]
fn parse_endpoint_port_ipv4() {
assert_eq!(parse_endpoint_port("10.10.0.7:4433").unwrap(), 4433);
}

#[test]
fn parse_endpoint_port_ipv6_bracketed() {
assert_eq!(parse_endpoint_port("[fd00::7]:4433").unwrap(), 4433);
}

#[test]
fn parse_endpoint_port_rejects_no_colon() {
assert!(parse_endpoint_port("gateway.example.com").is_err());
}

// ---- EnrollResponse deserialization --------------------------------------

/// New gateway: both `quic_endpoint` and `quic_port` present. Agent prefers
/// `quic_port`.
#[test]
fn enroll_response_accepts_new_compat_bridge_payload() {
let body = serde_json::json!({
"agent_id": "00000000-0000-0000-0000-000000000001",
"client_cert_pem": "stub",
"gateway_ca_cert_pem": "stub",
"quic_endpoint": "10.10.0.7:4433",
"quic_port": 4433,
"server_spki_sha256": "deadbeef",
});
let parsed: EnrollResponse = serde_json::from_value(body).expect("parse new payload");
assert_eq!(parsed.quic_port, Some(4433));
assert_eq!(parsed.quic_endpoint.as_deref(), Some("10.10.0.7:4433"));
}

/// Legacy gateway: only `quic_endpoint`. Agent must fall back to parsing it.
#[test]
fn enroll_response_accepts_legacy_payload_without_quic_port() {
let body = serde_json::json!({
"agent_id": "00000000-0000-0000-0000-000000000001",
"client_cert_pem": "stub",
"gateway_ca_cert_pem": "stub",
"quic_endpoint": "10.10.0.7:4433",
"server_spki_sha256": "deadbeef",
});
let parsed: EnrollResponse = serde_json::from_value(body).expect("parse legacy payload");
assert_eq!(parsed.quic_port, None);
assert_eq!(parsed.quic_endpoint.as_deref(), Some("10.10.0.7:4433"));
}
}
Loading
Loading