@@ -26,12 +26,15 @@ import (
2626 "github.com/stretchr/testify/suite"
2727 "go.uber.org/goleak"
2828 "go.uber.org/zap/zapcore"
29+ "google.golang.org/grpc"
2930
3031 "github.com/pingcap/failpoint"
3132 "github.com/pingcap/log"
3233
34+ "github.com/tikv/pd/client/errs"
3335 "github.com/tikv/pd/client/opt"
3436 cctx "github.com/tikv/pd/client/pkg/connectionctx"
37+ "github.com/tikv/pd/client/pkg/retry"
3538 "github.com/tikv/pd/client/pkg/utils/testutil"
3639 sd "github.com/tikv/pd/client/servicediscovery"
3740)
@@ -44,6 +47,7 @@ type mockTSOServiceProvider struct {
4447 option * opt.Option
4548 createStream func (ctx context.Context ) * tsoStream
4649 conCtxMgr * cctx.Manager [* tsoStream ]
50+ svcDiscovery sd.ServiceDiscovery
4751}
4852
4953func newMockTSOServiceProvider (option * opt.Option , createStream func (ctx context.Context ) * tsoStream ) * mockTSOServiceProvider {
@@ -58,7 +62,10 @@ func (m *mockTSOServiceProvider) getOption() *opt.Option {
5862 return m .option
5963}
6064
61- func (* mockTSOServiceProvider ) getServiceDiscovery () sd.ServiceDiscovery {
65+ func (m * mockTSOServiceProvider ) getServiceDiscovery () sd.ServiceDiscovery {
66+ if m .svcDiscovery != nil {
67+ return m .svcDiscovery
68+ }
6269 return sd .NewMockServiceDiscovery ([]string {mockStreamURL }, nil )
6370}
6471
@@ -108,7 +115,7 @@ func (s *testTSODispatcherSuite) SetupTest() {
108115 created .Store (true )
109116 return s .stream
110117 }
111- s .dispatcher = newTSODispatcher (context .Background (), defaultMaxTSOBatchSize , newMockTSOServiceProvider (s .option , createStream ))
118+ s .dispatcher = newTSODispatcher (context .Background (), newMockTSOServiceProvider (s .option , createStream ))
112119 s .reqPool = & sync.Pool {
113120 New : func () any {
114121 return & Request {
@@ -183,6 +190,37 @@ func TestTSODispatcherTestSuite(t *testing.T) {
183190 suite .Run (t , new (testTSODispatcherSuite ))
184191}
185192
193+ type countingServiceDiscovery struct {
194+ removeCount atomic.Int32
195+ scheduleCount atomic.Int32
196+ }
197+
198+ func (* countingServiceDiscovery ) Init () error { return nil }
199+ func (* countingServiceDiscovery ) Close () {}
200+ func (* countingServiceDiscovery ) GetClusterID () uint64 { return 0 }
201+ func (* countingServiceDiscovery ) GetKeyspaceID () uint32 { return 0 }
202+ func (* countingServiceDiscovery ) SetKeyspaceID (uint32 ) {}
203+ func (* countingServiceDiscovery ) GetKeyspaceGroupID () uint32 { return 0 }
204+ func (* countingServiceDiscovery ) GetServiceURLs () []string { return nil }
205+ func (* countingServiceDiscovery ) GetServingEndpointClientConn () * grpc.ClientConn { return nil }
206+ func (* countingServiceDiscovery ) GetClientConns () * sync.Map { return & sync.Map {} }
207+ func (* countingServiceDiscovery ) GetServingURL () string { return "" }
208+ func (* countingServiceDiscovery ) GetBackupURLs () []string { return nil }
209+ func (* countingServiceDiscovery ) GetServiceClient () sd.ServiceClient { return nil }
210+ func (* countingServiceDiscovery ) GetServiceClientByKind (sd.APIKind ) sd.ServiceClient { return nil }
211+ func (* countingServiceDiscovery ) GetAllServiceClients () []sd.ServiceClient { return nil }
212+ func (* countingServiceDiscovery ) GetOrCreateGRPCConn (string ) (* grpc.ClientConn , error ) {
213+ return nil , nil
214+ }
215+ func (s * countingServiceDiscovery ) RemoveClientConn (string ) { s .removeCount .Add (1 ) }
216+ func (s * countingServiceDiscovery ) ScheduleCheckMemberChanged () {
217+ s .scheduleCount .Add (1 )
218+ }
219+ func (* countingServiceDiscovery ) CheckMemberChanged () error { return nil }
220+ func (* countingServiceDiscovery ) ExecAndAddLeaderSwitchedCallback (sd.LeaderSwitchedCallbackFunc ) {}
221+ func (* countingServiceDiscovery ) AddLeaderSwitchedCallback (sd.LeaderSwitchedCallbackFunc ) {}
222+ func (* countingServiceDiscovery ) AddMembersChangedCallback (func ()) {}
223+
186224func (s * testTSODispatcherSuite ) TestBasic () {
187225 ctx := context .Background ()
188226 req := s .sendReq (ctx )
@@ -196,6 +234,29 @@ func (s *testTSODispatcherSuite) TestBasic() {
196234 s .reqMustNotReady (req )
197235}
198236
237+ func (s * testTSODispatcherSuite ) TestHandleProcessRequestErrorRemoveConnOnCalleeMismatch () {
238+ svcDiscovery := & countingServiceDiscovery {}
239+ td := newTSODispatcher (context .Background (), & mockTSOServiceProvider {
240+ option : s .option ,
241+ conCtxMgr : cctx .NewManager [* tsoStream ](),
242+ svcDiscovery : svcDiscovery ,
243+ })
244+ defer td .close ()
245+ bo := retry .InitialBackoffer (time .Millisecond , time .Millisecond , time .Millisecond )
246+
247+ ok := td .handleProcessRequestError (
248+ context .Background (),
249+ bo ,
250+ cctx .NewManager [* tsoStream ](),
251+ mockStreamURL ,
252+ fmt .Errorf ("%s" , errs .MismatchCalleeIDErr ),
253+ )
254+
255+ s .re .True (ok )
256+ s .re .Equal (int32 (1 ), svcDiscovery .removeCount .Load ())
257+ s .re .Zero (svcDiscovery .scheduleCount .Load ())
258+ }
259+
199260func (s * testTSODispatcherSuite ) checkIdleTokenCount (expectedTotal int ) {
200261 // When the tsoDispatcher is idle, the dispatcher loop will acquire a token and wait for requests. Therefore
201262 // there should be N-1 free tokens remaining.
@@ -358,7 +419,7 @@ func BenchmarkTSODispatcherHandleRequests(b *testing.B) {
358419 return req
359420 }
360421
361- dispatcher := newTSODispatcher (ctx , defaultMaxTSOBatchSize , newMockTSOServiceProvider (opt .NewOption (), nil ))
422+ dispatcher := newTSODispatcher (ctx , newMockTSOServiceProvider (opt .NewOption (), nil ))
362423 var wg sync.WaitGroup
363424 wg .Add (1 )
364425
0 commit comments