Skip to content

Commit f641800

Browse files
authored
Test that local address actually changes (#12595)
1 parent 52f8f19 commit f641800

File tree

4 files changed

+109
-53
lines changed

4 files changed

+109
-53
lines changed

crates/test-programs/src/bin/p3_sockets_udp_connect.rs

Lines changed: 65 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
1-
use test_programs::p3::wasi::sockets::types::{
2-
ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, UdpSocket,
1+
use test_programs::{
2+
p3::wasi::sockets::types::{
3+
ErrorCode, IpAddress, IpAddressFamily, IpSocketAddress, Ipv4Address, Ipv6Address, UdpSocket,
4+
},
5+
sockets::supports_ipv6,
36
};
47

58
struct Component;
69

710
test_programs::p3::export!(Component);
811

9-
const SOME_PORT: u16 = 47; // If the tests pass, this will never actually be connected to.
12+
// If the tests work as expected, these will never actually be connected to:
13+
const SOME_PORT: u16 = 47;
14+
const SOME_PUBLIC_IPV4: Ipv4Address = (123, 234, 12, 34);
15+
const SOME_PUBLIC_IPV6: Ipv6Address = (123, 234, 0, 0, 0, 0, 0, 34);
1016

1117
fn test_udp_connect_disconnect_reconnect(family: IpAddressFamily) {
1218
let remote1 = IpSocketAddress::new(IpAddress::new_loopback(family), 4321);
@@ -48,6 +54,38 @@ fn test_udp_connect_unspec(family: IpAddressFamily) {
4854
));
4955
}
5056

57+
/// If not explicitly bound, connecting a UDP socket should update the local
58+
/// address to reflect the best network path.
59+
fn test_udp_connect_local_address_change(family: IpAddressFamily) {
60+
fn connect(sock: &UdpSocket, ip: IpAddress, port: u16) -> IpSocketAddress {
61+
let remote = IpSocketAddress::new(ip, port);
62+
sock.connect(remote).unwrap();
63+
let local = sock.get_local_address().unwrap();
64+
println!("connect({remote:?}) changed local address to: {local:?}",);
65+
local
66+
}
67+
68+
if !has_public_interface(family) {
69+
println!("No public interface detected, skipping test");
70+
return;
71+
}
72+
73+
let loopback_ip = IpAddress::new_loopback(family);
74+
let public_ip = some_public_ip(family);
75+
76+
let client = UdpSocket::create(family).unwrap();
77+
78+
let loopback_if1 = connect(&client, loopback_ip, 4321);
79+
let loopback_if2 = connect(&client, loopback_ip, 4322);
80+
let public_if = connect(&client, public_ip, 4323);
81+
82+
// Note: these assertions are based on observed behavior on Linux, MacOS and
83+
// Windows, but there is nothing in their official documentation to
84+
// corroborate this.
85+
assert_eq!(loopback_if1, loopback_if2);
86+
assert_ne!(loopback_if1, public_if);
87+
}
88+
5189
/// 0 is not a valid remote port.
5290
fn test_udp_connect_port_0(family: IpAddressFamily) {
5391
let addr = IpSocketAddress::new(IpAddress::new_loopback(family), 0);
@@ -131,26 +169,38 @@ fn test_udp_connect_dual_stack() {
131169
impl test_programs::p3::exports::wasi::cli::run::Guest for Component {
132170
async fn run() -> Result<(), ()> {
133171
test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv4);
134-
test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv6);
135-
136172
test_udp_connect_unspec(IpAddressFamily::Ipv4);
137-
test_udp_connect_unspec(IpAddressFamily::Ipv6);
138-
173+
test_udp_connect_local_address_change(IpAddressFamily::Ipv4);
139174
test_udp_connect_port_0(IpAddressFamily::Ipv4);
140-
test_udp_connect_port_0(IpAddressFamily::Ipv6);
141-
142175
test_udp_connect_wrong_family(IpAddressFamily::Ipv4);
143-
test_udp_connect_wrong_family(IpAddressFamily::Ipv6);
144-
145176
test_udp_connect_without_bind(IpAddressFamily::Ipv4);
146-
test_udp_connect_without_bind(IpAddressFamily::Ipv6);
147-
148177
test_udp_connect_with_bind(IpAddressFamily::Ipv4);
149-
test_udp_connect_with_bind(IpAddressFamily::Ipv6);
150178

151-
test_udp_connect_dual_stack();
179+
if supports_ipv6() {
180+
test_udp_connect_disconnect_reconnect(IpAddressFamily::Ipv6);
181+
test_udp_connect_unspec(IpAddressFamily::Ipv6);
182+
test_udp_connect_local_address_change(IpAddressFamily::Ipv6);
183+
test_udp_connect_port_0(IpAddressFamily::Ipv6);
184+
test_udp_connect_wrong_family(IpAddressFamily::Ipv6);
185+
test_udp_connect_without_bind(IpAddressFamily::Ipv6);
186+
test_udp_connect_with_bind(IpAddressFamily::Ipv6);
187+
test_udp_connect_dual_stack();
188+
}
152189
Ok(())
153190
}
154191
}
155192

193+
fn some_public_ip(family: IpAddressFamily) -> IpAddress {
194+
match family {
195+
IpAddressFamily::Ipv4 => IpAddress::Ipv4(SOME_PUBLIC_IPV4),
196+
IpAddressFamily::Ipv6 => IpAddress::Ipv6(SOME_PUBLIC_IPV6),
197+
}
198+
}
199+
200+
fn has_public_interface(family: IpAddressFamily) -> bool {
201+
let sock = UdpSocket::create(family).unwrap();
202+
sock.connect(IpSocketAddress::new(some_public_ip(family), SOME_PORT))
203+
.is_ok()
204+
}
205+
156206
fn main() {}

crates/wasi/src/p2/host/udp.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -64,24 +64,14 @@ impl udp::HostUdpSocket for WasiSocketsCtxView<'_> {
6464
return Err(ErrorCode::InvalidState.into());
6565
}
6666

67-
// We disconnect & (re)connect in two distinct steps for two reasons:
68-
// - To leave our socket instance in a consistent state in case the
69-
// connect fails.
70-
// - When reconnecting to a different address, Linux sometimes fails
71-
// if there isn't a disconnect in between.
72-
73-
// Step #1: Disconnect
74-
if socket.is_connected() {
75-
socket.disconnect()?;
76-
}
77-
78-
// Step #2: (Re)connect
7967
if let Some(connect_addr) = remote_address {
8068
let Some(check) = socket.socket_addr_check() else {
8169
return Err(ErrorCode::InvalidState.into());
8270
};
8371
check.check(connect_addr, SocketAddrUse::UdpConnect).await?;
8472
socket.connect_p2(connect_addr)?;
73+
} else if socket.is_connected() {
74+
socket.disconnect()?;
8575
}
8676

8777
let incoming_stream = IncomingDatagramStream {

crates/wasi/src/sockets/udp.rs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,13 @@ use crate::runtime::with_ambient_tokio_runtime;
22
use crate::sockets::util::{
33
ErrorCode, get_unicast_hop_limit, is_valid_address_family, is_valid_remote_address,
44
receive_buffer_size, send_buffer_size, set_receive_buffer_size, set_send_buffer_size,
5-
set_unicast_hop_limit, udp_bind, udp_disconnect, udp_socket,
5+
set_unicast_hop_limit, udp_bind, udp_connect, udp_disconnect, udp_socket,
66
};
77
use crate::sockets::{SocketAddrCheck, SocketAddressFamily, WasiSocketsCtx};
88
use cap_net_ext::AddressFamily;
99
use io_lifetimes::AsSocketlike as _;
1010
use io_lifetimes::raw::{FromRawSocketlike as _, IntoRawSocketlike as _};
1111
use rustix::io::Errno;
12-
use rustix::net::connect;
1312
use std::net::SocketAddr;
1413
use std::sync::Arc;
1514
use tracing::debug;
@@ -155,28 +154,26 @@ impl UdpSocket {
155154
return Err(ErrorCode::InvalidArgument);
156155
}
157156

158-
// We disconnect & (re)connect in two distinct steps for two reasons:
159-
// - To leave our socket instance in a consistent state in case the
160-
// connect fails.
161-
// - When reconnecting to a different address, Linux sometimes fails
162-
// if there isn't a disconnect in between.
157+
match udp_connect(&self.socket, addr) {
158+
Ok(()) => {
159+
self.udp_state = UdpState::Connected(addr);
160+
Ok(())
161+
}
162+
Err(e) => {
163+
// Revert to a consistent state:
164+
_ = udp_disconnect(&self.socket);
165+
self.udp_state = UdpState::Bound;
163166

164-
// Step #1: Disconnect
165-
if let UdpState::Connected(..) = self.udp_state {
166-
udp_disconnect(&self.socket)?;
167-
self.udp_state = UdpState::Bound;
168-
}
169-
// Step #2: (Re)connect
170-
connect(&self.socket, &addr).map_err(|error| match error {
171-
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.
172-
Errno::INPROGRESS => {
173-
debug!("UDP connect returned EINPROGRESS, which should never happen");
174-
ErrorCode::Unknown
167+
Err(match e {
168+
Errno::AFNOSUPPORT => ErrorCode::InvalidArgument, // See `udp_bind` implementation.
169+
Errno::INPROGRESS => {
170+
debug!("UDP connect returned EINPROGRESS, which should never happen");
171+
ErrorCode::Unknown
172+
}
173+
err => err.into(),
174+
})
175175
}
176-
err => err.into(),
177-
})?;
178-
self.udp_state = UdpState::Connected(addr);
179-
Ok(())
176+
}
180177
}
181178

182179
/// Send data using p3 semantics. (with implicit bind)

crates/wasi/src/sockets/util.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use core::time::Duration;
66
use cap_net_ext::{AddressFamily, Blocking, UdpSocketExt};
77
use rustix::fd::AsFd;
88
use rustix::io::Errno;
9-
use rustix::net::{bind, connect_unspec, sockopt};
9+
use rustix::net::{bind, connect, connect_unspec, sockopt};
1010
use tracing::debug;
1111

1212
use crate::sockets::SocketAddressFamily;
@@ -396,7 +396,27 @@ pub fn udp_bind(sockfd: impl AsFd, addr: SocketAddr) -> Result<(), ErrorCode> {
396396
})
397397
}
398398

399-
pub fn udp_disconnect(sockfd: impl AsFd) -> Result<(), ErrorCode> {
399+
pub fn udp_connect(sockfd: impl AsFd, addr: SocketAddr) -> Result<(), Errno> {
400+
match connect(sockfd.as_fd(), &addr) {
401+
// When connecting a UDP socket, the OS looks up the best route to the
402+
// remote address and selects an appropriate outgoing interface.
403+
// If the new destination routes through an interface different than the
404+
// previously selected interface, most operating systems will
405+
// automatically update the socket's local address to match that route.
406+
//
407+
// Linux however doesn't do that automatically and we manually
408+
// dissolve the existing association and then connect again to the
409+
// new destination.
410+
#[cfg(target_os = "linux")]
411+
Err(Errno::INVAL) => {
412+
_ = udp_disconnect(sockfd.as_fd());
413+
return connect(sockfd.as_fd(), &addr);
414+
}
415+
r => r,
416+
}
417+
}
418+
419+
pub fn udp_disconnect(sockfd: impl AsFd) -> Result<(), Errno> {
400420
match connect_unspec(sockfd) {
401421
// BSD platforms return an error even if the UDP socket was disconnected successfully.
402422
//
@@ -411,8 +431,7 @@ pub fn udp_disconnect(sockfd: impl AsFd) -> Result<(), ErrorCode> {
411431
// address family of the socket.
412432
#[cfg(target_os = "macos")]
413433
Err(Errno::INVAL | Errno::AFNOSUPPORT) => Ok(()),
414-
Err(err) => Err(err.into()),
415-
Ok(()) => Ok(()),
434+
r => r,
416435
}
417436
}
418437

0 commit comments

Comments
 (0)