Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1687,6 +1687,23 @@ func (i inCastContext) String() string {
// @see BuildCastFunction4Union
const inUnionCastContext inCastContext = 0

// hasSpecialCast checks if this expr has its own special cast function.
// for example(#9713): when doing arithmetic using results of function DayName,
// "Monday" should be regarded as 0, "Tuesday" should be regarded as 1 and so on.
func hasSpecialCast(ctx sessionctx.Context, expr Expression, tp *types.FieldType) bool {
switch f := expr.(type) {
case *ScalarFunction:
switch f.FuncName.L {
case ast.DayName:
switch tp.EvalType() {
case types.ETInt, types.ETReal:
Comment thread
alivxxx marked this conversation as resolved.
return true
}
}
}
return false
}

// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union
// Expression.
func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
Expand All @@ -1699,6 +1716,10 @@ func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types.

// BuildCastFunction builds a CAST ScalarFunction from the Expression.
func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
if hasSpecialCast(ctx, expr, tp) {
Comment thread
zz-jason marked this conversation as resolved.
return expr
}

var fc functionClass
switch tp.EvalType() {
case types.ETInt:
Expand Down
36 changes: 30 additions & 6 deletions expression/builtin_time.go
Original file line number Diff line number Diff line change
Expand Up @@ -1079,21 +1079,45 @@ func (b *builtinDayNameSig) Clone() builtinFunc {
return newSig
}

// evalString evals a builtinDayNameSig.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayname
func (b *builtinDayNameSig) evalString(row chunk.Row) (string, bool, error) {
func (b *builtinDayNameSig) evalIndex(row chunk.Row) (int64, bool, error) {
arg, isNull, err := b.args[0].EvalTime(b.ctx, row)
if isNull || err != nil {
return "", isNull, err
return 0, isNull, err
}
if arg.InvalidZero() {
return "", true, handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenWithStackByArgs(arg.String()))
return 0, true, handleInvalidTimeError(b.ctx, types.ErrIncorrectDatetimeValue.GenWithStackByArgs(arg.String()))
}
// Monday is 0, ... Sunday = 6 in MySQL
// but in go, Sunday is 0, ... Saturday is 6
// w will do a conversion.
res := (int64(arg.Time.Weekday()) + 6) % 7
return types.WeekdayNames[res], false, nil
return res, false, nil
}

// evalString evals a builtinDayNameSig.
// See https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_dayname
func (b *builtinDayNameSig) evalString(row chunk.Row) (string, bool, error) {
idx, isNull, err := b.evalIndex(row)
if isNull || err != nil {
return "", isNull, err
}
return types.WeekdayNames[idx], false, nil
}

func (b *builtinDayNameSig) evalReal(row chunk.Row) (float64, bool, error) {
Comment thread
qw4990 marked this conversation as resolved.
idx, isNull, err := b.evalIndex(row)
if isNull || err != nil {
return 0, isNull, err
}
return float64(idx), false, nil
}

func (b *builtinDayNameSig) evalInt(row chunk.Row) (int64, bool, error) {
idx, isNull, err := b.evalIndex(row)
if isNull || err != nil {
return 0, isNull, err
}
return idx, false, nil
}

type dayOfMonthFunctionClass struct {
Expand Down
1 change: 1 addition & 0 deletions expression/function_traits.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ var unFoldableFunctions = map[string]struct{}{
ast.GetVar: {},
ast.GetParam: {},
ast.Benchmark: {},
ast.DayName: {},
Comment thread
alivxxx marked this conversation as resolved.
}

// DisableFoldFunctions stores functions which prevent child scope functions from being constant folded.
Expand Down
48 changes: 48 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4134,3 +4134,51 @@ func (s *testIntegrationSuite) TestDecimalConvertToTime(c *C) {
tk.MustExec("insert t values (20010101100000.123456, 20110707101112.123456)")
tk.MustQuery("select * from t").Check(testkit.Rows("2001-01-01 10:00:00.123456 2011-07-07 10:11:12"))
}

func (s *testIntegrationSuite) TestDaynameArithmetic(c *C) {
tk := testkit.NewTestKit(c, s.store)
defer s.cleanEnv(c)

cases := []struct {
sql string
result string
}{
{`select dayname("1962-03-01")+0;`, "3"},
{`select dayname("1962-03-02")+0;`, "4"},
{`select dayname("1962-03-03")+0;`, "5"},
{`select dayname("1962-03-04")+0;`, "6"},
{`select dayname("1962-03-05")+0;`, "0"},
{`select dayname("1962-03-06")+0;`, "1"},
{`select dayname("1962-03-07")+0;`, "2"},
{`select dayname("1962-03-08")+0;`, "3"},
{`select dayname("1962-03-01")+1;`, "4"},
{`select dayname("1962-03-01")+2;`, "5"},
{`select dayname("1962-03-01")+3;`, "6"},
{`select dayname("1962-03-01")+4;`, "7"},
{`select dayname("1962-03-01")+5;`, "8"},
{`select dayname("1962-03-01")+6;`, "9"},
{`select dayname("1962-03-01")+7;`, "10"},
{`select dayname("1962-03-01")+2333;`, "2336"},
{`select dayname("1962-03-01")+2.333;`, "5.333"},
{`select dayname("1962-03-01")>2;`, "1"},
{`select dayname("1962-03-01")<2;`, "0"},
{`select dayname("1962-03-01")=3;`, "1"},
{`select dayname("1962-03-01")!=3;`, "0"},
{`select dayname("1962-03-01")<4;`, "1"},
{`select dayname("1962-03-01")>4;`, "0"},
{`select !dayname("1962-03-01");`, "0"},
{`select dayname("1962-03-01")&1;`, "1"},
{`select dayname("1962-03-01")&3;`, "3"},
{`select dayname("1962-03-01")&7;`, "3"},
{`select dayname("1962-03-01")|1;`, "3"},
{`select dayname("1962-03-01")|3;`, "3"},
{`select dayname("1962-03-01")|7;`, "7"},
{`select dayname("1962-03-01")^1;`, "2"},
{`select dayname("1962-03-01")^3;`, "0"},
{`select dayname("1962-03-01")^7;`, "4"},
}

for _, c := range cases {
tk.MustQuery(c.sql).Check(testkit.Rows(c.result))
}
}