Skip to content

Commit 1250dd6

Browse files
authored
[ADDED] Automatic reconnect on write error option (#2055)
Signed-off-by: Piotr Piotrowski <piotr@synadia.com>
1 parent e4a8c79 commit 1250dd6

2 files changed

Lines changed: 237 additions & 63 deletions

File tree

nats.go

Lines changed: 59 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,34 @@ type Options struct {
410410
// Defaults to 1m.
411411
FlusherTimeout time.Duration
412412

413+
// ReconnectOnFlusherError, when set to true, causes the client to
414+
// trigger a reconnect if the background flusher fails to write to the
415+
// underlying connection for any reason (timeout, broken pipe,
416+
// connection reset, EOF etc.).
417+
//
418+
// This is an advanced option. Most applications do not need to enable
419+
// it: the server-side stale connection detection (via PingInterval /
420+
// MaxPingsOut) and the read loop's own error handling will eventually
421+
// notice a dead connection and the client will reconnect. Enable this
422+
// only if you need faster recovery from a stalled or broken TCP write
423+
// — for example, in latency-sensitive setups where waiting for a ping
424+
// timeout is unacceptable.
425+
//
426+
// Messages buffered at the time of the error are lost, as they are
427+
// with any flusher write error. The purpose of this option is to
428+
// limit the blast radius by preventing further messages from being
429+
// buffered into a potentially corrupted connection, not to recover
430+
// the in-flight data.
431+
//
432+
// When triggered, the standard DisconnectErrHandler and
433+
// ReconnectHandler callbacks are invoked as with any other reconnect.
434+
// The first reconnect attempt bypasses the configured ReconnectWait
435+
// so that recovery is as fast as possible; if that attempt fails,
436+
// subsequent attempts obey the normal backoff.
437+
//
438+
// Defaults to false.
439+
ReconnectOnFlusherError bool
440+
413441
// PingInterval is the period at which the client will be sending ping
414442
// commands to the server, disabled if 0 or negative.
415443
// Defaults to 2m.
@@ -1211,6 +1239,17 @@ func FlusherTimeout(t time.Duration) Option {
12111239
}
12121240
}
12131241

1242+
// ReconnectOnFlusherError is an Option to automatically trigger a
1243+
// reconnect when the background flusher hits any write error. See
1244+
// [Options.ReconnectOnFlusherError] for details. This is an
1245+
// advanced option and is usually not required.
1246+
func ReconnectOnFlusherError() Option {
1247+
return func(o *Options) error {
1248+
o.ReconnectOnFlusherError = true
1249+
return nil
1250+
}
1251+
}
1252+
12141253
// DrainTimeout is an Option to set the timeout for draining a connection.
12151254
// Defaults to 30s.
12161255
func DrainTimeout(t time.Duration) Option {
@@ -2440,8 +2479,10 @@ func (nc *Conn) ForceReconnect() error {
24402479
// Stop ping timer if set.
24412480
nc.stopPingTimer()
24422481

2443-
// Go ahead and make sure we have flushed the outbound
2482+
// flush any pending data and switch to pending mode to buffer new outgoing
2483+
// data until we reconnect and can flush it.
24442484
nc.bw.flush()
2485+
nc.bw.switchToPending()
24452486
nc.conn.Close()
24462487

24472488
nc.changeConnStatus(RECONNECTING)
@@ -3342,8 +3383,10 @@ func (nc *Conn) doReconnect(err error, forceReconnect bool) {
33423383
}
33433384

33443385
// processOpErr handles errors from reading or parsing the protocol.
3345-
// The lock should not be held entering this function.
3346-
func (nc *Conn) processOpErr(err error) bool {
3386+
// The lock should not be held entering this function. If forceReconnect
3387+
// is true, the first reconnect attempt will bypass the configured
3388+
// ReconnectWait; subsequent attempts still obey the normal backoff.
3389+
func (nc *Conn) processOpErr(err error, forceReconnect bool) bool {
33473390
nc.mu.Lock()
33483391
defer nc.mu.Unlock()
33493392
if nc.isConnecting() || nc.isClosed() || nc.isReconnecting() {
@@ -3366,7 +3409,7 @@ func (nc *Conn) processOpErr(err error) bool {
33663409
// Clear any queued pongs, e.g. pending flush calls.
33673410
nc.clearPendingFlushCalls()
33683411

3369-
go nc.doReconnect(err, false)
3412+
go nc.doReconnect(err, forceReconnect)
33703413
return false
33713414
}
33723415

@@ -3469,7 +3512,7 @@ func (nc *Conn) readLoop() {
34693512
err = nc.parse(buf)
34703513
}
34713514
if err != nil {
3472-
if shouldClose := nc.processOpErr(err); shouldClose {
3515+
if shouldClose := nc.processOpErr(err, false); shouldClose {
34733516
nc.close(CLOSED, true, nil)
34743517
}
34753518
break
@@ -3917,6 +3960,13 @@ func (nc *Conn) flusher() {
39173960
if asyncErrorCB := nc.Opts.AsyncErrorCB; asyncErrorCB != nil {
39183961
nc.ach.push(func() { asyncErrorCB(nc, nil, err) })
39193962
}
3963+
if nc.Opts.ReconnectOnFlusherError {
3964+
nc.mu.Unlock()
3965+
if shouldClose := nc.processOpErr(err, true); shouldClose {
3966+
nc.close(CLOSED, true, nil)
3967+
}
3968+
return
3969+
}
39203970
}
39213971
}
39223972
nc.mu.Unlock()
@@ -4096,11 +4146,11 @@ func (nc *Conn) processErr(ie string) {
40964146

40974147
// FIXME(dlc) - process Slow Consumer signals special.
40984148
if e == STALE_CONNECTION {
4099-
close = nc.processOpErr(ErrStaleConnection)
4149+
close = nc.processOpErr(ErrStaleConnection, false)
41004150
} else if e == MAX_CONNECTIONS_ERR {
4101-
close = nc.processOpErr(ErrMaxConnectionsExceeded)
4151+
close = nc.processOpErr(ErrMaxConnectionsExceeded, false)
41024152
} else if e == MAX_ACCOUNT_CONNECTIONS_ERR {
4103-
close = nc.processOpErr(ErrMaxAccountConnectionsExceeded)
4153+
close = nc.processOpErr(ErrMaxAccountConnectionsExceeded, false)
41044154
} else if strings.HasPrefix(e, PERMISSIONS_ERR) {
41054155
nc.processTransientError(fmt.Errorf("%w: %s", ErrPermissionViolation, ne))
41064156
} else if strings.HasPrefix(e, MAX_SUBSCRIPTIONS_ERR) {
@@ -5682,7 +5732,7 @@ func (nc *Conn) processPingTimer() {
56825732
nc.pout++
56835733
if nc.pout > nc.Opts.MaxPingsOut {
56845734
nc.mu.Unlock()
5685-
if shouldClose := nc.processOpErr(ErrStaleConnection); shouldClose {
5735+
if shouldClose := nc.processOpErr(ErrStaleConnection, false); shouldClose {
56865736
nc.close(CLOSED, true, nil)
56875737
}
56885738
return

test/conn_test.go

Lines changed: 178 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1931,7 +1931,7 @@ func TestCustomFlusherTimeout(t *testing.T) {
19311931

19321932
errCh := make(chan error, 1)
19331933
wg := sync.WaitGroup{}
1934-
wg.Add(2)
1934+
wg.Add(1)
19351935
go func() {
19361936
defer wg.Done()
19371937
for {
@@ -1949,50 +1949,7 @@ func TestCustomFlusherTimeout(t *testing.T) {
19491949
}()
19501950
defer nc1.Close()
19511951

1952-
l, e := net.Listen("tcp", "127.0.0.1:0")
1953-
if e != nil {
1954-
t.Fatal("Could not listen on an ephemeral port")
1955-
}
1956-
tl := l.(*net.TCPListener)
1957-
defer tl.Close()
1958-
1959-
addr := tl.Addr().(*net.TCPAddr)
1960-
1961-
fsDoneCh := make(chan struct{}, 1)
1962-
fsErrCh := make(chan error, 1)
1963-
go func() {
1964-
defer wg.Done()
1965-
1966-
serverInfo := "INFO {\"server_id\":\"foobar\",\"host\":\"%s\",\"port\":%d,\"auth_required\":false,\"tls_required\":false,\"max_payload\":%d}\r\n"
1967-
conn, err := l.Accept()
1968-
if err != nil {
1969-
fsErrCh <- err
1970-
return
1971-
}
1972-
defer conn.Close()
1973-
// Make it small on purpose
1974-
if err := conn.(*net.TCPConn).SetReadBuffer(1024); err != nil {
1975-
fsErrCh <- err
1976-
return
1977-
}
1978-
1979-
info := fmt.Sprintf(serverInfo, addr.IP, addr.Port, 1024*1024)
1980-
conn.Write([]byte(info))
1981-
1982-
// Read connect and ping commands sent from the client
1983-
line := make([]byte, 100)
1984-
_, err = conn.Read(line)
1985-
if err != nil {
1986-
fsErrCh <- fmt.Errorf("Expected CONNECT and PING from client, got: %v", err)
1987-
return
1988-
}
1989-
conn.Write([]byte("PONG\r\n"))
1990-
1991-
// Don't consume anything at this point and wait to be notified
1992-
// that we are done.
1993-
<-fsDoneCh
1994-
fsErrCh <- nil
1995-
}()
1952+
addr := startStalledMockServer(t)
19961953

19971954
nc2, err := nats.Connect(
19981955
// URL to fake server
@@ -2046,17 +2003,12 @@ forLoop:
20462003
}
20472004
}
20482005

2049-
// Notify fake server that it can stop
2050-
close(fsDoneCh)
2051-
2052-
// Wait for go routines to end
2006+
// Close nc2 to fire its ClosedHandler, which signals the publisher
2007+
// goroutine (via doneCh) to exit. The fake server is torn down by
2008+
// the t.Cleanup registered in startStalledFakeServer.
2009+
nc2.Close()
20532010
wg.Wait()
20542011

2055-
// Make sure there were no error in the fake server
2056-
if err := <-fsErrCh; err != nil {
2057-
t.Fatalf("Fake server reported: %v", err)
2058-
}
2059-
20602012
// One of those two are guaranteed to be set.
20612013
err = nc2Err
20622014
if err == nil {
@@ -2082,6 +2034,178 @@ forLoop:
20822034
}
20832035
}
20842036

2037+
// startStalledMockServer starts a fake NATS server on an ephemeral port
2038+
// that completes the INFO/CONNECT/PING handshake and then stops reading
2039+
// from the socket, so client writes eventually stall.
2040+
func startStalledMockServer(t *testing.T) *net.TCPAddr {
2041+
t.Helper()
2042+
l, err := net.Listen("tcp", "127.0.0.1:0")
2043+
if err != nil {
2044+
t.Fatalf("Could not listen on an ephemeral port: %v", err)
2045+
}
2046+
addr := l.Addr().(*net.TCPAddr)
2047+
2048+
done := make(chan struct{})
2049+
exited := make(chan struct{})
2050+
2051+
go func() {
2052+
defer close(exited)
2053+
conn, err := l.Accept()
2054+
if err != nil {
2055+
return
2056+
}
2057+
defer conn.Close()
2058+
if err := conn.(*net.TCPConn).SetReadBuffer(1024); err != nil {
2059+
t.Errorf("Expected SetReadBuffer to succeed, got: %v", err)
2060+
return
2061+
}
2062+
info := fmt.Sprintf(
2063+
"INFO {\"server_id\":\"foobar\",\"host\":\"%s\",\"port\":%d,\"auth_required\":false,\"tls_required\":false,\"max_payload\":%d}\r\n",
2064+
addr.IP, addr.Port, 1024*1024,
2065+
)
2066+
conn.Write([]byte(info))
2067+
line := make([]byte, 100)
2068+
if _, err := conn.Read(line); err != nil {
2069+
t.Errorf("Expected CONNECT+PING, got: %v", err)
2070+
return
2071+
}
2072+
conn.Write([]byte("PONG\r\n"))
2073+
<-done
2074+
}()
2075+
2076+
t.Cleanup(func() {
2077+
close(done)
2078+
l.Close()
2079+
<-exited
2080+
})
2081+
return addr
2082+
}
2083+
2084+
func TestReconnectOnFlusherError(t *testing.T) {
2085+
if runtime.GOOS == "windows" {
2086+
t.SkipNow()
2087+
}
2088+
2089+
for _, tc := range []struct {
2090+
name string
2091+
withOption bool
2092+
noReconnect bool
2093+
wantReconnect bool
2094+
wantClosed bool
2095+
}{
2096+
{"enabled", true, false, true, false},
2097+
{"disabled", false, false, false, false},
2098+
{"no_reconnect", true, true, false, true},
2099+
} {
2100+
t.Run(tc.name, func(t *testing.T) {
2101+
s := RunDefaultServer()
2102+
defer s.Shutdown()
2103+
2104+
fakeAddr := startStalledMockServer(t)
2105+
2106+
reconnectedCh := make(chan struct{}, 1)
2107+
closedCh := make(chan struct{}, 1)
2108+
asyncErrCh := make(chan error, 1)
2109+
2110+
opts := []nats.Option{
2111+
nats.SetCustomDialer(&lowWriteBufferDialer{}),
2112+
nats.FlusherTimeout(15 * time.Millisecond),
2113+
nats.MaxReconnects(10),
2114+
nats.DontRandomize(),
2115+
nats.ReconnectHandler(func(_ *nats.Conn) {
2116+
select {
2117+
case reconnectedCh <- struct{}{}:
2118+
default:
2119+
}
2120+
}),
2121+
nats.ClosedHandler(func(_ *nats.Conn) {
2122+
select {
2123+
case closedCh <- struct{}{}:
2124+
default:
2125+
}
2126+
}),
2127+
nats.ErrorHandler(func(_ *nats.Conn, _ *nats.Subscription, err error) {
2128+
select {
2129+
case asyncErrCh <- err:
2130+
default:
2131+
}
2132+
}),
2133+
}
2134+
if tc.withOption {
2135+
opts = append(opts, nats.ReconnectOnFlusherError())
2136+
}
2137+
if tc.noReconnect {
2138+
opts = append(opts, nats.NoReconnect())
2139+
}
2140+
2141+
nc, err := nats.Connect(
2142+
fmt.Sprintf("nats://127.0.0.1:%d,%s", fakeAddr.Port, nats.DefaultURL),
2143+
opts...,
2144+
)
2145+
if err != nil {
2146+
t.Fatalf("Connect: %v", err)
2147+
}
2148+
defer nc.Close()
2149+
2150+
if url := nc.ConnectedUrl(); url != fmt.Sprintf("nats://127.0.0.1:%d", fakeAddr.Port) {
2151+
t.Fatalf("Expected initial connection to fake server, got %q", url)
2152+
}
2153+
2154+
stopPub := make(chan struct{})
2155+
pubDone := make(chan struct{})
2156+
go func() {
2157+
defer close(pubDone)
2158+
tick := time.NewTicker(50 * time.Millisecond)
2159+
defer tick.Stop()
2160+
// Slightly under the library's internal buffer size so the
2161+
// flusher goroutine (not the Publish call) triggers the write.
2162+
payload := make([]byte, 32*1024-200)
2163+
for {
2164+
select {
2165+
case <-stopPub:
2166+
return
2167+
case <-tick.C:
2168+
nc.Publish("hello", payload)
2169+
}
2170+
}
2171+
}()
2172+
defer func() {
2173+
close(stopPub)
2174+
<-pubDone
2175+
}()
2176+
2177+
// Confirm the flusher actually observes a write error.
2178+
select {
2179+
case <-asyncErrCh:
2180+
case <-time.After(3 * time.Second):
2181+
t.Fatal("flusher did not report an async write error within 3s")
2182+
}
2183+
2184+
if tc.wantReconnect {
2185+
select {
2186+
case <-reconnectedCh:
2187+
case <-time.After(5 * time.Second):
2188+
t.Fatal("expected reconnect after flusher error")
2189+
}
2190+
if url := nc.ConnectedUrl(); url != s.ClientURL() {
2191+
t.Fatalf("expected to be reconnected to real server %q, got %q", s.ClientURL(), url)
2192+
}
2193+
} else if tc.wantClosed {
2194+
WaitOnChannel(t, closedCh, struct{}{})
2195+
if !nc.IsClosed() {
2196+
t.Fatalf("expected IsClosed() to be true, got status %v", nc.Status())
2197+
}
2198+
} else {
2199+
select {
2200+
case <-reconnectedCh:
2201+
t.Fatal("unexpected reconnect when ReconnectOnFlusherError is disabled")
2202+
case <-time.After(500 * time.Millisecond):
2203+
}
2204+
}
2205+
})
2206+
}
2207+
}
2208+
20852209
func TestNewServers(t *testing.T) {
20862210
s1Opts := test.DefaultTestOptions
20872211
s1Opts.Host = "127.0.0.1"

0 commit comments

Comments
 (0)