Skip to content

Commit f17ba02

Browse files
committed
fix: make EventuallyWithT concurrency safe
1 parent a23f5db commit f17ba02

2 files changed

Lines changed: 67 additions & 25 deletions

File tree

assert/assertions.go

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1870,23 +1870,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
18701870
}
18711871

18721872
// FailNow panics.
1873-
func (c *CollectT) FailNow() {
1873+
func (*CollectT) FailNow() {
18741874
panic("Assertion failed")
18751875
}
18761876

1877-
// Reset clears the collected errors.
1878-
func (c *CollectT) Reset() {
1879-
c.errors = nil
1877+
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
1878+
func (*CollectT) Reset() {
1879+
panic("Reset() is deprecated")
18801880
}
18811881

1882-
// Copy copies the collected errors to the supplied t.
1883-
func (c *CollectT) Copy(t TestingT) {
1884-
if tt, ok := t.(tHelper); ok {
1885-
tt.Helper()
1886-
}
1887-
for _, err := range c.errors {
1888-
t.Errorf("%v", err)
1889-
}
1882+
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
1883+
func (*CollectT) Copy(TestingT) {
1884+
panic("Copy() is deprecated")
18901885
}
18911886

18921887
// EventuallyWithT asserts that given condition will be met in waitFor time,
@@ -1912,8 +1907,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19121907
h.Helper()
19131908
}
19141909

1915-
collect := new(CollectT)
1916-
ch := make(chan bool, 1)
1910+
var lastFinishedTickErrs []error
1911+
ch := make(chan []error, 1)
19171912

19181913
timer := time.NewTimer(waitFor)
19191914
defer timer.Stop()
@@ -1924,19 +1919,23 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19241919
for tick := ticker.C; ; {
19251920
select {
19261921
case <-timer.C:
1927-
collect.Copy(t)
1922+
for _, err := range lastFinishedTickErrs {
1923+
t.Errorf("%v", err)
1924+
}
19281925
return Fail(t, "Condition never satisfied", msgAndArgs...)
19291926
case <-tick:
19301927
tick = nil
1931-
collect.Reset()
19321928
go func() {
1929+
collect := new(CollectT)
19331930
condition(collect)
1934-
ch <- len(collect.errors) == 0
1931+
ch <- collect.errors
19351932
}()
1936-
case v := <-ch:
1937-
if v {
1933+
case errs := <-ch:
1934+
if len(errs) == 0 {
19381935
return true
19391936
}
1937+
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
1938+
lastFinishedTickErrs = errs
19401939
tick = ticker.C
19411940
}
19421941
}

assert/assertions_test.go

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"regexp"
1515
"runtime"
1616
"strings"
17+
"sync/atomic"
1718
"testing"
1819
"time"
1920
"unsafe"
@@ -2760,32 +2761,74 @@ func TestEventuallyTrue(t *testing.T) {
27602761
True(t, Eventually(t, condition, 100*time.Millisecond, 20*time.Millisecond))
27612762
}
27622763

2764+
// errorsCapturingT is a mock implementation of TestingT that captures errors reported with Errorf.
2765+
type errorsCapturingT struct {
2766+
errors []error
2767+
}
2768+
2769+
func (t *errorsCapturingT) Errorf(format string, args ...interface{}) {
2770+
t.errors = append(t.errors, fmt.Errorf(format, args...))
2771+
}
2772+
2773+
func (t *errorsCapturingT) Helper() {}
2774+
27632775
func TestEventuallyWithTFalse(t *testing.T) {
2764-
mockT := new(CollectT)
2776+
mockT := new(errorsCapturingT)
27652777

27662778
condition := func(collect *CollectT) {
2767-
True(collect, false)
2779+
Fail(collect, "condition fixed failure")
27682780
}
27692781

27702782
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
27712783
Len(t, mockT.errors, 2)
27722784
}
27732785

27742786
func TestEventuallyWithTTrue(t *testing.T) {
2775-
mockT := new(CollectT)
2787+
mockT := new(errorsCapturingT)
27762788

2777-
state := 0
2789+
var state atomic.Int32
27782790
condition := func(collect *CollectT) {
27792791
defer func() {
2780-
state += 1
2792+
state.Add(1)
27812793
}()
2782-
True(collect, state == 2)
2794+
True(collect, state.Load() == 2)
27832795
}
27842796

27852797
True(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
27862798
Len(t, mockT.errors, 0)
27872799
}
27882800

2801+
func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
2802+
mockT := new(errorsCapturingT)
2803+
2804+
condition := func(collect *CollectT) {
2805+
Fail(collect, "condition fixed failure")
2806+
}
2807+
2808+
// To trigger race conditions, we run EventuallyWithT with a nanosecond tick.
2809+
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Nanosecond))
2810+
Len(t, mockT.errors, 2)
2811+
}
2812+
2813+
func TestEventuallyWithT_ReturnsTheLatestFinishedConditionErrors(t *testing.T) {
2814+
var calledOnce atomic.Bool
2815+
condition := func(collect *CollectT) {
2816+
if calledOnce.Load() {
2817+
// Sleep to ensure that the second condition runs longer than timeout.
2818+
time.Sleep(time.Second)
2819+
return
2820+
}
2821+
2822+
// The first condition will fail. We expect to get this error as a result.
2823+
Fail(collect, "condition fixed failure")
2824+
calledOnce.Store(true)
2825+
}
2826+
2827+
mockT := new(errorsCapturingT)
2828+
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
2829+
Len(t, mockT.errors, 2)
2830+
}
2831+
27892832
func TestNeverFalse(t *testing.T) {
27902833
condition := func() bool {
27912834
return false

0 commit comments

Comments
 (0)