Skip to content

Commit 37e8241

Browse files
authored
expression: check if period is valid in period_add (#10380)
1 parent 3190ad9 commit 37e8241

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

expression/builtin_time.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4734,6 +4734,11 @@ func (c *periodAddFunctionClass) getFunction(ctx sessionctx.Context, args []Expr
47344734
return sig, nil
47354735
}
47364736

4737+
// validPeriod checks if this period is valid, it comes from MySQL 8.0+.
4738+
func validPeriod(p int64) bool {
4739+
return !(p < 0 || p%100 == 0 || p%100 > 12)
4740+
}
4741+
47374742
// period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM.
47384743
// Note that the period argument is not a date value.
47394744
func period2Month(period uint64) uint64 {
@@ -4785,15 +4790,16 @@ func (b *builtinPeriodAddSig) evalInt(row chunk.Row) (int64, bool, error) {
47854790
return 0, true, err
47864791
}
47874792

4788-
if p == 0 {
4789-
return 0, false, nil
4790-
}
4791-
47924793
n, isNull, err := b.args[1].EvalInt(b.ctx, row)
47934794
if isNull || err != nil {
47944795
return 0, true, err
47954796
}
47964797

4798+
// in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first.
4799+
if !validPeriod(p) {
4800+
return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add")
4801+
}
4802+
47974803
sumMonth := int64(period2Month(uint64(p))) + n
47984804
return int64(month2Period(uint64(sumMonth))), false, nil
47994805
}

expression/builtin_time_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,8 +2171,8 @@ func (s *testEvaluatorSuite) TestPeriodAdd(c *C) {
21712171
{201611, -13, true, 201510},
21722172
{1611, 3, true, 201702},
21732173
{7011, 3, true, 197102},
2174-
{12323, 10, true, 12509},
2175-
{0, 3, true, 0},
2174+
{12323, 10, false, 0},
2175+
{0, 3, false, 0},
21762176
}
21772177

21782178
fc := funcs[ast.PeriodAdd]

expression/integration_test.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1511,10 +1511,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) {
15111511
result.Check(testkit.Rows("123456 10 <nil> <nil>"))
15121512

15131513
// for period_add
1514-
result = tk.MustQuery(`SELECT period_add(191, 2), period_add(191, -2), period_add(0, 20), period_add(0, 0);`)
1515-
result.Check(testkit.Rows("200809 200805 0 0"))
1516-
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("21aa", "11aa"), period_add("", "");`)
1517-
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208 0"))
1514+
result = tk.MustQuery(`SELECT period_add(200807, 2), period_add(200807, -2);`)
1515+
result.Check(testkit.Rows("200809 200805"))
1516+
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("200207aa", "1aa");`)
1517+
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208"))
1518+
for _, errPeriod := range []string{
1519+
"period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')",
1520+
} {
1521+
err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod))
1522+
c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add")
1523+
}
15181524

15191525
// for period_diff
15201526
result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`)

0 commit comments

Comments
 (0)