Skip to content

Commit 7bcb77c

Browse files
committed
expression: check if period is valid in period_add (pingcap#10380)
1 parent 570a315 commit 7bcb77c

File tree

3 files changed

+22
-10
lines changed

3 files changed

+22
-10
lines changed

expression/builtin_time.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4887,6 +4887,11 @@ func (c *periodAddFunctionClass) getFunction(ctx sessionctx.Context, args []Expr
48874887
return sig, nil
48884888
}
48894889

4890+
// validPeriod checks if this period is valid, it comes from MySQL 8.0+.
4891+
func validPeriod(p int64) bool {
4892+
return !(p < 0 || p%100 == 0 || p%100 > 12)
4893+
}
4894+
48904895
// period2Month converts a period to months, in which period is represented in the format of YYMM or YYYYMM.
48914896
// Note that the period argument is not a date value.
48924897
func period2Month(period uint64) uint64 {
@@ -4938,15 +4943,16 @@ func (b *builtinPeriodAddSig) evalInt(row chunk.Row) (int64, bool, error) {
49384943
return 0, true, errors.Trace(err)
49394944
}
49404945

4941-
if p == 0 {
4942-
return 0, false, nil
4943-
}
4944-
49454946
n, isNull, err := b.args[1].EvalInt(b.ctx, row)
49464947
if isNull || err != nil {
49474948
return 0, true, errors.Trace(err)
49484949
}
49494950

4951+
// in MySQL, if p is invalid but n is NULL, the result is NULL, so we have to check if n is NULL first.
4952+
if !validPeriod(p) {
4953+
return 0, false, errIncorrectArgs.GenWithStackByArgs("period_add")
4954+
}
4955+
49504956
sumMonth := int64(period2Month(uint64(p))) + n
49514957
return int64(month2Period(uint64(sumMonth))), false, nil
49524958
}

expression/builtin_time_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2146,8 +2146,8 @@ func (s *testEvaluatorSuite) TestPeriodAdd(c *C) {
21462146
{201611, -13, true, 201510},
21472147
{1611, 3, true, 201702},
21482148
{7011, 3, true, 197102},
2149-
{12323, 10, true, 12509},
2150-
{0, 3, true, 0},
2149+
{12323, 10, false, 0},
2150+
{0, 3, false, 0},
21512151
}
21522152

21532153
fc := funcs[ast.PeriodAdd]

expression/integration_test.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1438,10 +1438,16 @@ func (s *testIntegrationSuite) TestTimeBuiltin(c *C) {
14381438
result.Check(testkit.Rows("123456 10 <nil> <nil>"))
14391439

14401440
// for period_add
1441-
result = tk.MustQuery(`SELECT period_add(191, 2), period_add(191, -2), period_add(0, 20), period_add(0, 0);`)
1442-
result.Check(testkit.Rows("200809 200805 0 0"))
1443-
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("21aa", "11aa"), period_add("", "");`)
1444-
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208 0"))
1441+
result = tk.MustQuery(`SELECT period_add(200807, 2), period_add(200807, -2);`)
1442+
result.Check(testkit.Rows("200809 200805"))
1443+
result = tk.MustQuery(`SELECT period_add(NULL, 2), period_add(-191, NULL), period_add(NULL, NULL), period_add(12.09, -2), period_add("200207aa", "1aa");`)
1444+
result.Check(testkit.Rows("<nil> <nil> <nil> 200010 200208"))
1445+
for _, errPeriod := range []string{
1446+
"period_add(0, 20)", "period_add(0, 0)", "period_add(-1, 1)", "period_add(200013, 1)", "period_add(-200012, 1)", "period_add('', '')",
1447+
} {
1448+
err := tk.QueryToErr(fmt.Sprintf("SELECT %v;", errPeriod))
1449+
c.Assert(err.Error(), Equals, "[expression:1210]Incorrect arguments to period_add")
1450+
}
14451451

14461452
// for period_diff
14471453
result = tk.MustQuery(`SELECT period_diff(191, 2), period_diff(191, -2), period_diff(0, 0), period_diff(191, 191);`)

0 commit comments

Comments
 (0)