Skip to content

Commit 797028d

Browse files
committed
feat: ✨ add thread-safe handlers registration
1 parent a0ab5a6 commit 797028d

5 files changed

Lines changed: 74 additions & 38 deletions

File tree

go.mod

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ module github.com/mehdihadeli/go-mediatr
33
go 1.24
44

55
require (
6-
github.com/goccy/go-reflect v1.2.0
76
github.com/pkg/errors v0.9.1
87
github.com/stretchr/testify v1.10.0
98
)

internal/examples/cqrs_example/go.mod

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module cqrsexample
22

3-
go 1.18
3+
go 1.24
4+
5+
toolchain go1.24.2
46

57
replace github.com/mehdihadeli/go-mediatr => ../../../
68

mediatr.go

Lines changed: 51 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ package mediatr
22

33
import (
44
"context"
5+
"reflect"
6+
"sync"
57

6-
"github.com/goccy/go-reflect"
78
"github.com/pkg/errors"
89
)
910

@@ -27,16 +28,22 @@ type NotificationHandler[TNotification any] interface {
2728

2829
type NotificationHandlerFactory[TNotification any] func() NotificationHandler[TNotification]
2930

30-
var requestHandlersRegistrations = map[reflect.Type]interface{}{}
31-
var notificationHandlersRegistrations = map[reflect.Type][]interface{}{}
32-
var pipelineBehaviours []interface{} = []interface{}{}
31+
var (
32+
requestHandlersRegistrations = map[reflect.Type]interface{}{}
33+
notificationHandlersRegistrations = map[reflect.Type][]interface{}{}
34+
pipelineBehaviours []interface{}
35+
registryMutex sync.RWMutex
36+
)
3337

3438
type Unit struct{}
3539

3640
func registerRequestHandler[TRequest any, TResponse any](handler any) error {
3741
var request TRequest
3842
requestType := reflect.TypeOf(request)
3943

44+
registryMutex.Lock()
45+
defer registryMutex.Unlock()
46+
4047
_, exist := requestHandlersRegistrations[requestType]
4148
if exist {
4249
// each request in request/response strategy should have just one handler
@@ -60,6 +67,9 @@ func RegisterRequestHandlerFactory[TRequest any, TResponse any](factory RequestH
6067

6168
// RegisterRequestPipelineBehaviors register the request behaviors to mediatr registry.
6269
func RegisterRequestPipelineBehaviors(behaviours ...PipelineBehavior) error {
70+
registryMutex.Lock()
71+
defer registryMutex.Unlock()
72+
6373
for _, behavior := range behaviours {
6474
behaviorType := reflect.TypeOf(behavior)
6575

@@ -78,6 +88,9 @@ func registerNotificationHandler[TEvent any](handler any) error {
7888
var event TEvent
7989
eventType := reflect.TypeOf(event)
8090

91+
registryMutex.Lock()
92+
defer registryMutex.Unlock()
93+
8194
handlers, exist := notificationHandlersRegistrations[eventType]
8295
if !exist {
8396
notificationHandlersRegistrations[eventType] = []interface{}{handler}
@@ -157,7 +170,15 @@ func buildRequestHandler[TRequest any, TResponse any](handler any) (RequestHandl
157170
func Send[TRequest any, TResponse any](ctx context.Context, request TRequest) (TResponse, error) {
158171
requestType := reflect.TypeOf(request)
159172
var response TResponse
173+
174+
registryMutex.RLock()
160175
handler, ok := requestHandlersRegistrations[requestType]
176+
// without copying, another goroutine could modify the original slice
177+
behavioursCopy := make([]interface{}, len(pipelineBehaviours))
178+
// deep copy of elements
179+
copy(behavioursCopy, pipelineBehaviours)
180+
registryMutex.RUnlock()
181+
161182
if !ok {
162183
// request-response strategy should have exactly one handler and if we can't find a corresponding handler, we should return an error
163184
return *new(TResponse), errors.Errorf("no handler for request %T", request)
@@ -168,29 +189,29 @@ func Send[TRequest any, TResponse any](ctx context.Context, request TRequest) (T
168189
return *new(TResponse), errors.Errorf("handler for request %T is not a Handler", request)
169190
}
170191

171-
if len(pipelineBehaviours) > 0 {
172-
var reversPipes = reversOrder(pipelineBehaviours)
192+
if len(behavioursCopy) > 0 {
193+
var reversPipes = reversOrder(behavioursCopy)
173194

174195
var lastHandler RequestHandlerFunc = func(ctx context.Context) (interface{}, error) {
175196
return handlerValue.Handle(ctx, request)
176197
}
177-
178-
aggregateResult := lastHandler
179-
for _, pipe := range reversPipes {
180-
pipeValue := pipe.(PipelineBehavior)
181-
currentNext := aggregateResult
182-
183-
aggregateResult = func(ctx context.Context) (interface{}, error) {
184-
return pipeValue.Handle(ctx, request, currentNext)
185-
}
186-
}
187-
188-
response, err := aggregateResult(ctx)
189-
if err != nil {
190-
return *new(TResponse), errors.Wrap(err, "error handling request")
191-
}
192-
193-
return response.(TResponse), nil
198+
199+
aggregateResult := lastHandler
200+
for _, pipe := range reversPipes {
201+
pipeValue := pipe.(PipelineBehavior)
202+
currentNext := aggregateResult
203+
204+
aggregateResult = func(ctx context.Context) (interface{}, error) {
205+
return pipeValue.Handle(ctx, request, currentNext)
206+
}
207+
}
208+
209+
response, err := aggregateResult(ctx)
210+
if err != nil {
211+
return *new(TResponse), errors.Wrap(err, "error handling request")
212+
}
213+
214+
return response.(TResponse), nil
194215
} else {
195216
res, err := handlerValue.Handle(ctx, request)
196217
if err != nil {
@@ -221,13 +242,19 @@ func buildNotificationHandler[TNotification any](handler any) (NotificationHandl
221242
func Publish[TNotification any](ctx context.Context, notification TNotification) error {
222243
eventType := reflect.TypeOf(notification)
223244

245+
registryMutex.RLock()
224246
handlers, ok := notificationHandlersRegistrations[eventType]
247+
// without copying, another goroutine could modify the original slice
248+
handlersCopy := make([]interface{}, len(handlers))
249+
// deep copy of elements
250+
copy(handlersCopy, handlers)
251+
registryMutex.RUnlock()
225252
if !ok {
226253
// notification strategy should have zero or more handlers, so it should run without any error if we can't find a corresponding handler
227254
return nil
228255
}
229256

230-
for _, handler := range handlers {
257+
for _, handler := range handlersCopy {
231258
handlerValue, ok := buildNotificationHandler[TNotification](handler)
232259

233260
if !ok {

mediatr_benchmarks_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ package mediatr
22

33
import (
44
"context"
5-
"github.com/goccy/go-reflect"
5+
"reflect"
66
"testing"
77
)
88

mediatr_test.go

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,17 @@ package mediatr
33
import (
44
"context"
55
"fmt"
6+
"reflect"
7+
"sync"
68
"testing"
79

8-
"github.com/goccy/go-reflect"
910
"github.com/pkg/errors"
1011
"github.com/stretchr/testify/assert"
1112
"github.com/stretchr/testify/require"
1213
)
1314

1415
var testData []string
16+
var testMutex sync.Mutex
1517

1618
func TestRunner(t *testing.T) {
1719
//https://pkg.go.dev/testing@master#hdr-Subtests_and_Sub_benchmarks
@@ -186,20 +188,19 @@ func (t *MediatRTests) Test_Send_Should_Dispatch_Request_To_Handler_And_Get_Resp
186188
pip1 := &PipelineBehaviourTest{}
187189
pip2 := &PipelineBehaviourTest2{}
188190
err := RegisterRequestPipelineBehaviors(pip1, pip2)
189-
if err != nil {
190-
t.Errorf("error registering request pipeline behaviors: %s", err)
191-
}
191+
assert.Nil(t, err)
192192

193193
handler := &RequestTestHandler{}
194194
errRegister := RegisterRequestHandler[*RequestTest, *ResponseTest](handler)
195-
if errRegister != nil {
196-
t.Error(errRegister)
197-
}
195+
assert.Nil(t, errRegister)
198196

199197
response, err := Send[*RequestTest, *ResponseTest](context.Background(), &RequestTest{Data: "test"})
200198
assert.Nil(t, err)
201199
assert.IsType(t, &ResponseTest{}, response)
202200
assert.Equal(t, "test", response.Data)
201+
202+
testMutex.Lock()
203+
defer testMutex.Unlock()
203204
assert.Contains(t, testData, "PipelineBehaviourTest")
204205
assert.Contains(t, testData, "PipelineBehaviourTest2")
205206
}
@@ -328,7 +329,7 @@ func (t *MediatRTests) Test_Clear_Request_Registrations() {
328329
err2 := RegisterRequestHandler[*RequestTest2, *ResponseTest2](handler2)
329330
require.NoError(t, err1, err2)
330331

331-
ClearRequestRegistrations()
332+
cleanup()
332333

333334
count := len(requestHandlersRegistrations)
334335
assert.Equal(t, 0, count)
@@ -481,7 +482,14 @@ func (c *PipelineBehaviourTest2) Handle(ctx context.Context, request interface{}
481482

482483
// /////////////////////////////////////////////////////////////////////////////////////////////
483484
func cleanup() {
484-
requestHandlersRegistrations = map[reflect.Type]interface{}{}
485-
notificationHandlersRegistrations = map[reflect.Type][]interface{}{}
486-
pipelineBehaviours = []interface{}{}
485+
testMutex.Lock()
486+
defer testMutex.Unlock()
487+
488+
// Reset package-level registrations
489+
requestHandlersRegistrations = make(map[reflect.Type]interface{})
490+
notificationHandlersRegistrations = make(map[reflect.Type][]interface{})
491+
pipelineBehaviours = make([]interface{}, 0)
492+
493+
// Reset test data
494+
testData = nil
487495
}

0 commit comments

Comments
 (0)