Skip to content

Commit 1066f6a

Browse files
committed
fix: make EventuallyWithT concurrency safe
1 parent 882382d commit 1066f6a

2 files changed

Lines changed: 67 additions & 22 deletions

File tree

assert/assertions.go

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,23 +1868,18 @@ func (c *CollectT) Errorf(format string, args ...interface{}) {
18681868
}
18691869

18701870
// FailNow panics.
1871-
func (c *CollectT) FailNow() {
1871+
func (*CollectT) FailNow() {
18721872
panic("Assertion failed")
18731873
}
18741874

1875-
// Reset clears the collected errors.
1876-
func (c *CollectT) Reset() {
1877-
c.errors = nil
1875+
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
1876+
func (*CollectT) Reset() {
1877+
panic("Reset() is deprecated")
18781878
}
18791879

1880-
// Copy copies the collected errors to the supplied t.
1881-
func (c *CollectT) Copy(t TestingT) {
1882-
if tt, ok := t.(tHelper); ok {
1883-
tt.Helper()
1884-
}
1885-
for _, err := range c.errors {
1886-
t.Errorf("%v", err)
1887-
}
1880+
// Deprecated: That was a method for internal usage that should not have been published. Now just panics.
1881+
func (*CollectT) Copy(TestingT) {
1882+
panic("Copy() is deprecated")
18881883
}
18891884

18901885
// EventuallyWithT asserts that given condition will be met in waitFor time,
@@ -1910,8 +1905,8 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19101905
h.Helper()
19111906
}
19121907

1913-
collect := new(CollectT)
1914-
ch := make(chan bool, 1)
1908+
var lastFinishedTickErrs []error
1909+
ch := make(chan []error, 1)
19151910

19161911
timer := time.NewTimer(waitFor)
19171912
defer timer.Stop()
@@ -1922,19 +1917,23 @@ func EventuallyWithT(t TestingT, condition func(collect *CollectT), waitFor time
19221917
for tick := ticker.C; ; {
19231918
select {
19241919
case <-timer.C:
1925-
collect.Copy(t)
1920+
for _, err := range lastFinishedTickErrs {
1921+
t.Errorf("%v", err)
1922+
}
19261923
return Fail(t, "Condition never satisfied", msgAndArgs...)
19271924
case <-tick:
19281925
tick = nil
1929-
collect.Reset()
19301926
go func() {
1927+
collect := new(CollectT)
19311928
condition(collect)
1932-
ch <- len(collect.errors) == 0
1929+
ch <- collect.errors
19331930
}()
1934-
case v := <-ch:
1935-
if v {
1931+
case errs := <-ch:
1932+
if len(errs) == 0 {
19361933
return true
19371934
}
1935+
// Keep the errors from the last ended condition, so that they can be copied to t if timeout is reached.
1936+
lastFinishedTickErrs = errs
19381937
tick = ticker.C
19391938
}
19401939
}

assert/assertions_test.go

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2760,19 +2760,30 @@ func TestEventuallyTrue(t *testing.T) {
27602760
True(t, Eventually(t, condition, 100*time.Millisecond, 20*time.Millisecond))
27612761
}
27622762

2763+
// errorsCapturingT is a mock implementation of TestingT that captures errors reported with Errorf.
2764+
type errorsCapturingT struct {
2765+
errors []error
2766+
}
2767+
2768+
func (t *errorsCapturingT) Errorf(format string, args ...interface{}) {
2769+
t.errors = append(t.errors, fmt.Errorf(format, args...))
2770+
}
2771+
2772+
func (t *errorsCapturingT) Helper() {}
2773+
27632774
func TestEventuallyWithTFalse(t *testing.T) {
2764-
mockT := new(CollectT)
2775+
mockT := new(errorsCapturingT)
27652776

27662777
condition := func(collect *CollectT) {
2767-
True(collect, false)
2778+
Fail(collect, "condition fixed failure")
27682779
}
27692780

27702781
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
27712782
Len(t, mockT.errors, 2)
27722783
}
27732784

27742785
func TestEventuallyWithTTrue(t *testing.T) {
2775-
mockT := new(CollectT)
2786+
mockT := new(errorsCapturingT)
27762787

27772788
state := 0
27782789
condition := func(collect *CollectT) {
@@ -2786,6 +2797,41 @@ func TestEventuallyWithTTrue(t *testing.T) {
27862797
Len(t, mockT.errors, 0)
27872798
}
27882799

2800+
func TestEventuallyWithT_ConcurrencySafe(t *testing.T) {
2801+
mockT := new(errorsCapturingT)
2802+
2803+
condition := func(collect *CollectT) {
2804+
Fail(collect, "condition fixed failure")
2805+
}
2806+
2807+
// To trigger race conditions, we run EventuallyWithT with a nanosecond tick.
2808+
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, time.Nanosecond))
2809+
Len(t, mockT.errors, 2)
2810+
}
2811+
2812+
func TestEventuallyWithT_ReturnsTheLatestFinishedConditionErrors(t *testing.T) {
2813+
// We'll use a channel to control whether a condition should sleep or not.
2814+
mustSleep := make(chan bool, 2)
2815+
mustSleep <- false
2816+
mustSleep <- true
2817+
close(mustSleep)
2818+
2819+
condition := func(collect *CollectT) {
2820+
if <-mustSleep {
2821+
// Sleep to ensure that the second condition runs longer than timeout.
2822+
time.Sleep(time.Second)
2823+
return
2824+
}
2825+
2826+
// The first condition will fail. We expect to get this error as a result.
2827+
Fail(collect, "condition fixed failure")
2828+
}
2829+
2830+
mockT := new(errorsCapturingT)
2831+
False(t, EventuallyWithT(mockT, condition, 100*time.Millisecond, 20*time.Millisecond))
2832+
Len(t, mockT.errors, 2)
2833+
}
2834+
27892835
func TestNeverFalse(t *testing.T) {
27902836
condition := func() bool {
27912837
return false

0 commit comments

Comments
 (0)