diff --git a/include/net/sock.h b/include/net/sock.h index 66fd3951e6f3..73b7830b0bb8 100644 --- a/include/net/sock.h +++ b/include/net/sock.h @@ -72,6 +72,7 @@ #include #include #include +#include /* * This structure really needs to be cleaned up. @@ -2399,4 +2400,23 @@ static inline void sk_pacing_shift_update(struct sock *sk, int val) sk->sk_pacing_shift = val; } +/* if a socket is bound to a device, check that the given device + * index is either the same or that the socket is bound to an L3 + * master device and the given device index is also enslaved to + * that L3 master + */ +static inline bool sk_dev_equal_l3scope(struct sock *sk, int dif) +{ + int mdif; + + if (!sk->sk_bound_dev_if || sk->sk_bound_dev_if == dif) + return true; + + mdif = l3mdev_master_ifindex_by_index(sock_net(sk), dif); + if (mdif && mdif == sk->sk_bound_dev_if) + return true; + + return false; +} + #endif /* _SOCK_H */ diff --git a/net/ipv6/datagram.c b/net/ipv6/datagram.c index a1f918713006..fbf08ce3f5ab 100644 --- a/net/ipv6/datagram.c +++ b/net/ipv6/datagram.c @@ -221,8 +221,7 @@ int __ip6_datagram_connect(struct sock *sk, struct sockaddr *uaddr, if (__ipv6_addr_needs_scope_id(addr_type)) { if (addr_len >= sizeof(struct sockaddr_in6) && usin->sin6_scope_id) { - if (sk->sk_bound_dev_if && - sk->sk_bound_dev_if != usin->sin6_scope_id) { + if (!sk_dev_equal_l3scope(sk, usin->sin6_scope_id)) { err = -EINVAL; goto out; } diff --git a/net/ipv6/tcp_ipv6.c b/net/ipv6/tcp_ipv6.c index aa12a26a96c6..c0f7e69f2e6c 100644 --- a/net/ipv6/tcp_ipv6.c +++ b/net/ipv6/tcp_ipv6.c @@ -176,8 +176,7 @@ static int tcp_v6_connect(struct sock *sk, struct sockaddr *uaddr, /* If interface is set while binding, indices * must coincide. */ - if (sk->sk_bound_dev_if && - sk->sk_bound_dev_if != usin->sin6_scope_id) + if (!sk_dev_equal_l3scope(sk, usin->sin6_scope_id)) return -EINVAL; sk->sk_bound_dev_if = usin->sin6_scope_id;