Skip to content

Commit 4c7c3e9

Browse files
Base: add record_calls / invalidate_calls for compilation benchmarking
Add `Base.record_calls`, `Base.@record_calls`, and `Base.invalidate_calls` for scoped invalidation of a call's cached native code and inferred IR. These are intended for statistical benchmarking of compilation time (e.g. `@benchmark foo($x) compilation=true` in BenchmarkTools): each sample can invalidate just the target `MethodInstance` and force a fresh codegen on the next dispatch, without bumping the global world counter or propagating to backedges. The C helper `invalidate_method_instance_caches` is renamed to `jl_method_instance_invalidate_caches` and exported via julia.h so the Julia-side wrapper can `ccall` it directly. Co-Authored-By: Claude <[email protected]>
1 parent cacbbe8 commit 4c7c3e9

4 files changed

Lines changed: 92 additions & 4 deletions

File tree

base/reflection.jl

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,63 @@ function method_instance(@nospecialize(f), @nospecialize(t);
8787
return method_instance(tt; world, method_table)
8888
end
8989

90+
"""
91+
Base.record_calls(f, types::Type{<:Tuple}) -> Vector{Core.MethodInstance}
92+
93+
Return the `MethodInstance`s that must be invalidated in order to force
94+
recompilation of the call `f(args...)` where `types === Tuple{map(typeof, args)...}`.
95+
96+
Only the entry-point `MethodInstance` is returned: the next dispatch to it
97+
will re-infer and re-codegen along with any callees that are inlined into
98+
it. Callees that are not inlined remain cached, so invalidating the returned
99+
instances measures the cost of compiling `f` itself (with its inlining
100+
decisions) rather than of rebuilding the entire transitive call tree.
101+
102+
See also [`Base.invalidate_calls`](@ref), [`Base.@record_calls`](@ref).
103+
"""
104+
function record_calls(@nospecialize(f), @nospecialize(types::Type{<:Tuple}))
105+
mi = method_instance(f, types)
106+
mi === nothing && error("could not resolve MethodInstance for ", f, " with argument types ", types)
107+
return Core.MethodInstance[mi]
108+
end
109+
110+
"""
111+
Base.@record_calls f(args...)
112+
113+
Convenience macro equivalent to `Base.record_calls(f, Tuple{map(Core.Typeof, args)...})`.
114+
"""
115+
macro record_calls(ex)
116+
Meta.isexpr(ex, :call) || error("@record_calls expects a call expression, got $(ex)")
117+
f = ex.args[1]
118+
args = ex.args[2:end]
119+
types = Expr(:curly, :Tuple, (:(Core.Typeof($(esc(a)))) for a in args)...)
120+
return :(record_calls($(esc(f)), $types))
121+
end
122+
123+
"""
124+
Base.invalidate_calls(mis)
125+
126+
Invalidate the cached native code and inferred IR for each `MethodInstance`
127+
in `mis`, forcing a recompile the next time each is dispatched. This is a
128+
pure cache flush: it does not bump the global world counter and does not
129+
propagate to backedges, so other code compiled against `mis` is unaffected.
130+
131+
Intended for statistical benchmarking of compilation time. See also
132+
[`Base.record_calls`](@ref).
133+
"""
134+
function invalidate_calls(mis)
135+
world = get_world_counter()
136+
# Cap cache validity *below* the current world so dispatch at this world
137+
# re-enters codegen. We do NOT bump the global world counter (that would
138+
# force everything else to reinfer too).
139+
cap = world == 0 ? UInt(0) : world - UInt(1)
140+
for mi in mis
141+
mi isa Core.MethodInstance || throw(ArgumentError(string("expected Core.MethodInstance, got ", typeof(mi))))
142+
ccall(:jl_method_instance_invalidate_caches, Cvoid, (Any, Csize_t), mi, cap)
143+
end
144+
return nothing
145+
end
146+
90147
default_debug_info_kind() = unsafe_load(cglobal(:jl_default_debug_info_kind, Cint))
91148

92149
# this type mirrors jl_cgparams_t (documented in julia.h)

src/gf.c

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5268,8 +5268,11 @@ JL_DLLEXPORT void jl_extern_c(jl_value_t *name, jl_value_t *declrt, jl_tupletype
52685268
JL_GC_POP();
52695269
}
52705270

5271-
// Drop all method caches and increment world age as if adding a method that intersects everything
5272-
static void invalidate_method_instance_caches(jl_method_instance_t *mi, size_t world)
5271+
// Drop cached native code and inferred IR for a single MethodInstance by
5272+
// setting `max_world` on each live CodeInstance in its cache chain. This does
5273+
// NOT bump the global world counter and does NOT propagate to backedges,
5274+
// making it a pure cache flush suitable for compilation benchmarking.
5275+
JL_DLLEXPORT void jl_method_instance_invalidate_caches(jl_method_instance_t *mi, size_t world)
52735276
{
52745277
if ((jl_value_t*)mi == jl_nothing)
52755278
return;
@@ -5295,12 +5298,12 @@ static int invalidate_all_specializations(jl_typemap_entry_t *def, void *closure
52955298
size_t i, l = jl_svec_len(specializations);
52965299
for (i = 0; i < l; i++) {
52975300
jl_method_instance_t *mi = (jl_method_instance_t*)jl_svecref(specializations, i);
5298-
invalidate_method_instance_caches(mi, world);
5301+
jl_method_instance_invalidate_caches(mi, world);
52995302
}
53005303
}
53015304
else if (specializations != NULL) {
53025305
jl_method_instance_t *mi = (jl_method_instance_t*)specializations;
5303-
invalidate_method_instance_caches(mi, world);
5306+
jl_method_instance_invalidate_caches(mi, world);
53045307
}
53055308
JL_UNLOCK(&method->writelock);
53065309
return 1;

src/julia.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,6 +1932,7 @@ JL_DLLEXPORT jl_code_info_t *jl_copy_code_info(jl_code_info_t *src);
19321932
JL_DLLEXPORT size_t jl_get_world_counter(void) JL_NOTSAFEPOINT;
19331933
JL_DLLEXPORT size_t jl_get_tls_world_age(void) JL_NOTSAFEPOINT;
19341934
JL_DLLEXPORT void jl_drop_all_caches(void);
1935+
JL_DLLEXPORT void jl_method_instance_invalidate_caches(jl_method_instance_t *mi, size_t world);
19351936
JL_DLLEXPORT jl_value_t *jl_box_bool(int8_t x) JL_NOTSAFEPOINT;
19361937
JL_DLLEXPORT jl_value_t *jl_box_int8(int8_t x) JL_NOTSAFEPOINT;
19371938
JL_DLLEXPORT jl_value_t *jl_box_uint8(uint8_t x) JL_NOTSAFEPOINT;

test/worlds.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,30 @@ let
627627
@test rettype_side_effect == "blah"
628628
ccall(:strlen, rettype_with_side_effect(), (Cstring,), "xx")
629629
end
630+
631+
# Base.Experimental.record_calls / invalidate_calls: scoped cache invalidation
632+
# used by BenchmarkTools-style compilation benchmarks. Must drop cached native
633+
# code without bumping world age or touching backedges.
634+
module CompileRecordTest
635+
using Test
636+
recorded_target(x) = x * x + 1
637+
recorded_target(1.0) # force compile
638+
let mis = Base.@record_calls recorded_target(1.0)
639+
@test length(mis) == 1
640+
mi = mis[1]
641+
@test mi isa Core.MethodInstance
642+
ci = @atomic mi.cache
643+
@test ci !== nothing
644+
@test (@atomic ci.max_world) == typemax(UInt)
645+
world_before = Base.get_world_counter()
646+
Base.invalidate_calls(mis)
647+
# World counter must not move (this is the key distinction from ordinary invalidation)
648+
@test Base.get_world_counter() == world_before
649+
# Cached CodeInstance is now capped
650+
@test (@atomic ci.max_world) < typemax(UInt)
651+
# Next call re-specializes; a fresh CodeInstance appears on the cache chain
652+
recorded_target(1.0)
653+
ci2 = @atomic mi.cache
654+
@test ci2 !== ci || (@atomic ci2.max_world) == typemax(UInt)
655+
end
656+
end

0 commit comments

Comments
 (0)