Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions cmd/explaintest/r/select.result
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,12 @@ Projection_7 10000.00 root 6_aux_0
│ └─TableScan_9 10000.00 cop table:t1, range:[-inf,+inf], keep order:false, stats:pseudo
└─TableReader_12 10000.00 root data:TableScan_11
└─TableScan_11 10000.00 cop table:t2, range:[-inf,+inf], keep order:false, stats:pseudo
drop table if exists t1;
create table t1 (f2 datetime);
insert into t1 values('2001-01-01 01:01:01');
select f2 from t1 where '2001-04-10 12:34:56' between f2 and '01-05-01';
f2
2001-01-01 01:01:01
select 1 from t1 where 20010410123456 between cast('2001-01-01 12:34:56' as datetime) and 010501;
1
drop table t1;
10 changes: 10 additions & 0 deletions cmd/explaintest/t/select.test
Original file line number Diff line number Diff line change
Expand Up @@ -201,3 +201,13 @@ explain select a in (select a+b from t t2 where t2.b = t1.b) from t t1;
drop table t;
create table t(a int not null, b int);
explain select a in (select a from t t2 where t2.b = t1.b) from t t1;

# issue 9764
drop table if exists t1;
create table t1 (f2 datetime);
insert into t1 values('2001-01-01 01:01:01');
# convert string to DATETIME if one of fields in between is DATETIME
select f2 from t1 where '2001-04-10 12:34:56' between f2 and '01-05-01';
# do not do that if one of them is int, convert them to int instead
select 1 from t1 where 20010410123456 between cast('2001-01-01 12:34:56' as datetime) and 010501;
drop table t1;
17 changes: 13 additions & 4 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,18 @@ type compareFunctionClass struct {
op opcode.Op
}

// AggCmpType aggregates extends getBaseCmpType to get type when comparing three or more fields.
func AggCmpType(fields ...*types.FieldType) types.EvalType {
ret := fields[0].EvalType()
for i := 1; i < len(fields); i++ {
ret = getBaseCmpType(ret, fields[i].EvalType(), nil, fields[i])
}
return ret
}

// getBaseCmpType gets the EvalType that the two args will be treated as when comparing.
func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.EvalType {
if lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified {
if lft != nil && rft != nil && (lft.Tp == mysql.TypeUnspecified || rft.Tp == mysql.TypeUnspecified) {
if lft.Tp == rft.Tp {
return types.ETString
}
Expand All @@ -968,10 +977,10 @@ func getBaseCmpType(lhs, rhs types.EvalType, lft, rft *types.FieldType) types.Ev
}
if lhs.IsStringKind() && rhs.IsStringKind() {
return types.ETString
} else if (lhs == types.ETInt || lft.Hybrid()) && (rhs == types.ETInt || rft.Hybrid()) {
} else if (lhs == types.ETInt || (lft != nil && lft.Hybrid())) && (rhs == types.ETInt || (rft != nil && rft.Hybrid())) {
return types.ETInt
} else if ((lhs == types.ETInt || lft.Hybrid()) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || rft.Hybrid()) || rhs == types.ETDecimal) {
} else if ((lhs == types.ETInt || (lft != nil && lft.Hybrid())) || lhs == types.ETDecimal) &&
((rhs == types.ETInt || (rft != nil && rft.Hybrid())) || rhs == types.ETDecimal) {
return types.ETDecimal
}
return types.ETReal
Expand Down
22 changes: 20 additions & 2 deletions planner/core/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -1259,11 +1259,29 @@ func (er *expressionRewriter) betweenToExpression(v *ast.BetweenExpr) {
if er.err != nil {
return
}

expr, lexp, rexp := er.ctxStack[stkLen-3], er.ctxStack[stkLen-2], er.ctxStack[stkLen-1]

if expression.AggCmpType(expr.GetType(), lexp.GetType(), rexp.GetType()) == types.ETString {
Comment thread
qw4990 marked this conversation as resolved.
Outdated
containsDateTime := false
for _, v := range []expression.Expression{expr, lexp, rexp} {
if v.GetType().EvalType() == types.ETDatetime {
containsDateTime = true
break
}
}
if containsDateTime {
expr = expression.WrapWithCastAsTime(er.ctx, expr, types.NewFieldType(mysql.TypeDatetime))
lexp = expression.WrapWithCastAsTime(er.ctx, lexp, types.NewFieldType(mysql.TypeDatetime))
rexp = expression.WrapWithCastAsTime(er.ctx, rexp, types.NewFieldType(mysql.TypeDatetime))
}
}

var op string
var l, r expression.Expression
l, er.err = er.newFunction(ast.GE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-2])
l, er.err = er.newFunction(ast.GE, &v.Type, expr, lexp)
if er.err == nil {
r, er.err = er.newFunction(ast.LE, &v.Type, er.ctxStack[stkLen-3], er.ctxStack[stkLen-1])
r, er.err = er.newFunction(ast.LE, &v.Type, expr, rexp)
}
op = ast.LogicAnd
if er.err != nil {
Expand Down