Skip to content

Commit 9ead470

Browse files
authored
*: fix potential bug by ai in pkg/schedule/schedulers (#10312)
close #10449 - It removes several unsafe type assertions on JSON input (like input["x"].(string)), replacing them with checked parsing and clear 400 Bad Request responses, so malformed payloads no longer panic PD. - Key places: pkg/schedule/schedulers/balance_range.go, pkg/schedule/schedulers/evict_leader.go, pkg/schedule/schedulers/grant_hot_region.go, pkg/schedule/schedulers/grant_leader.go - It fixes a risky type cast in transfer witness handling by making RecvRegionInfo safe (nil on wrong scheduler type) and guarding before channel send. - pkg/schedule/schedulers/transfer_witness_leader.go, pkg/schedule/schedulers/scheduler_controller.go - It corrects HTTP error behavior in scheduler API paths: - invalid user options now return 400 instead of 500 - avoids nil dereference when delete redirect fails and response is nil - handles “handler unavailable” cases explicitly - server/api/scheduler.go - It adds regression tests specifically asserting no panic and expected status codes for bad/odd inputs. - pkg/schedule/schedulers/balance_range_test.go, pkg/schedule/schedulers/evict_leader_test.go, pkg/schedule/schedulers/grant_hot_region_test.go, pkg/schedule/schedulers/grant_leader_test.go, pkg/schedule/schedulers/transfer_witness_leader_test.go Signed-off-by: tongjian <1045931706@qq.com> Signed-off-by: bufferflies <1045931706@qq.com>
1 parent 3095bb6 commit 9ead470

File tree

12 files changed

+221
-28
lines changed

12 files changed

+221
-28
lines changed

errors.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ error = '''
5656
the option %s does not exist
5757
'''
5858

59+
["PD:apiutil:ErrOptionTypeInvalid"]
60+
error = '''
61+
the type of option %s is invalid
62+
'''
63+
5964
["PD:apiutil:ErrRedirect"]
6065
error = '''
6166
redirect failed

pkg/errs/errno.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ var (
311311
var (
312312
ErrRedirect = errors.Normalize("redirect failed", errors.RFCCodeText("PD:apiutil:ErrRedirect"))
313313
ErrOptionNotExist = errors.Normalize("the option %s does not exist", errors.RFCCodeText("PD:apiutil:ErrOptionNotExist"))
314+
ErrOptionTypeInvalid = errors.Normalize("the type of option %s is invalid", errors.RFCCodeText("PD:apiutil:ErrOptionTypeInvalid"))
314315
ErrRedirectNoLeader = errors.Normalize("redirect finds no leader", errors.RFCCodeText("PD:apiutil:ErrRedirectNoLeader"))
315316
ErrRedirectToNotLeader = errors.Normalize("redirect to not leader", errors.RFCCodeText("PD:apiutil:ErrRedirectToNotLeader"))
316317
ErrRedirectToNotPrimary = errors.Normalize("redirect to not primary", errors.RFCCodeText("PD:apiutil:ErrRedirectToNotPrimary"))

pkg/schedule/schedulers/balance_range.go

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -90,27 +90,52 @@ func (handler *balanceRangeSchedulerHandler) addJob(w http.ResponseWriter, r *ht
9090
Status: pending,
9191
Timeout: defaultJobTimeout,
9292
}
93-
job.Engine = input["engine"].(string)
93+
engine, ok := input["engine"].(string)
94+
if !ok || len(engine) == 0 {
95+
handler.rd.JSON(w, http.StatusBadRequest, "engine is required and must be a string")
96+
return
97+
}
98+
job.Engine = engine
9499
if job.Engine != core.EngineTiFlash && job.Engine != core.EngineTiKV {
95-
handler.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("engine:%s must be tikv or tiflash", input["engine"].(string)))
100+
handler.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("engine:%s must be tikv or tiflash", job.Engine))
96101
return
97102
}
98-
job.Rule = core.NewRule(input["rule"].(string))
103+
ruleStr, ok := input["rule"].(string)
104+
if !ok || len(ruleStr) == 0 {
105+
handler.rd.JSON(w, http.StatusBadRequest, "rule is required and must be a string")
106+
return
107+
}
108+
job.Rule = core.NewRule(ruleStr)
99109
if job.Rule != core.LeaderScatter && job.Rule != core.PeerScatter && job.Rule != core.LearnerScatter {
100110
handler.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("rule:%s must be leader-scatter, learner-scatter or peer-scatter",
101-
input["engine"].(string)))
111+
ruleStr))
102112
return
103113
}
104114

105-
job.Alias = input["alias"].(string)
106-
timeoutStr, ok := input["timeout"].(string)
107-
if ok && len(timeoutStr) > 0 {
108-
timeout, err := time.ParseDuration(timeoutStr)
109-
if err != nil {
110-
handler.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("timeout:%s is invalid", input["timeout"].(string)))
115+
alias, ok := input["alias"].(string)
116+
if !ok || len(alias) == 0 {
117+
handler.rd.JSON(w, http.StatusBadRequest, "alias is required and must be a string")
118+
return
119+
}
120+
job.Alias = alias
121+
if timeoutVal, exists := input["timeout"]; exists {
122+
timeoutStr, ok := timeoutVal.(string)
123+
if !ok {
124+
handler.rd.JSON(w, http.StatusBadRequest, "timeout must be a string")
111125
return
112126
}
113-
job.Timeout = timeout
127+
if len(timeoutStr) > 0 {
128+
timeout, err := time.ParseDuration(timeoutStr)
129+
if err != nil {
130+
handler.rd.JSON(w, http.StatusBadRequest, fmt.Sprintf("timeout:%s is invalid", timeoutStr))
131+
return
132+
}
133+
if timeout <= 0 {
134+
handler.rd.JSON(w, http.StatusBadRequest, "timeout must be positive")
135+
return
136+
}
137+
job.Timeout = timeout
138+
}
114139
}
115140

116141
keys, err := keyutil.DecodeHTTPKeyRanges(input)

pkg/schedule/schedulers/balance_range_test.go

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515
package schedulers
1616

1717
import (
18+
"bytes"
19+
"encoding/json"
1820
"fmt"
21+
"net/http"
22+
"net/http/httptest"
1923
"strconv"
2024
"testing"
2125
"time"
2226

2327
"github.com/stretchr/testify/require"
28+
"github.com/unrolled/render"
2429

2530
"github.com/pingcap/failpoint"
2631

@@ -516,3 +521,79 @@ func TestPersistFail(t *testing.T) {
516521
re.ErrorContains(conf.gcLocked(), errMsg)
517522
re.Len(conf.jobs, 1)
518523
}
524+
525+
func TestAddBalanceRangeJobWithInvalidFieldType(t *testing.T) {
526+
re := require.New(t)
527+
conf := &balanceRangeSchedulerConfig{
528+
schedulerConfig: &baseSchedulerConfig{},
529+
jobs: make([]*balanceRangeSchedulerJob, 0),
530+
}
531+
conf.init("test", storage.NewStorageWithMemoryBackend(), conf)
532+
handler := &balanceRangeSchedulerHandler{
533+
config: conf,
534+
rd: render.New(render.Options{IndentJSON: true}),
535+
}
536+
count := 0
537+
checkFn := func(data []byte, pass bool) {
538+
req := httptest.NewRequest(http.MethodPut, "/job", bytes.NewReader(data))
539+
resp := httptest.NewRecorder()
540+
re.NotPanics(func() {
541+
handler.addJob(resp, req)
542+
})
543+
if pass {
544+
re.Equal(http.StatusOK, resp.Code)
545+
count++
546+
re.Len(conf.jobs, count)
547+
} else {
548+
re.Equal(http.StatusBadRequest, resp.Code)
549+
}
550+
re.Len(conf.jobs, count)
551+
}
552+
553+
// invalid engine type
554+
body, err := json.Marshal(map[string]any{
555+
"alias": "a",
556+
"engine": 1,
557+
"rule": "leader-scatter",
558+
"start-key": "100",
559+
"end-key": "200",
560+
})
561+
re.NoError(err)
562+
checkFn(body, false)
563+
564+
// invalid timeout type
565+
body, err = json.Marshal(map[string]any{
566+
"alias": "a",
567+
"engine": "tikv",
568+
"rule": "leader-scatter",
569+
"start-key": "100",
570+
"end-key": "200",
571+
"timeout": "123",
572+
})
573+
re.NoError(err)
574+
checkFn(body, false)
575+
576+
// normal case
577+
body, err = json.Marshal(map[string]any{
578+
"alias": "a",
579+
"engine": "tikv",
580+
"rule": "leader-scatter",
581+
"start-key": "100",
582+
"end-key": "200",
583+
"timeout": "123s",
584+
})
585+
re.NoError(err)
586+
checkFn(body, true)
587+
588+
// invalidate case
589+
body, err = json.Marshal(map[string]any{
590+
"alias": "a",
591+
"engine": "tikv",
592+
"rule": "leader-scatter",
593+
"start-key": "100",
594+
"end-key": "200",
595+
"timeout": "0s",
596+
})
597+
re.NoError(err)
598+
checkFn(body, false)
599+
}

pkg/schedule/schedulers/grant_hot_region.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,12 @@ func (handler *grantHotRegionHandler) updateConfig(w http.ResponseWriter, r *htt
189189
}
190190
storeIDs = append(storeIDs, id)
191191
}
192-
leaderID, err := strconv.ParseUint(input["store-leader-id"].(string), 10, 64)
192+
leaderStr, ok := input["store-leader-id"].(string)
193+
if !ok {
194+
handler.rd.JSON(w, http.StatusBadRequest, errs.ErrSchedulerConfig)
195+
return
196+
}
197+
leaderID, err := strconv.ParseUint(leaderStr, 10, 64)
193198
if err != nil {
194199
handler.rd.JSON(w, http.StatusBadRequest, errs.ErrBytesToUint64)
195200
return
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright 2026 TiKV Project Authors.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package schedulers
16+
17+
import (
18+
"bytes"
19+
"encoding/json"
20+
"net/http"
21+
"net/http/httptest"
22+
"testing"
23+
24+
"github.com/stretchr/testify/require"
25+
"github.com/unrolled/render"
26+
)
27+
28+
func TestGrantHotRegionUpdateConfigWithInvalidLeaderIDType(t *testing.T) {
29+
re := require.New(t)
30+
handler := &grantHotRegionHandler{
31+
config: &grantHotRegionSchedulerConfig{},
32+
rd: render.New(render.Options{IndentJSON: true}),
33+
}
34+
body, err := json.Marshal(map[string]any{
35+
"store-id": "1,2",
36+
"store-leader-id": 1,
37+
})
38+
re.NoError(err)
39+
req := httptest.NewRequest(http.MethodPost, "/config", bytes.NewReader(body))
40+
resp := httptest.NewRecorder()
41+
re.NotPanics(func() {
42+
handler.updateConfig(resp, req)
43+
})
44+
re.Equal(http.StatusBadRequest, resp.Code)
45+
}

pkg/schedule/schedulers/init.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -596,6 +596,9 @@ func schedulersRegister() {
596596
if err != nil {
597597
return errs.ErrURLParse.Wrap(err)
598598
}
599+
if duration <= 0 {
600+
return errs.ErrURLParse.FastGenByArgs("timeout must be greater than 0")
601+
}
599602
alias, err := url.QueryUnescape(args[3])
600603
if err != nil {
601604
return errs.ErrURLParse.Wrap(err)

pkg/schedule/schedulers/scheduler_controller.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,13 @@ func (c *Controller) CheckTransferWitnessLeader(region *core.RegionInfo) {
429429
s, ok := c.schedulers[types.TransferWitnessLeaderScheduler.String()]
430430
c.RUnlock()
431431
if ok {
432+
regionC := RecvRegionInfo(s.Scheduler)
433+
if regionC == nil {
434+
log.Warn("invalid scheduler type for transfer witness leader", zap.String("scheduler", s.GetName()))
435+
return
436+
}
432437
select {
433-
case RecvRegionInfo(s.Scheduler) <- region:
438+
case regionC <- region:
434439
default:
435440
log.Warn("drop transfer witness leader due to recv region channel full", zap.Uint64("region-id", region.GetID()))
436441
}

pkg/schedule/schedulers/transfer_witness_leader.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,5 +114,8 @@ func scheduleTransferWitnessLeader(name string, cluster sche.SchedulerCluster, r
114114

115115
// RecvRegionInfo receives a checked region from coordinator
116116
func RecvRegionInfo(s Scheduler) chan<- *core.RegionInfo {
117-
return s.(*transferWitnessLeaderScheduler).regions
117+
if scheduler, ok := s.(*transferWitnessLeaderScheduler); ok {
118+
return scheduler.regions
119+
}
120+
return nil
118121
}

pkg/utils/apiutil/apiutil.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,13 +335,18 @@ func CollectEscapeStringOption(option string, input map[string]any, collectors .
335335

336336
// CollectStringOption is used to collect string using from input map for given option
337337
func CollectStringOption(option string, input map[string]any, collectors ...func(v string)) error {
338-
if v, ok := input[option].(string); ok {
339-
for _, c := range collectors {
340-
c(v)
341-
}
342-
return nil
338+
v, exist := input[option]
339+
if !exist {
340+
return errs.ErrOptionNotExist.FastGenByArgs(option)
343341
}
344-
return errs.ErrOptionNotExist.FastGenByArgs(option)
342+
str, ok := v.(string)
343+
if !ok {
344+
return errs.ErrOptionTypeInvalid.FastGenByArgs(option)
345+
}
346+
for _, c := range collectors {
347+
c(str)
348+
}
349+
return nil
345350
}
346351

347352
// ParseKey is used to parse interface into []byte and string

0 commit comments

Comments
 (0)