Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 68 additions & 9 deletions mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,7 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *

for i, call := range m.ExpectedCalls {
if call.Method == method {
_, diffCount := call.Arguments.Diff(arguments)
if diffCount == 0 {
if call.Arguments.matchCount(arguments) == 0 {
expectedCall = call
if call.Repeatability > -1 {
return i, call
Expand All @@ -405,7 +404,6 @@ func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *

type matchCandidate struct {
call *Call
mismatch string
diffCount int
}

Expand All @@ -430,16 +428,14 @@ func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool {
return false
}

func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) {
func (m *Mock) findClosestCall(method string, arguments ...interface{}) *Call {
var bestMatch matchCandidate

for _, call := range m.expectedCalls() {
if call.Method == method {

errInfo, tempDiffCount := call.Arguments.Diff(arguments)
tempDiffCount := call.Arguments.matchCount(arguments)
tempCandidate := matchCandidate{
call: call,
mismatch: errInfo,
diffCount: tempDiffCount,
}
if tempCandidate.isBetterMatchThan(bestMatch) {
Expand All @@ -448,7 +444,7 @@ func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call,
}
}

return bestMatch.call, bestMatch.mismatch
return bestMatch.call
}

func callString(method string, arguments Arguments, includeArgumentValues bool) string {
Expand Down Expand Up @@ -512,10 +508,13 @@ func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Argumen
// a) this is a totally unexpected call to this method,
// b) the arguments are not what was expected, or
// c) the developer has forgotten to add an accompanying On...Return pair.
closestCall, mismatch := m.findClosestCall(methodName, arguments...)
closestCall := m.findClosestCall(methodName, arguments...)
m.mutex.Unlock()

if closestCall != nil {
// Format the diff outside the mutex to avoid deadlocks when
// arguments implement Stringer and call back into MethodCalled.
mismatch, _ := closestCall.Arguments.Diff(arguments)
m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s\nat: %s\n",
callString(methodName, arguments, true),
callString(methodName, closestCall.Arguments, true),
Expand Down Expand Up @@ -953,6 +952,66 @@ func (args Arguments) Is(objects ...interface{}) bool {
return true
}

// matchCount returns the number of argument differences without formatting
// output strings. This is safe to call while holding a mutex because it
// does not invoke user-defined methods like String() or GoString() that
// could call back into MethodCalled and cause a deadlock.
func (args Arguments) matchCount(objects []interface{}) int {
maxArgCount := len(args)
if len(objects) > maxArgCount {
maxArgCount = len(objects)
}

var differences int
for i := 0; i < maxArgCount; i++ {
if len(objects) <= i || len(args) <= i {
differences++
continue
}
actual := objects[i]
expected := args[i]

if matcher, ok := expected.(argumentMatcher); ok {
func() {
defer func() {
if recover() != nil {
differences++
}
}()
if !matcher.Matches(actual) {
differences++
}
}()
} else {
switch expected := expected.(type) {
case anythingOfTypeArgument:
if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
differences++
}
case *IsTypeArgument:
if reflect.TypeOf(actual) != expected.t {
differences++
}
case *FunctionalOptionsArgument:
var name string
if len(expected.values) > 0 {
name = "[]" + reflect.TypeOf(expected.values[0]).String()
}
if name != reflect.TypeOf(actual).String() && len(expected.values) != 0 {
differences++
} else if ef, af := assertOpts(expected.values, actual); ef != "" || af != "" {
differences++
}
default:
if !assert.ObjectsAreEqual(expected, Anything) && !assert.ObjectsAreEqual(actual, Anything) && !assert.ObjectsAreEqual(actual, expected) {
differences++
}
}
}
}
return differences
}

// Diff gets a string describing the differences between the arguments
// and the specified objects.
//
Expand Down
29 changes: 29 additions & 0 deletions mock/mock_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,35 @@ func TestIssue1785ArgumentWithMutatingStringer(t *testing.T) {
m.AssertExpectations(t)
}

// TestIssue1719StringerDeadlock verifies that MethodCalled does not deadlock
// when an argument's String() method calls back into MethodCalled.
// See https://github.com/stretchr/testify/issues/1719
func TestIssue1719StringerDeadlock(t *testing.T) {
done := make(chan struct{})

go func() {
defer close(done)

m := &Mock{}
m.On("String").Return("")
m.On("DoAThing", Anything).Return()

// When DoAThing is called with the mock itself as an argument,
// Diff used to format the argument with %v, triggering String(),
// which calls MethodCalled("String") — deadlock because the mutex
// is already held by the outer MethodCalled("DoAThing").
m.MethodCalled("DoAThing", m)
m.MethodCalled("String")
}()

select {
case <-done:
// Success — no deadlock
case <-time.After(5 * time.Second):
t.Fatal("MethodCalled deadlocked when argument's String() calls MethodCalled")
}
}

func TestIssue1227AssertExpectationsForObjectsWithMock(t *testing.T) {
mockT := &MockTestingT{}
AssertExpectationsForObjects(mockT, Mock{})
Expand Down