Skip to content

Commit d44786e

Browse files
committed
68
1 parent 64459be commit d44786e

File tree

1 file changed

+88
-5
lines changed

1 file changed

+88
-5
lines changed

HumanEvalLean/HumanEval68.lean

Lines changed: 88 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,90 @@
11
module
22

3-
def pluck : Unit :=
4-
()
3+
/-! ## Missing API -/
4+
5+
-- The following two instances should not be upstreamed verbatim. Instead, `Bool` should be
6+
-- a genuine linear order.
7+
instance : Std.LawfulOrderOrd Bool where
8+
isLE_compare a b := by cases a <;> cases b <;> simp [compare, LE.le]
9+
isGE_compare a b := by cases a <;> cases b <;> simp [compare, LE.le]
10+
instance : Std.LawfulOrderLT Bool where
11+
lt_iff a b := by cases a <;> cases b <;> simp [LT.lt, LE.le]
12+
13+
-- The lexicographic order on pairs is intentionally not a global instance.
14+
-- However, there should be a `IsLinearOrder` instance for it (etc.)
15+
local instance : Ord (Bool × Nat) := lexOrd
16+
local instance : LE (Bool × Nat) := leOfOrd
17+
local instance : LT (Bool × Nat) := ltOfOrd
18+
local instance : Std.IsLinearOrder (Bool × Nat) := .of_ord
19+
20+
@[simp, grind .]
21+
theorem Bool.not_lt_false {x : Bool} :
22+
¬ x < false := by
23+
simp [LT.lt]
24+
25+
theorem Bool.compare_eq_lt {x y : Bool} :
26+
compare x y = .lt ↔ x = false ∧ y = true := by
27+
simp only [compare]
28+
split <;> grind
29+
30+
theorem compare_lexOrd_eq [Ord α] [Ord β] {x y : α × β} :
31+
haveI : Ord (α × β) := lexOrd
32+
compare x y = (compare x.1 y.1).then (compare x.2 y.2) := by
33+
simp [compare, compareLex, compareOn]
34+
35+
/-! ## Implementation -/
36+
37+
-- (Returning an `Option (Nat × Nat))` would be better, but we will stick to the problem description.)
38+
def pluck (xs : Array Nat) : List Nat :=
39+
let i? := xs.toList.minIdxOn? (fun x => (x % 2 == 1, x))
40+
match h : i? with
41+
| some i =>
42+
have h : i < xs.size := by
43+
simp only [i?, Option.eq_some_iff_get_eq, List.get_minIdxOn?_eq_minIdxOn] at h
44+
simp only [← h.choose_spec]
45+
apply List.minIdxOn_lt_length
46+
let x := xs[i]'h
47+
if x % 2 = 0 then
48+
[x, i]
49+
else
50+
[]
51+
| none => []
52+
53+
/-! ## Tests -/
54+
55+
example : pluck #[4, 2, 3] = [2, 1] := by native_decide
56+
example : pluck #[1, 2, 3] = [2, 1] := by native_decide
57+
example : pluck #[] = [] := by native_decide
58+
example : pluck #[5, 0, 3, 0, 4, 2] = [0, 1] := by native_decide
59+
example : pluck #[1, 2, 3, 0, 5, 3] = [0, 3] := by native_decide
60+
example : pluck #[5, 4, 8, 4, 8] = [4, 1] := by native_decide
61+
example : pluck #[7, 6, 7, 1] = [6, 1] := by native_decide
62+
example : pluck #[7, 9, 7, 1] = [] := by native_decide
63+
64+
/-! ## Verification -/
65+
66+
theorem pluck_empty :
67+
pluck #[] = [] := by
68+
grind [pluck, List.minIdxOn?_nil]
69+
70+
theorem pluck_eq_empty (h : ∀ (i : Nat) (hi : i < xs.size), xs[i] % 2 = 1) :
71+
pluck xs = [] := by
72+
grind [pluck]
73+
74+
theorem pluck_eq_pair {i : Nat} (hi : i < xs.size) (h_even : xs[i] % 2 = 0)
75+
(h : ∀ (j : Nat) (hj : j < xs.size), xs[j] % 2 = 1 ∨ xs[i] ≤ xs[j])
76+
(h' : ∀ (j : Nat) (hj : j < i), xs[j] % 2 = 1 ∨ xs[i] < xs[j]) :
77+
pluck xs = [xs[i], i] := by
78+
rw [pluck]
79+
split
80+
· rename_i j heq
81+
simp only [Option.eq_some_iff_get_eq, List.get_minIdxOn?_eq_minIdxOn] at heq
82+
conv at heq => congr; ext; rw [List.minIdxOn_eq_iff (by grind)]
83+
simp only [LE.le, compare_lexOrd_eq, Ordering.isLE_then_iff_and,
84+
Std.isLE_compare (α := Bool)] at heq
85+
suffices i = j by grind
86+
apply Nat.le_antisymm <;> grind [Std.compare_eq_lt, Std.isLE_compare]
87+
· grind [List.minIdxOn?_eq_if, Array.toList_eq_nil_iff]
588

689
/-!
790
## Prompt
@@ -26,12 +109,12 @@ def pluck(arr):
26109
Example 2:
27110
Input: [1,2,3]
28111
Output: [2, 1]
29-
Explanation: 2 has the smallest even value, and 2 has the smallest index.
112+
Explanation: 2 has the smallest even value, and 2 has the smallest index.
30113
31114
Example 3:
32115
Input: []
33116
Output: []
34-
117+
35118
Example 4:
36119
Input: [5, 0, 3, 0, 4, 2]
37120
Output: [0, 1]
@@ -73,4 +156,4 @@ def check(candidate):
73156
assert candidate([7, 9, 7, 1]) == [], "Error"
74157
75158
```
76-
-/
159+
-/

0 commit comments

Comments
 (0)