Skip to content

Commit 5e372e5

Browse files
authored
fix(compiler): Correct exhaustive check on record adt with when pattern (#2355)
1 parent 8a8a906 commit 5e372e5

File tree

10 files changed

+357
-50
lines changed

10 files changed

+357
-50
lines changed

compiler/src/middle_end/linearize.re

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1496,7 +1496,15 @@ and transl_comp_expression =
14961496
| TExpLambda(
14971497
[
14981498
{
1499-
mb_pat: {pat_desc: TPatConstruct({txt: ident}, _, [])},
1499+
mb_pat:
1500+
{
1501+
pat_desc:
1502+
TPatConstruct(
1503+
{txt: ident},
1504+
_,
1505+
TPatConstrSingleton | TPatConstrTuple([]),
1506+
),
1507+
},
15001508
mb_body: body,
15011509
},
15021510
],

compiler/src/middle_end/matchcomp.re

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,15 @@ let rec specialize_matrix = (cd, cur, mtx) => {
387387
| TPatVar(id, _) =>
388388
let wildcards = Parmatch.omegas(arity);
389389
[(wildcards @ ptl, [(id, cur), ...binds])];
390-
| TPatConstruct(_, pcd, args) when cd == pcd => [(args @ ptl, binds)]
390+
| TPatConstruct(_, pcd, TPatConstrRecord(arg)) when cd == pcd => [
391+
([arg] @ ptl, binds),
392+
]
393+
| TPatConstruct(_, pcd, TPatConstrTuple(args)) when cd == pcd => [
394+
(args @ ptl, binds),
395+
]
396+
| TPatConstruct(_, pcd, TPatConstrSingleton) when cd == pcd => [
397+
(ptl, binds),
398+
]
391399
| TPatOr(p1, p2) =>
392400
specialized_rows([p1, ...ptl], binds)
393401
@ specialized_rows([p2, ...ptl], binds)
@@ -1278,7 +1286,19 @@ module MatchTreeCompiler = {
12781286
| TPatAny
12791287
| TPatConstant(_) => []
12801288
| TPatVar(id, _) => [bind(id)]
1281-
| TPatConstruct(_, _, pats)
1289+
| TPatConstruct(
1290+
_,
1291+
_,
1292+
TPatConstrRecord({pat_desc: TPatRecord(_, _) | TPatAny} as pat),
1293+
) =>
1294+
collect_bindings(pat)
1295+
| TPatConstruct(_, _, TPatConstrRecord(_)) =>
1296+
failwith(
1297+
"Impossible: Invalid record constructor pattern `collect_bindings`",
1298+
)
1299+
| TPatConstruct(_, _, TPatConstrTuple(pats)) =>
1300+
List.flatten @@ List.map(collect_bindings, pats)
1301+
| TPatConstruct(_, _, TPatConstrSingleton) => []
12821302
| TPatTuple(pats)
12831303
| TPatArray(pats) => List.flatten @@ List.map(collect_bindings, pats)
12841304
| TPatRecord(pats, _) =>

compiler/src/typed/parmatch.re

Lines changed: 173 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -357,8 +357,34 @@ module Compat =
357357
| (TPatOr(p1, p2), _) => compat(p1, q) || compat(p2, q)
358358
| (_, TPatOr(q1, q2)) => compat(p, q1) || compat(p, q2)
359359
/* Constructors, with special case for extension */
360-
| (TPatConstruct(_, c1, ps1), TPatConstruct(_, c2, ps2)) =>
360+
| (
361+
TPatConstruct(
362+
_,
363+
c1,
364+
TPatConstrRecord({pat_desc: TPatRecord(_, _) | TPatAny} as ps1),
365+
),
366+
TPatConstruct(
367+
_,
368+
c2,
369+
TPatConstrRecord({pat_desc: TPatRecord(_, _) | TPatAny} as ps2),
370+
),
371+
) =>
372+
Constr.equal(c1, c2) && compat(ps1, ps2)
373+
| (
374+
TPatConstruct(_, c1, TPatConstrRecord(_)),
375+
TPatConstruct(_, c2, TPatConstrRecord(_)),
376+
) =>
377+
failwith("Impossible: Invalid record constructor pattern `compat`")
378+
| (
379+
TPatConstruct(_, c1, TPatConstrTuple(ps1)),
380+
TPatConstruct(_, c2, TPatConstrTuple(ps2)),
381+
) =>
361382
Constr.equal(c1, c2) && compats(ps1, ps2)
383+
| (
384+
TPatConstruct(_, c1, TPatConstrSingleton),
385+
TPatConstruct(_, c2, TPatConstrSingleton),
386+
) =>
387+
Constr.equal(c1, c2)
362388
/* More standard stuff */
363389
| (TPatConstant(c1), TPatConstant(c2)) => const_compare(c1, c2) == 0
364390
| (TPatTuple(ps), TPatTuple(qs)) => compats(ps, qs)
@@ -456,14 +482,18 @@ let extract_fields = (fields, arg) => {
456482
let rec simple_match_args = (p1, p2) =>
457483
switch (p2.pat_desc) {
458484
| TPatAlias(p2, _, _) => simple_match_args(p1, p2)
459-
| TPatConstruct(_, _, args) => args
485+
| TPatConstruct(_, _, TPatConstrRecord(arg)) => [arg]
486+
| TPatConstruct(_, _, TPatConstrTuple(args)) => args
487+
| TPatConstruct(_, _, TPatConstrSingleton) => []
460488
| TPatRecord(args, _) => extract_fields(record_arg(p1), args)
461489
| TPatTuple(args) => args
462490
| TPatArray(args) => args
463491
| TPatAny
464492
| TPatVar(_) =>
465493
switch (p1.pat_desc) {
466-
| TPatConstruct(_, _, args) => omega_list(args)
494+
| TPatConstruct(_, _, TPatConstrRecord(arg)) => omega_list([arg])
495+
| TPatConstruct(_, _, TPatConstrTuple(args)) => omega_list(args)
496+
| TPatConstruct(_, _, TPatConstrSingleton) => []
467497
| TPatRecord(args, _) => omega_list(args)
468498
| TPatTuple(args) => omega_list(args)
469499
| TPatArray(args) => omega_list(args)
@@ -493,8 +523,24 @@ let rec normalize_pat = q =>
493523
q.pat_type,
494524
q.pat_env,
495525
)
496-
| TPatConstruct(lid, c, args) =>
497-
make_pat(TPatConstruct(lid, c, omega_list(args)), q.pat_type, q.pat_env)
526+
| TPatConstruct(lid, c, TPatConstrRecord(_)) =>
527+
make_pat(
528+
TPatConstruct(lid, c, TPatConstrRecord(omega)),
529+
q.pat_type,
530+
q.pat_env,
531+
)
532+
| TPatConstruct(lid, c, TPatConstrTuple(args)) =>
533+
make_pat(
534+
TPatConstruct(lid, c, TPatConstrTuple(omega_list(args))),
535+
q.pat_type,
536+
q.pat_env,
537+
)
538+
| TPatConstruct(lid, c, TPatConstrSingleton) =>
539+
make_pat(
540+
TPatConstruct(lid, c, TPatConstrSingleton),
541+
q.pat_type,
542+
q.pat_env,
543+
)
498544
| TPatOr(_) => fatal_error("Parmatch.normalize_pat")
499545
};
500546

@@ -599,9 +645,41 @@ let do_set_args = (erase_mutable, q, r) =>
599645
let args =
600646
List.map2(((lid, lbl, _), arg) => (lid, lbl, arg), omegas, args);
601647
[make_pat(TPatRecord(args, closed), q.pat_type, q.pat_env), ...rest];
602-
| {pat_desc: TPatConstruct(lid, c, omegas)} =>
648+
| {pat_desc: TPatConstruct(lid, c, TPatConstrRecord(omega))} =>
649+
let (arg, rest) =
650+
switch (read_args([omega], r)) {
651+
| ([arg], rest) => (arg, rest)
652+
| _ =>
653+
failwith(
654+
"Impossible: Invalid record constructor pattern `do_set_args`",
655+
)
656+
};
657+
[
658+
make_pat(
659+
TPatConstruct(lid, c, TPatConstrRecord(arg)),
660+
q.pat_type,
661+
q.pat_env,
662+
),
663+
...rest,
664+
];
665+
| {pat_desc: TPatConstruct(lid, c, TPatConstrTuple(omegas))} =>
603666
let (args, rest) = read_args(omegas, r);
604-
[make_pat(TPatConstruct(lid, c, args), q.pat_type, q.pat_env), ...rest];
667+
[
668+
make_pat(
669+
TPatConstruct(lid, c, TPatConstrTuple(args)),
670+
q.pat_type,
671+
q.pat_env,
672+
),
673+
...rest,
674+
];
675+
| {pat_desc: TPatConstruct(lid, c, TPatConstrSingleton)} => [
676+
make_pat(
677+
TPatConstruct(lid, c, TPatConstrSingleton),
678+
q.pat_type,
679+
q.pat_env,
680+
),
681+
...r,
682+
]
605683
| {pat_desc: TPatArray(omegas)} =>
606684
let (args, rest) = read_args(omegas, r);
607685
[make_pat(TPatArray(args), q.pat_type, q.pat_env), ...rest];
@@ -891,7 +969,15 @@ let pat_of_constr = (ex_pat, cstr) => {
891969
TPatConstruct(
892970
mknoloc(Identifier.IdentName(mknoloc("?pat_of_constr?"))),
893971
cstr,
894-
omegas(cstr.cstr_arity),
972+
switch (cstr.cstr_arity, cstr.cstr_inlined) {
973+
| (0, None) => TPatConstrSingleton
974+
| (_, None) => TPatConstrTuple(omegas(cstr.cstr_arity))
975+
| (1, Some(_)) => TPatConstrRecord(omega)
976+
| (_, Some(_)) =>
977+
failwith(
978+
"Impossible: Invalid record constructor pattern `pat_of_cnstr`",
979+
)
980+
},
895981
),
896982
};
897983

@@ -1200,7 +1286,9 @@ let rec has_instance = p =>
12001286
| TPatConstant(_) => true
12011287
| TPatAlias(p, _, _) => has_instance(p)
12021288
| TPatOr(p1, p2) => has_instance(p1) || has_instance(p2)
1203-
| TPatConstruct(_, _, ps)
1289+
| TPatConstruct(_, _, TPatConstrRecord(p)) => has_instance(p)
1290+
| TPatConstruct(_, _, TPatConstrTuple(ps)) => has_instances(ps)
1291+
| TPatConstruct(_, _, TPatConstrSingleton) => true
12041292
| TPatTuple(ps)
12051293
| TPatArray(ps) => has_instances(ps)
12061294
| TPatRecord(fields, _) =>
@@ -1897,8 +1985,21 @@ let rec le_pat = (p, q) =>
18971985
| (TPatAlias(p, _, _), _) => le_pat(p, q)
18981986
| (_, TPatAlias(q, _, _)) => le_pat(p, q)
18991987
| (TPatConstant(c1), TPatConstant(c2)) => const_compare(c1, c2) == 0
1900-
| (TPatConstruct(_, c1, ps), TPatConstruct(_, c2, qs)) =>
1988+
| (
1989+
TPatConstruct(_, c1, TPatConstrRecord(p)),
1990+
TPatConstruct(_, c2, TPatConstrRecord(q)),
1991+
) =>
1992+
Types.equal_tag(c1.cstr_tag, c2.cstr_tag) && le_pat(p, q)
1993+
| (
1994+
TPatConstruct(_, c1, TPatConstrTuple(ps)),
1995+
TPatConstruct(_, c2, TPatConstrTuple(qs)),
1996+
) =>
19011997
Types.equal_tag(c1.cstr_tag, c2.cstr_tag) && le_pats(ps, qs)
1998+
| (
1999+
TPatConstruct(_, c1, TPatConstrSingleton),
2000+
TPatConstruct(_, c2, TPatConstrSingleton),
2001+
) =>
2002+
Types.equal_tag(c1.cstr_tag, c2.cstr_tag)
19022003
| (TPatArray(ps), TPatArray(qs)) =>
19032004
List.length(ps) == List.length(qs) && le_pats(ps, qs)
19042005
| (TPatTuple(ps), TPatTuple(qs)) => le_pats(ps, qs)
@@ -1945,10 +2046,38 @@ let rec lub = (p, q) =>
19452046
| (TPatTuple(ps), TPatTuple(qs)) =>
19462047
let rs = lubs(ps, qs);
19472048
make_pat(TPatTuple(rs), p.pat_type, p.pat_env);
1948-
| (TPatConstruct(lid, c1, ps1), TPatConstruct(_, c2, ps2))
2049+
| (
2050+
TPatConstruct(lid, c1, TPatConstrRecord(p1)),
2051+
TPatConstruct(_, c2, TPatConstrRecord(p2)),
2052+
)
2053+
when Types.equal_tag(c1.cstr_tag, c2.cstr_tag) =>
2054+
let rs = lub(p1, p2);
2055+
make_pat(
2056+
TPatConstruct(lid, c1, TPatConstrRecord(rs)),
2057+
p.pat_type,
2058+
p.pat_env,
2059+
);
2060+
| (
2061+
TPatConstruct(lid, c1, TPatConstrTuple(ps1)),
2062+
TPatConstruct(_, c2, TPatConstrTuple(ps2)),
2063+
)
19492064
when Types.equal_tag(c1.cstr_tag, c2.cstr_tag) =>
19502065
let rs = lubs(ps1, ps2);
1951-
make_pat(TPatConstruct(lid, c1, rs), p.pat_type, p.pat_env);
2066+
make_pat(
2067+
TPatConstruct(lid, c1, TPatConstrTuple(rs)),
2068+
p.pat_type,
2069+
p.pat_env,
2070+
);
2071+
| (
2072+
TPatConstruct(lid, c1, TPatConstrSingleton),
2073+
TPatConstruct(_, c2, TPatConstrSingleton),
2074+
)
2075+
when Types.equal_tag(c1.cstr_tag, c2.cstr_tag) =>
2076+
make_pat(
2077+
TPatConstruct(lid, c1, TPatConstrSingleton),
2078+
p.pat_type,
2079+
p.pat_env,
2080+
)
19522081
| (TPatRecord(l1, closed), TPatRecord(l2, _)) =>
19532082
let rs = record_lubs(l1, l2);
19542083
make_pat(TPatRecord(rs, closed), p.pat_type, p.pat_env);
@@ -2141,7 +2270,7 @@ module Conv = {
21412270
c,
21422271
),
21432272
)
2144-
| TPatConstruct(cstr_lid, cstr, lst) =>
2273+
| TPatConstruct(cstr_lid, cstr, arg) =>
21452274
let id = fresh(cstr.cstr_name);
21462275
let lid = {
21472276
...cstr_lid,
@@ -2152,25 +2281,33 @@ module Conv = {
21522281
}),
21532282
};
21542283
Hashtbl.add(constrs, id, cstr);
2155-
switch (lst) {
2156-
| [{pat_desc: TPatRecord(fields, closed)}]
2157-
when cstr.cstr_inlined != None =>
2158-
mkpat(
2159-
~loc=pat.pat_loc,
2284+
mkpat(
2285+
~loc=pat.pat_loc,
2286+
switch (arg, cstr.cstr_inlined) {
2287+
| (
2288+
TPatConstrRecord({pat_desc: TPatRecord(fields, closed)}),
2289+
Some(_),
2290+
) =>
21602291
PPatConstruct(
21612292
lid,
21622293
PPatConstrRecord(
21632294
List.map(((id, _, p)) => (id, loop(p)), fields),
21642295
closed,
21652296
),
2166-
),
2167-
)
2168-
| _ =>
2169-
mkpat(
2170-
~loc=pat.pat_loc,
2171-
PPatConstruct(lid, PPatConstrTuple(List.map(loop, lst))),
2172-
)
2173-
};
2297+
)
2298+
| (TPatConstrRecord({pat_desc: TPatAny}), Some(_)) =>
2299+
// TODO: Validate
2300+
PPatConstruct(lid, PPatConstrRecord([], Open))
2301+
| (TPatConstrTuple(args), None) =>
2302+
PPatConstruct(lid, PPatConstrTuple(List.map(loop, args)))
2303+
| (TPatConstrSingleton, None) =>
2304+
PPatConstruct(lid, PPatConstrSingleton)
2305+
| (_, Some(_)) =>
2306+
failwith("Impossible: Invalid record constructor pattern `conv`")
2307+
| (TPatConstrRecord(_), None) =>
2308+
failwith("Impossible: invalid constructor pattern")
2309+
},
2310+
);
21742311
};
21752312

21762313
let ps = loop(typed);
@@ -2300,7 +2437,11 @@ let rec collect_paths_from_pat = (r, p) =>
23002437
} else {
23012438
r;
23022439
},
2303-
ps,
2440+
switch (ps) {
2441+
| TPatConstrSingleton => []
2442+
| TPatConstrTuple(ps) => ps
2443+
| TPatConstrRecord(p) => [p]
2444+
},
23042445
);
23052446
| TPatAny
23062447
| TPatVar(_)
@@ -2489,8 +2630,11 @@ let inactive = (~partial, pat) =>
24892630
| Const_bigint(_)
24902631
| Const_rational(_) => true
24912632
}
2492-
| TPatTuple(ps)
2493-
| TPatConstruct(_, _, ps) => List.for_all(p => loop(p), ps)
2633+
| TPatTuple(ps) => List.for_all(p => loop(p), ps)
2634+
| TPatConstruct(_, _, TPatConstrRecord(p)) => loop(p)
2635+
| TPatConstruct(_, _, TPatConstrTuple(ps)) =>
2636+
List.for_all(p => loop(p), ps)
2637+
| TPatConstruct(_, _, TPatConstrSingleton) => true
24942638
| TPatRecord(fields, _) =>
24952639
List.for_all(((_, _, p)) => loop(p), fields)
24962640
| TPatAlias(p, _, _) => loop(p)

0 commit comments

Comments
 (0)