diff --git a/src/config.rs b/src/config.rs index 01033d438e..4435281a2c 100644 --- a/src/config.rs +++ b/src/config.rs @@ -69,6 +69,7 @@ const CONNECTION_TERMINATION_DEADLINE: &str = "CONNECTION_TERMINATION_DEADLINE"; // (Our forceful shutdown is more graceful than a SIGKILL, as we can close connections cleanly). const TERMINATION_GRACE_PERIOD_SECONDS: &str = "TERMINATION_GRACE_PERIOD_SECONDS"; const ENABLE_ORIG_SRC: &str = "ENABLE_ORIG_SRC"; +const ENABLE_OUTBOUND_ORIG_SRC: &str = "ENABLE_OUTBOUND_ORIG_SRC"; const PROXY_CONFIG: &str = "PROXY_CONFIG"; const IPV6_ENABLED: &str = "IPV6_ENABLED"; @@ -280,6 +281,11 @@ pub struct Config { // If unset (recommended), this is automatically detected based on permissions. pub require_original_source: Option, + // Enable source IP preservation for outbound connections. + // If set to true, outbound connections explicitly bind to the downstream peer address + // If false (default), the system determines the outbound address. + pub enable_outbound_original_source: bool, + // CLI args passed to ztunnel at runtime pub proxy_args: String, @@ -810,6 +816,7 @@ pub fn construct_config(pc: ProxyConfig) -> Result { )?, require_original_source: parse(ENABLE_ORIG_SRC)?, + enable_outbound_original_source: parse_default(ENABLE_OUTBOUND_ORIG_SRC, false)?, proxy_args: parse_args(), dns_resolver_cfg, dns_resolver_opts, @@ -1224,4 +1231,24 @@ pub mod tests { env::remove_var(ZTUNNEL_CPU_LIMIT); } } + + #[test] + fn test_enable_outbound_original_source_parsing() { + unsafe { + // Test explicitly enabled + env::set_var(ENABLE_OUTBOUND_ORIG_SRC, "true"); + let cfg = construct_config(ProxyConfig::default()).unwrap(); + assert!(cfg.enable_outbound_original_source); + + // Test explicitly disabled + env::set_var(ENABLE_OUTBOUND_ORIG_SRC, "false"); + let cfg = construct_config(ProxyConfig::default()).unwrap(); + assert!(!cfg.enable_outbound_original_source); + + // Test unset (default is false) + env::remove_var(ENABLE_OUTBOUND_ORIG_SRC); + let cfg = construct_config(ProxyConfig::default()).unwrap(); + assert!(!cfg.enable_outbound_original_source); + } + } } diff --git a/src/proxy/outbound.rs b/src/proxy/outbound.rs index 644b6c132b..db91f68d45 100644 --- a/src/proxy/outbound.rs +++ b/src/proxy/outbound.rs @@ -227,7 +227,7 @@ impl OutboundConnection { .await } OutboundProtocol::TCP => { - self.proxy_to_tcp(source_stream, &req, connection_result_builder) + self.proxy_to_tcp(source_stream, source_addr, &req, connection_result_builder) .await } }; @@ -397,14 +397,19 @@ impl OutboundConnection { async fn proxy_to_tcp( &mut self, stream: TcpStream, + source_addr: SocketAddr, req: &Request, connection_stats_builder: Box, ) { let connection_stats = Box::new(connection_stats_builder.build()); - + let local = if self.pi.cfg.enable_outbound_original_source { + Some(source_addr.ip()) + } else { + None + }; let res = (async { let outbound = super::freebind_connect( - None, // No need to spoof source IP on outbound + local, req.actual_destination, self.pi.socket_factory.as_ref(), ) diff --git a/src/proxy/pool.rs b/src/proxy/pool.rs index e69d92797d..81d04489dc 100644 --- a/src/proxy/pool.rs +++ b/src/proxy/pool.rs @@ -83,7 +83,12 @@ impl ConnSpawner { let cert = self.local_workload.fetch_certificate().await?; let connector = cert.outbound_connector(key.dst_id.clone())?; - let tcp_stream = super::freebind_connect(None, key.dst, self.socket_factory.as_ref()) + let local = if self.cfg.enable_outbound_original_source { + Some(key.src) + } else { + None + }; + let tcp_stream = super::freebind_connect(local, key.dst, self.socket_factory.as_ref()) .await .map_err(|e: io::Error| match e.kind() { io::ErrorKind::TimedOut => Error::MaybeHBONENetworkPolicyError(e),