Development discussion of WireGuard
 help / color / mirror / Atom feed
* [PATCH] [wireguard-go] Pool for endpoint objects
@ 2022-04-24  9:20 Natan Elul
  0 siblings, 0 replies; only message in thread
From: Natan Elul @ 2022-04-24  9:20 UTC (permalink / raw)
  To: wireguard

Use sync.pool for endpoints to avoid memory allocations on each receive.
When an endpoint is returned in bind linux, go allocates memory on the heap
By using sync.pool, the allocations can be reused, and can
dramatically be more efficient.

This patch includes the changes for linux bind, and an optimization
for SetEndpointFor packet, that will use the lock only if needed.

Signed-off-by: Natan Elul <elul.natan@gmail.com>
---
 conn/bind_linux.go        | 58 +++++++++++++++++++++++++++++++--------
 conn/bind_std.go          | 15 ++++++++++
 conn/bind_windows.go      | 14 ++++++++++
 conn/bindtest/bindtest.go | 11 ++++++++
 conn/conn.go              |  5 ++++
 device/peer.go            | 11 +++++++-
 device/pools.go           |  1 +
 device/receive.go         |  1 +
 8 files changed, 103 insertions(+), 13 deletions(-)

diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index f11f031..5630481 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -38,6 +38,24 @@ func (endpoint *LinuxSocketEndpoint) Src4()
*ipv4Source         { return endpoin
 func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 {
return endpoint.dst4() }
 func (endpoint *LinuxSocketEndpoint) IsV6() bool                {
return endpoint.isV6 }

+func (endpoint *LinuxSocketEndpoint) IsEqual(ep Endpoint) bool {
+ // Protect from mutable sendmsg
+ endpoint.mu.Lock()
+ defer endpoint.mu.Unlock()
+
+ linuxEp := ep.(*LinuxSocketEndpoint)
+ return endpoint.dst == linuxEp.dst && endpoint.src == linuxEp.src
+}
+
+func (endpoint *LinuxSocketEndpoint) Copy() Endpoint {
+ return &LinuxSocketEndpoint{
+ mu:   sync.Mutex{},
+ dst:  endpoint.dst,
+ src:  endpoint.src,
+ isV6: endpoint.isV6,
+ }
+}
+
 func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
  return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
 }
@@ -58,13 +76,28 @@ func (endpoint *LinuxSocketEndpoint) dst6()
*unix.SockaddrInet6 {
 type LinuxSocketBind struct {
  // mu guards sock4 and sock6 and the associated fds.
  // As long as someone holds mu (read or write), the associated fds are valid.
- mu    sync.RWMutex
- sock4 int
- sock6 int
+ mu             sync.RWMutex
+ sock4          int
+ sock6          int
+ epElementsPool sync.Pool
+}
+
+func (bind *LinuxSocketBind) GetEndpoint() *LinuxSocketEndpoint {
+ return bind.epElementsPool.Get().(*LinuxSocketEndpoint)
 }

-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1,
sock6: -1} }
-func NewDefaultBind() Bind     { return NewLinuxSocketBind() }
+func (bind *LinuxSocketBind) PutEndpoint(endpoint Endpoint) {
+ bind.epElementsPool.Put(endpoint)
+}
+
+func NewLinuxSocketBind() Bind {
+ return &LinuxSocketBind{sock4: -1, sock6: -1,
+ epElementsPool: sync.Pool{New: func() interface{} {
+ return new(LinuxSocketEndpoint)
+ }}}
+}
+
+func NewDefaultBind() Bind { return NewLinuxSocketBind() }

 var (
  _ Endpoint = (*LinuxSocketEndpoint)(nil)
@@ -224,14 +257,14 @@ func (bind *LinuxSocketBind) Close() error {
 }

 func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+ end := bind.GetEndpoint()
  bind.mu.RLock()
  defer bind.mu.RUnlock()
  if bind.sock4 == -1 {
  return 0, nil, net.ErrClosed
  }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
+ n, err := receive4(bind.sock4, buf, end)
+ return n, end, err
 }

 func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
@@ -448,11 +481,12 @@ func send4(sock int, end *LinuxSocketEndpoint,
buff []byte) error {
  // clear src and retry

  if err == unix.EINVAL {
- end.ClearSrc()
+ // clear source writing to source ip that can collide with isEqual
read. this is a rare execution code, so we will just
+ // create a copy and use it instead. (avoid write)
+ newEndpoint := end.Copy().(*LinuxSocketEndpoint)
+ newEndpoint.ClearSrc()
  cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff,
(*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
+ _, err = unix.SendmsgN(sock, buff,
(*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:],
newEndpoint.dst4(), 0)
  }

  return err
diff --git a/conn/bind_std.go b/conn/bind_std.go
index e0f6cdd..4306f7f 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -27,8 +27,23 @@ type StdNetBind struct {

 func NewStdNetBind() Bind { return &StdNetBind{} }

+func (bind *StdNetBind) PutEndpoint(endpoint Endpoint) {
+}
+
 type StdNetEndpoint netip.AddrPort

+func (e *StdNetEndpoint) IsEqual(endpoint Endpoint) bool {
+ addrPort := (*netip.AddrPort)(e)
+ addrPortParam := (*netip.AddrPort)(endpoint.(*StdNetEndpoint))
+ return addrPort.Port() == addrPortParam.Port() &&
addrPort.Addr().Compare(addrPortParam.Addr()) == 0
+}
+
+func (e *StdNetEndpoint) Copy() Endpoint {
+ addrPortString := (*netip.AddrPort)(e).String()
+ copyEndpoint, _ := netip.ParseAddrPort(addrPortString)
+ return (*StdNetEndpoint)(&copyEndpoint)
+}
+
 var (
  _ Bind     = (*StdNetBind)(nil)
  _ Endpoint = (*StdNetEndpoint)(nil)
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 9268bc1..d0a1d66 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -77,6 +77,9 @@ type WinRingBind struct {
  isOpen uint32
 }

+func (bind *WinRingBind) PutEndpoint(endpoint Endpoint) {
+}
+
 func NewDefaultBind() Bind { return NewWinRingBind() }

 func NewWinRingBind() Bind {
@@ -131,6 +134,17 @@ func (*WinRingBind) ParseEndpoint(s string)
(Endpoint, error) {

 func (*WinRingEndpoint) ClearSrc() {}

+func (e *WinRingEndpoint) IsEqual(endpoint Endpoint) bool {
+ winEndpoint := endpoint.(*WinRingEndpoint)
+ return winEndpoint.family == e.family && winEndpoint.data == e.data
+}
+
+func (e *WinRingEndpoint) Copy() Endpoint {
+ return &WinRingEndpoint{
+ family: e.family,
+ data:   e.data,
+ }
+}
 func (e *WinRingEndpoint) DstIP() netip.Addr {
  switch e.family {
  case windows.AF_INET:
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index b38cae6..a04ccc9 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -23,8 +23,19 @@ type ChannelBind struct {
  target4, target6 ChannelEndpoint
 }

+func (c *ChannelBind) PutEndpoint(endpoint conn.Endpoint) {
+}
+
 type ChannelEndpoint uint16

+func (c ChannelEndpoint) IsEqual(endpoint conn.Endpoint) bool {
+ return c == endpoint.(ChannelEndpoint)
+}
+
+func (c ChannelEndpoint) Copy() conn.Endpoint {
+ return c
+}
+
 var (
  _ conn.Bind     = (*ChannelBind)(nil)
  _ conn.Endpoint = (*ChannelEndpoint)(nil)
diff --git a/conn/conn.go b/conn/conn.go
index 5a93b2b..c772b25 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -43,6 +43,9 @@ type Bind interface {

  // ParseEndpoint creates a new endpoint from a string.
  ParseEndpoint(s string) (Endpoint, error)
+
+ // PutEndpoint returns endpoint back to pool
+ PutEndpoint(endpoint Endpoint)
 }

 // BindSocketToInterface is implemented by Bind objects that support being
@@ -70,6 +73,8 @@ type Endpoint interface {
  DstToBytes() []byte  // used for mac2 cookie calculations
  DstIP() netip.Addr
  SrcIP() netip.Addr
+ IsEqual(endpoint Endpoint) bool
+ Copy() Endpoint
 }

 var (
diff --git a/device/peer.go b/device/peer.go
index 5bd52df..eb1cc41 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -271,7 +271,16 @@ func (peer *Peer) SetEndpointFromPacket(endpoint
conn.Endpoint) {
  if peer.disableRoaming {
  return
  }
+
+ peer.RLock()
+ if peer.endpoint.IsEqual(endpoint) {
+ peer.RUnlock()
+ return
+ }
+
+ peer.RUnlock()
+
  peer.Lock()
- peer.endpoint = endpoint
+ peer.endpoint = endpoint.Copy()
  peer.Unlock()
 }
diff --git a/device/pools.go b/device/pools.go
index f40477b..f861c51 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -70,6 +70,7 @@ func (device *Device) GetInboundElement()
*QueueInboundElement {
 }

 func (device *Device) PutInboundElement(elem *QueueInboundElement) {
+ device.net.bind.PutEndpoint(elem.endpoint)
  elem.clearPointers()
  device.pool.inboundElements.Put(elem)
 }
diff --git a/device/receive.go b/device/receive.go
index cc34498..e6cebcd 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -390,6 +390,7 @@ func (device *Device) RoutineHandshake(id int) {
  peer.SendKeepalive()
  }
  skip:
+ device.net.bind.PutEndpoint(elem.endpoint)
  device.PutMessageBuffer(elem.buffer)
  }
 }
-- 
2.30.1 (Apple Git-130)

^ permalink raw reply	[flat|nested] only message in thread

only message in thread, other threads:[~2022-04-24 20:19 UTC | newest]

Thread overview: (only message) (download: mbox.gz / follow: Atom feed)
-- links below jump to the message on this page --
2022-04-24  9:20 [PATCH] [wireguard-go] Pool for endpoint objects Natan Elul

This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for NNTP newsgroup(s).