Index: linux/drivers/net/wireguard/allowedips.c =================================================================== --- linux.orig/drivers/net/wireguard/allowedips.c +++ linux/drivers/net/wireguard/allowedips.c @@ -5,6 +5,8 @@ #include "allowedips.h" #include "peer.h" +#include "net/dst.h" +#include "net/route.h" enum { MAX_ALLOWEDIPS_DEPTH = 129 }; @@ -356,6 +358,15 @@ int wg_allowedips_read_node(struct allow struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb) { + const struct dst_entry *dst = skb_dst(skb); + if (dst) { + const struct rtable *rt = container_of(dst, struct rtable, dst); + if (rt->rt_gw_family == AF_INET) + return lookup(table->root4, 32, &rt->rt_gw4); + else if (rt->rt_gw_family == AF_INET6) + return lookup(table->root6, 128, &rt->rt_gw6); + } + if (skb->protocol == htons(ETH_P_IP)) return lookup(table->root4, 32, &ip_hdr(skb)->daddr); else if (skb->protocol == htons(ETH_P_IPV6)) @@ -363,17 +374,6 @@ struct wg_peer *wg_allowedips_lookup_dst return NULL; } -/* Returns a strong reference to a peer */ -struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, - struct sk_buff *skb) -{ - if (skb->protocol == htons(ETH_P_IP)) - return lookup(table->root4, 32, &ip_hdr(skb)->saddr); - else if (skb->protocol == htons(ETH_P_IPV6)) - return lookup(table->root6, 128, &ipv6_hdr(skb)->saddr); - return NULL; -} - int __init wg_allowedips_slab_init(void) { node_cache = KMEM_CACHE(allowedips_node, 0); Index: linux/drivers/net/wireguard/receive.c =================================================================== --- linux.orig/drivers/net/wireguard/receive.c +++ linux/drivers/net/wireguard/receive.c @@ -338,7 +338,6 @@ static void wg_packet_consume_data_done( { struct net_device *dev = peer->device->dev; unsigned int len, len_before_trim; - struct wg_peer *routed_peer; wg_socket_set_peer_endpoint(peer, endpoint); @@ -401,24 +400,10 @@ static void wg_packet_consume_data_done( if (unlikely(pskb_trim(skb, len))) goto packet_processed; - routed_peer = wg_allowedips_lookup_src(&peer->device->peer_allowedips, - skb); - wg_peer_put(routed_peer); /* We don't need the extra reference. */ - - if (unlikely(routed_peer != peer)) - goto dishonest_packet_peer; - napi_gro_receive(&peer->napi, skb); update_rx_stats(peer, message_data_len(len_before_trim)); return; -dishonest_packet_peer: - net_dbg_skb_ratelimited("%s: Packet has unallowed src IP (%pISc) from peer %llu (%pISpfsc)\n", - dev->name, skb, peer->internal_id, - &peer->endpoint.addr); - ++dev->stats.rx_errors; - ++dev->stats.rx_frame_errors; - goto packet_processed; dishonest_packet_type: net_dbg_ratelimited("%s: Packet is neither ipv4 nor ipv6 from peer %llu (%pISpfsc)\n", dev->name, peer->internal_id, &peer->endpoint.addr); Index: linux/drivers/net/wireguard/allowedips.h =================================================================== --- linux.orig/drivers/net/wireguard/allowedips.h +++ linux/drivers/net/wireguard/allowedips.h @@ -43,11 +43,9 @@ void wg_allowedips_remove_by_peer(struct /* The ip input pointer should be __aligned(__alignof(u64))) */ int wg_allowedips_read_node(struct allowedips_node *node, u8 ip[16], u8 *cidr); -/* These return a strong reference to a peer: */ +/* Returns a strong reference to a peer: */ struct wg_peer *wg_allowedips_lookup_dst(struct allowedips *table, struct sk_buff *skb); -struct wg_peer *wg_allowedips_lookup_src(struct allowedips *table, - struct sk_buff *skb); #ifdef DEBUG bool wg_allowedips_selftest(void);