Development discussion of WireGuard
 help / color / mirror / Atom feed
From: Natan Elul <elul.natan@gmail.com>
To: wireguard@lists.zx2c4.com
Subject: [PATCH] [wireguard-go] Pool for endpoint objects
Date: Sun, 24 Apr 2022 12:20:08 +0300	[thread overview]
Message-ID: <CALDoV20ghXUW8DGL=yt6a8_cjGstCbEhXO9SbP6QcVLa63t7eA@mail.gmail.com> (raw)

Use sync.pool for endpoints to avoid memory allocations on each receive.
When an endpoint is returned in bind linux, go allocates memory on the heap
By using sync.pool, the allocations can be reused, and can
dramatically be more efficient.

This patch includes the changes for linux bind, and an optimization
for SetEndpointFor packet, that will use the lock only if needed.

Signed-off-by: Natan Elul <elul.natan@gmail.com>
---
 conn/bind_linux.go        | 58 +++++++++++++++++++++++++++++++--------
 conn/bind_std.go          | 15 ++++++++++
 conn/bind_windows.go      | 14 ++++++++++
 conn/bindtest/bindtest.go | 11 ++++++++
 conn/conn.go              |  5 ++++
 device/peer.go            | 11 +++++++-
 device/pools.go           |  1 +
 device/receive.go         |  1 +
 8 files changed, 103 insertions(+), 13 deletions(-)

diff --git a/conn/bind_linux.go b/conn/bind_linux.go
index f11f031..5630481 100644
--- a/conn/bind_linux.go
+++ b/conn/bind_linux.go
@@ -38,6 +38,24 @@ func (endpoint *LinuxSocketEndpoint) Src4()
*ipv4Source         { return endpoin
 func (endpoint *LinuxSocketEndpoint) Dst4() *unix.SockaddrInet4 {
return endpoint.dst4() }
 func (endpoint *LinuxSocketEndpoint) IsV6() bool                {
return endpoint.isV6 }

+func (endpoint *LinuxSocketEndpoint) IsEqual(ep Endpoint) bool {
+ // Protect from mutable sendmsg
+ endpoint.mu.Lock()
+ defer endpoint.mu.Unlock()
+
+ linuxEp := ep.(*LinuxSocketEndpoint)
+ return endpoint.dst == linuxEp.dst && endpoint.src == linuxEp.src
+}
+
+func (endpoint *LinuxSocketEndpoint) Copy() Endpoint {
+ return &LinuxSocketEndpoint{
+ mu:   sync.Mutex{},
+ dst:  endpoint.dst,
+ src:  endpoint.src,
+ isV6: endpoint.isV6,
+ }
+}
+
 func (endpoint *LinuxSocketEndpoint) src4() *ipv4Source {
  return (*ipv4Source)(unsafe.Pointer(&endpoint.src[0]))
 }
@@ -58,13 +76,28 @@ func (endpoint *LinuxSocketEndpoint) dst6()
*unix.SockaddrInet6 {
 type LinuxSocketBind struct {
  // mu guards sock4 and sock6 and the associated fds.
  // As long as someone holds mu (read or write), the associated fds are valid.
- mu    sync.RWMutex
- sock4 int
- sock6 int
+ mu             sync.RWMutex
+ sock4          int
+ sock6          int
+ epElementsPool sync.Pool
+}
+
+func (bind *LinuxSocketBind) GetEndpoint() *LinuxSocketEndpoint {
+ return bind.epElementsPool.Get().(*LinuxSocketEndpoint)
 }

-func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1,
sock6: -1} }
-func NewDefaultBind() Bind     { return NewLinuxSocketBind() }
+func (bind *LinuxSocketBind) PutEndpoint(endpoint Endpoint) {
+ bind.epElementsPool.Put(endpoint)
+}
+
+func NewLinuxSocketBind() Bind {
+ return &LinuxSocketBind{sock4: -1, sock6: -1,
+ epElementsPool: sync.Pool{New: func() interface{} {
+ return new(LinuxSocketEndpoint)
+ }}}
+}
+
+func NewDefaultBind() Bind { return NewLinuxSocketBind() }

 var (
  _ Endpoint = (*LinuxSocketEndpoint)(nil)
@@ -224,14 +257,14 @@ func (bind *LinuxSocketBind) Close() error {
 }

 func (bind *LinuxSocketBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
+ end := bind.GetEndpoint()
  bind.mu.RLock()
  defer bind.mu.RUnlock()
  if bind.sock4 == -1 {
  return 0, nil, net.ErrClosed
  }
- var end LinuxSocketEndpoint
- n, err := receive4(bind.sock4, buf, &end)
- return n, &end, err
+ n, err := receive4(bind.sock4, buf, end)
+ return n, end, err
 }

 func (bind *LinuxSocketBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
@@ -448,11 +481,12 @@ func send4(sock int, end *LinuxSocketEndpoint,
buff []byte) error {
  // clear src and retry

  if err == unix.EINVAL {
- end.ClearSrc()
+ // clear source writing to source ip that can collide with isEqual
read. this is a rare execution code, so we will just
+ // create a copy and use it instead. (avoid write)
+ newEndpoint := end.Copy().(*LinuxSocketEndpoint)
+ newEndpoint.ClearSrc()
  cmsg.pktinfo = unix.Inet4Pktinfo{}
- end.mu.Lock()
- _, err = unix.SendmsgN(sock, buff,
(*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
- end.mu.Unlock()
+ _, err = unix.SendmsgN(sock, buff,
(*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:],
newEndpoint.dst4(), 0)
  }

  return err
diff --git a/conn/bind_std.go b/conn/bind_std.go
index e0f6cdd..4306f7f 100644
--- a/conn/bind_std.go
+++ b/conn/bind_std.go
@@ -27,8 +27,23 @@ type StdNetBind struct {

 func NewStdNetBind() Bind { return &StdNetBind{} }

+func (bind *StdNetBind) PutEndpoint(endpoint Endpoint) {
+}
+
 type StdNetEndpoint netip.AddrPort

+func (e *StdNetEndpoint) IsEqual(endpoint Endpoint) bool {
+ addrPort := (*netip.AddrPort)(e)
+ addrPortParam := (*netip.AddrPort)(endpoint.(*StdNetEndpoint))
+ return addrPort.Port() == addrPortParam.Port() &&
addrPort.Addr().Compare(addrPortParam.Addr()) == 0
+}
+
+func (e *StdNetEndpoint) Copy() Endpoint {
+ addrPortString := (*netip.AddrPort)(e).String()
+ copyEndpoint, _ := netip.ParseAddrPort(addrPortString)
+ return (*StdNetEndpoint)(&copyEndpoint)
+}
+
 var (
  _ Bind     = (*StdNetBind)(nil)
  _ Endpoint = (*StdNetEndpoint)(nil)
diff --git a/conn/bind_windows.go b/conn/bind_windows.go
index 9268bc1..d0a1d66 100644
--- a/conn/bind_windows.go
+++ b/conn/bind_windows.go
@@ -77,6 +77,9 @@ type WinRingBind struct {
  isOpen uint32
 }

+func (bind *WinRingBind) PutEndpoint(endpoint Endpoint) {
+}
+
 func NewDefaultBind() Bind { return NewWinRingBind() }

 func NewWinRingBind() Bind {
@@ -131,6 +134,17 @@ func (*WinRingBind) ParseEndpoint(s string)
(Endpoint, error) {

 func (*WinRingEndpoint) ClearSrc() {}

+func (e *WinRingEndpoint) IsEqual(endpoint Endpoint) bool {
+ winEndpoint := endpoint.(*WinRingEndpoint)
+ return winEndpoint.family == e.family && winEndpoint.data == e.data
+}
+
+func (e *WinRingEndpoint) Copy() Endpoint {
+ return &WinRingEndpoint{
+ family: e.family,
+ data:   e.data,
+ }
+}
 func (e *WinRingEndpoint) DstIP() netip.Addr {
  switch e.family {
  case windows.AF_INET:
diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go
index b38cae6..a04ccc9 100644
--- a/conn/bindtest/bindtest.go
+++ b/conn/bindtest/bindtest.go
@@ -23,8 +23,19 @@ type ChannelBind struct {
  target4, target6 ChannelEndpoint
 }

+func (c *ChannelBind) PutEndpoint(endpoint conn.Endpoint) {
+}
+
 type ChannelEndpoint uint16

+func (c ChannelEndpoint) IsEqual(endpoint conn.Endpoint) bool {
+ return c == endpoint.(ChannelEndpoint)
+}
+
+func (c ChannelEndpoint) Copy() conn.Endpoint {
+ return c
+}
+
 var (
  _ conn.Bind     = (*ChannelBind)(nil)
  _ conn.Endpoint = (*ChannelEndpoint)(nil)
diff --git a/conn/conn.go b/conn/conn.go
index 5a93b2b..c772b25 100644
--- a/conn/conn.go
+++ b/conn/conn.go
@@ -43,6 +43,9 @@ type Bind interface {

  // ParseEndpoint creates a new endpoint from a string.
  ParseEndpoint(s string) (Endpoint, error)
+
+ // PutEndpoint returns endpoint back to pool
+ PutEndpoint(endpoint Endpoint)
 }

 // BindSocketToInterface is implemented by Bind objects that support being
@@ -70,6 +73,8 @@ type Endpoint interface {
  DstToBytes() []byte  // used for mac2 cookie calculations
  DstIP() netip.Addr
  SrcIP() netip.Addr
+ IsEqual(endpoint Endpoint) bool
+ Copy() Endpoint
 }

 var (
diff --git a/device/peer.go b/device/peer.go
index 5bd52df..eb1cc41 100644
--- a/device/peer.go
+++ b/device/peer.go
@@ -271,7 +271,16 @@ func (peer *Peer) SetEndpointFromPacket(endpoint
conn.Endpoint) {
  if peer.disableRoaming {
  return
  }
+
+ peer.RLock()
+ if peer.endpoint.IsEqual(endpoint) {
+ peer.RUnlock()
+ return
+ }
+
+ peer.RUnlock()
+
  peer.Lock()
- peer.endpoint = endpoint
+ peer.endpoint = endpoint.Copy()
  peer.Unlock()
 }
diff --git a/device/pools.go b/device/pools.go
index f40477b..f861c51 100644
--- a/device/pools.go
+++ b/device/pools.go
@@ -70,6 +70,7 @@ func (device *Device) GetInboundElement()
*QueueInboundElement {
 }

 func (device *Device) PutInboundElement(elem *QueueInboundElement) {
+ device.net.bind.PutEndpoint(elem.endpoint)
  elem.clearPointers()
  device.pool.inboundElements.Put(elem)
 }
diff --git a/device/receive.go b/device/receive.go
index cc34498..e6cebcd 100644
--- a/device/receive.go
+++ b/device/receive.go
@@ -390,6 +390,7 @@ func (device *Device) RoutineHandshake(id int) {
  peer.SendKeepalive()
  }
  skip:
+ device.net.bind.PutEndpoint(elem.endpoint)
  device.PutMessageBuffer(elem.buffer)
  }
 }
-- 
2.30.1 (Apple Git-130)

                 reply	other threads:[~2022-04-24 20:19 UTC|newest]

Thread overview: [no followups] expand[flat|nested]  mbox.gz  Atom feed

Reply instructions:

You may reply publicly to this message via plain-text email
using any one of the following methods:

* Save the following mbox file, import it into your mail client,
  and reply-to-all from there: mbox

  Avoid top-posting and favor interleaved quoting:
  https://en.wikipedia.org/wiki/Posting_style#Interleaved_style

* Reply using the --to, --cc, and --in-reply-to
  switches of git-send-email(1):

  git send-email \
    --in-reply-to='CALDoV20ghXUW8DGL=yt6a8_cjGstCbEhXO9SbP6QcVLa63t7eA@mail.gmail.com' \
    --to=elul.natan@gmail.com \
    --cc=wireguard@lists.zx2c4.com \
    /path/to/YOUR_REPLY

  https://kernel.org/pub/software/scm/git/docs/git-send-email.html

* If your mail client supports setting the In-Reply-To header
  via mailto: links, try the mailto: link
Be sure your reply has a Subject: header at the top and a blank line before the message body.
This is a public inbox, see mirroring instructions
for how to clone and mirror all data and code used for this inbox;
as well as URLs for NNTP newsgroup(s).