Skip to content

Commit 4e6af00

Browse files
authored
planner: store the hints of session variable (#45814) (#46046)
close #45812
1 parent c518a89 commit 4e6af00

File tree

8 files changed

+137
-5
lines changed

8 files changed

+137
-5
lines changed

planner/core/plan_cache.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,9 @@ func getPointQueryPlan(stmt *ast.Prepared, sessVars *variable.SessionVars, stmtC
209209
}
210210
sessVars.FoundInPlanCache = true
211211
stmtCtx.PointExec = true
212+
if pointGetPlan, ok := plan.(*PointGetPlan); ok && pointGetPlan != nil && pointGetPlan.stmtHints != nil {
213+
sessVars.StmtCtx.StmtHints = *pointGetPlan.stmtHints
214+
}
212215
return plan, names, true, nil
213216
}
214217

@@ -251,6 +254,7 @@ func getGeneralPlan(sctx sessionctx.Context, isGeneralPlanCache bool, cacheKey k
251254
planCacheCounter.Inc()
252255
}
253256
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
257+
stmtCtx.StmtHints = *cachedVal.stmtHints
254258
return cachedVal.Plan, cachedVal.OutPutNames, true, nil
255259
}
256260

@@ -289,7 +293,7 @@ func generateNewPlan(ctx context.Context, sctx sessionctx.Context, isGeneralPlan
289293
}
290294
sessVars.IsolationReadEngines[kv.TiFlash] = struct{}{}
291295
}
292-
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes)
296+
cached := NewPlanCacheValue(p, names, stmtCtx.TblInfo2UnionScan, paramTypes, &stmtCtx.StmtHints)
293297
stmt.NormalizedPlan, stmt.PlanDigest = NormalizePlan(p)
294298
stmtCtx.SetPlan(p)
295299
stmtCtx.SetPlanDigest(stmt.NormalizedPlan, stmt.PlanDigest)
@@ -687,12 +691,15 @@ func tryCachePointPlan(_ context.Context, sctx sessionctx.Context,
687691
names types.NameSlice
688692
)
689693

690-
if _, _ok := p.(*PointGetPlan); _ok {
694+
if plan, _ok := p.(*PointGetPlan); _ok {
691695
ok, err = IsPointGetWithPKOrUniqueKeyByAutoCommit(sctx, p)
692696
names = p.OutputNames()
693697
if err != nil {
694698
return err
695699
}
700+
if ok {
701+
plan.stmtHints = sctx.GetSessionVars().StmtCtx.StmtHints.Clone()
702+
}
696703
}
697704

698705
if ok {

planner/core/plan_cache_utils.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import (
2828
"github.com/pingcap/tidb/parser/model"
2929
"github.com/pingcap/tidb/parser/mysql"
3030
"github.com/pingcap/tidb/sessionctx"
31+
"github.com/pingcap/tidb/sessionctx/stmtctx"
3132
"github.com/pingcap/tidb/sessionctx/variable"
3233
"github.com/pingcap/tidb/types"
3334
driver "github.com/pingcap/tidb/types/parser_driver"
@@ -348,6 +349,8 @@ type PlanCacheValue struct {
348349
TblInfo2UnionScan map[*model.TableInfo]bool
349350
ParamTypes FieldSlice
350351
memoryUsage int64
352+
// stmtHints stores the hints which set session variables, because the hints won't be processed using cached plan.
353+
stmtHints *stmtctx.StmtHints
351354
}
352355

353356
func (v *PlanCacheValue) varTypesUnchanged(txtVarTps []*types.FieldType) bool {
@@ -395,7 +398,7 @@ func (v *PlanCacheValue) MemoryUsage() (sum int64) {
395398

396399
// NewPlanCacheValue creates a SQLCacheValue.
397400
func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.TableInfo]bool,
398-
paramTypes []*types.FieldType) *PlanCacheValue {
401+
paramTypes []*types.FieldType, stmtHints *stmtctx.StmtHints) *PlanCacheValue {
399402
dstMap := make(map[*model.TableInfo]bool)
400403
for k, v := range srcMap {
401404
dstMap[k] = v
@@ -409,6 +412,7 @@ func NewPlanCacheValue(plan Plan, names []*types.FieldName, srcMap map[*model.Ta
409412
OutPutNames: names,
410413
TblInfo2UnionScan: dstMap,
411414
ParamTypes: userParamTypes,
415+
stmtHints: stmtHints.Clone(),
412416
}
413417
}
414418

planner/core/point_get_plan.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ type PointGetPlan struct {
9696
// probeParents records the IndexJoins and Applys with this operator in their inner children.
9797
// Please see comments in PhysicalPlan for details.
9898
probeParents []PhysicalPlan
99+
// stmtHints should restore in executing context.
100+
stmtHints *stmtctx.StmtHints
99101
}
100102

101103
func (p *PointGetPlan) getEstRowCountForDisplay() float64 {

session/session_test/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ go_test(
2626
"//privilege/privileges",
2727
"//session",
2828
"//sessionctx",
29+
"//sessionctx/stmtctx",
2930
"//sessionctx/variable",
3031
"//store/copr",
3132
"//store/mockstore",

session/session_test/session_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import (
4242
"github.com/pingcap/tidb/privilege/privileges"
4343
"github.com/pingcap/tidb/session"
4444
"github.com/pingcap/tidb/sessionctx"
45+
"github.com/pingcap/tidb/sessionctx/stmtctx"
4546
"github.com/pingcap/tidb/sessionctx/variable"
4647
"github.com/pingcap/tidb/store/copr"
4748
"github.com/pingcap/tidb/store/mockstore"
@@ -4139,3 +4140,63 @@ func TestSQLModeOp(t *testing.T) {
41394140
a = mysql.SetSQLMode(s, mysql.ModeAllowInvalidDates)
41404141
require.Equal(t, mysql.ModeNoBackslashEscapes|mysql.ModeOnlyFullGroupBy|mysql.ModeAllowInvalidDates, a)
41414142
}
4143+
4144+
func TestPrepareExecuteWithSQLHints(t *testing.T) {
4145+
store := testkit.CreateMockStore(t)
4146+
tk := testkit.NewTestKit(t, store)
4147+
se := tk.Session()
4148+
se.SetConnectionID(1)
4149+
tk.MustExec("use test")
4150+
tk.MustExec("create table t(a int primary key)")
4151+
4152+
type hintCheck struct {
4153+
hint string
4154+
check func(*stmtctx.StmtHints)
4155+
}
4156+
4157+
hintChecks := []hintCheck{
4158+
{
4159+
hint: "MEMORY_QUOTA(1024 MB)",
4160+
check: func(stmtHint *stmtctx.StmtHints) {
4161+
require.True(t, stmtHint.HasMemQuotaHint)
4162+
require.Equal(t, int64(1024*1024*1024), stmtHint.MemQuotaQuery)
4163+
},
4164+
},
4165+
{
4166+
hint: "READ_CONSISTENT_REPLICA()",
4167+
check: func(stmtHint *stmtctx.StmtHints) {
4168+
require.True(t, stmtHint.HasReplicaReadHint)
4169+
require.Equal(t, byte(kv.ReplicaReadFollower), stmtHint.ReplicaRead)
4170+
},
4171+
},
4172+
{
4173+
hint: "MAX_EXECUTION_TIME(1000)",
4174+
check: func(stmtHint *stmtctx.StmtHints) {
4175+
require.True(t, stmtHint.HasMaxExecutionTime)
4176+
require.Equal(t, uint64(1000), stmtHint.MaxExecutionTime)
4177+
},
4178+
},
4179+
{
4180+
hint: "USE_TOJA(TRUE)",
4181+
check: func(stmtHint *stmtctx.StmtHints) {
4182+
require.True(t, stmtHint.HasAllowInSubqToJoinAndAggHint)
4183+
require.True(t, stmtHint.AllowInSubqToJoinAndAgg)
4184+
},
4185+
},
4186+
}
4187+
4188+
for i, check := range hintChecks {
4189+
// common path
4190+
tk.MustExec(fmt.Sprintf("prepare stmt%d from 'select /*+ %s */ * from t'", i, check.hint))
4191+
for j := 0; j < 10; j++ {
4192+
tk.MustQuery(fmt.Sprintf("execute stmt%d", i))
4193+
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
4194+
}
4195+
// fast path
4196+
tk.MustExec(fmt.Sprintf("prepare fast%d from 'select /*+ %s */ * from t where a = 1'", i, check.hint))
4197+
for j := 0; j < 10; j++ {
4198+
tk.MustQuery(fmt.Sprintf("execute fast%d", i))
4199+
check.check(&tk.Session().GetSessionVars().StmtCtx.StmtHints)
4200+
}
4201+
}
4202+
}

sessionctx/stmtctx/BUILD.bazel

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ go_test(
3434
],
3535
embed = [":stmtctx"],
3636
flaky = True,
37-
shard_count = 5,
37+
shard_count = 6,
3838
deps = [
3939
"//kv",
4040
"//sessionctx/variable",

sessionctx/stmtctx/stmtctx.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,6 @@ type StatementContext struct {
389389
type StmtHints struct {
390390
// Hint Information
391391
MemQuotaQuery int64
392-
ApplyCacheCapacity int64
393392
MaxExecutionTime uint64
394393
ReplicaRead byte
395394
AllowInSubqToJoinAndAgg bool
@@ -418,6 +417,41 @@ func (sh *StmtHints) TaskMapNeedBackUp() bool {
418417
return sh.ForceNthPlan != -1
419418
}
420419

420+
// Clone the StmtHints struct and returns the pointer of the new one.
421+
func (sh *StmtHints) Clone() *StmtHints {
422+
var (
423+
vars map[string]string
424+
tableHints []*ast.TableOptimizerHint
425+
)
426+
if len(sh.SetVars) > 0 {
427+
vars = make(map[string]string, len(sh.SetVars))
428+
for k, v := range sh.SetVars {
429+
vars[k] = v
430+
}
431+
}
432+
if len(sh.OriginalTableHints) > 0 {
433+
tableHints = make([]*ast.TableOptimizerHint, len(sh.OriginalTableHints))
434+
copy(tableHints, sh.OriginalTableHints)
435+
}
436+
return &StmtHints{
437+
MemQuotaQuery: sh.MemQuotaQuery,
438+
MaxExecutionTime: sh.MaxExecutionTime,
439+
ReplicaRead: sh.ReplicaRead,
440+
AllowInSubqToJoinAndAgg: sh.AllowInSubqToJoinAndAgg,
441+
NoIndexMergeHint: sh.NoIndexMergeHint,
442+
StraightJoinOrder: sh.StraightJoinOrder,
443+
EnableCascadesPlanner: sh.EnableCascadesPlanner,
444+
ForceNthPlan: sh.ForceNthPlan,
445+
HasAllowInSubqToJoinAndAggHint: sh.HasAllowInSubqToJoinAndAggHint,
446+
HasMemQuotaHint: sh.HasMemQuotaHint,
447+
HasReplicaReadHint: sh.HasReplicaReadHint,
448+
HasMaxExecutionTime: sh.HasMaxExecutionTime,
449+
HasEnableCascadesPlannerHint: sh.HasEnableCascadesPlannerHint,
450+
SetVars: vars,
451+
OriginalTableHints: tableHints,
452+
}
453+
}
454+
421455
// StmtCacheKey represents the key type in the StmtCache.
422456
type StmtCacheKey int
423457

sessionctx/stmtctx/stmtctx_test.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"encoding/json"
2020
"fmt"
2121
"math/rand"
22+
"reflect"
2223
"sort"
2324
"testing"
2425
"time"
@@ -272,3 +273,25 @@ func TestApproxRuntimeInfo(t *testing.T) {
272273
require.Equal(t, d.TotBackoffTime[backoff], timeSum)
273274
}
274275
}
276+
277+
func TestStmtHintsClone(t *testing.T) {
278+
hints := stmtctx.StmtHints{}
279+
value := reflect.ValueOf(&hints).Elem()
280+
for i := 0; i < value.NumField(); i++ {
281+
field := value.Field(i)
282+
switch field.Kind() {
283+
case reflect.Int, reflect.Int32, reflect.Int64:
284+
field.SetInt(1)
285+
case reflect.Uint, reflect.Uint32, reflect.Uint64:
286+
field.SetUint(1)
287+
case reflect.Uint8: // byte
288+
field.SetUint(1)
289+
case reflect.Bool:
290+
field.SetBool(true)
291+
case reflect.String:
292+
field.SetString("test")
293+
default:
294+
}
295+
}
296+
require.Equal(t, hints, *hints.Clone())
297+
}

0 commit comments

Comments
 (0)