Skip to content

Commit 05d38eb

Browse files
giordanoclaude
andauthored
Use Base.Semaphore to control test execution parallelism (#119)
* Use Base.Semaphore to control test execution parallelism Replace the fixed worker-task-per-slot model with a semaphore-based approach: one task per test, with a Base.Semaphore(jobs) limiting concurrency and a Channel-based worker pool for reuse. This decouples the number of tasks from the parallelism level and simplifies the control flow (no inner while loop, tests array is immutable). Co-authored-by: Claude <noreply@anthropic.com> Made-with: Cursor * Deal with 0 test jobs * Remove redundant variable * Avoid creating a `Set` * `@async` -> `Threads.@spawn` -> `@sync` * Move `@sync`ed for loop inside `try`/`catch` block * Require Malt v1.4.1 --------- Co-authored-by: Claude <noreply@anthropic.com>
1 parent 142a58e commit 05d38eb

2 files changed

Lines changed: 115 additions & 99 deletions

File tree

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1717
[compat]
1818
Dates = "1"
1919
IOCapture = "0.2.5, 1"
20-
Malt = "1.4.0"
20+
Malt = "1.4.1"
2121
Printf = "1"
2222
Random = "1"
2323
Scratch = "1.3.0"

src/ParallelTestRunner.jl

Lines changed: 114 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,11 @@ function runtests(mod::Module, args::ParsedArgs;
828828
jobs = clamp(jobs, 1, length(tests))
829829
println(stdout, "Running $(length(tests)) tests using $jobs parallel jobs. If this is too many concurrent jobs, specify the `--jobs=N` argument to the tests, or set the `JULIA_CPU_THREADS` environment variable.")
830830
!isnothing(args.verbose) && println(stdout, "Available memory: $(Base.format_bytes(available_memory()))")
831-
workers = fill(nothing, jobs)
831+
sem = Base.Semaphore(max(1, jobs))
832+
worker_pool = Channel{Union{Nothing, PTRWorker}}(jobs)
833+
for _ in 1:jobs
834+
put!(worker_pool, nothing)
835+
end
832836

833837
t0 = time()
834838
results = []
@@ -890,7 +894,7 @@ function runtests(mod::Module, args::ParsedArgs;
890894
# only draw if we have something to show
891895
isempty(running_tests) && return
892896
completed = Base.@lock results_lock length(results)
893-
total = completed + length(tests) + length(running_tests)
897+
total = length(tests)
894898

895899
# line 1: empty line
896900
line1 = ""
@@ -925,6 +929,9 @@ function runtests(mod::Module, args::ParsedArgs;
925929
end
926930
## yet-to-run
927931
for test in tests
932+
haskey(running_tests, test) && continue
933+
# Test is in any completed test
934+
any(r -> test == r.test, results) && continue
928935
est_remaining += get(historical_durations, test, est_per_test)
929936
end
930937

@@ -1007,130 +1014,131 @@ function runtests(mod::Module, args::ParsedArgs;
10071014
end
10081015
isa(ex, InterruptException) || rethrow()
10091016
finally
1010-
if isempty(tests) && isempty(running_tests)
1017+
if isempty(running_tests) && length(results) >= length(tests)
10111018
# XXX: only erase the status if we completed successfully.
10121019
# in other cases we'll have printed "caught interrupt"
10131020
clear_status()
10141021
end
10151022
end
10161023
end
10171024

1018-
10191025
#
10201026
# execution
10211027
#
10221028

1023-
for p in workers
1024-
push!(worker_tasks, @async begin
1025-
while !done
1026-
# get a test to run
1027-
test, test_t0 = Base.@lock test_lock begin
1028-
isempty(tests) && break
1029-
test = popfirst!(tests)
1030-
1031-
test_t0 = time()
1032-
running_tests[test] = test_t0
1029+
tests_to_start = Threads.Atomic{Int}(length(tests))
1030+
try
1031+
@sync for test in tests
1032+
push!(worker_tasks, Threads.@spawn begin
1033+
local p = nothing
1034+
acquired = false
1035+
try
1036+
Base.acquire(sem)
1037+
acquired = true
1038+
p = take!(worker_pool)
1039+
Threads.atomic_sub!(tests_to_start, 1)
1040+
1041+
done && return
1042+
1043+
test_t0 = Base.@lock test_lock begin
1044+
test_t0 = time()
1045+
running_tests[test] = test_t0
1046+
end
10331047

1034-
test, test_t0
1035-
end
1048+
# pass in init_worker_code to custom worker function if defined
1049+
wrkr = if init_worker_code == :()
1050+
test_worker(test)
1051+
else
1052+
test_worker(test, init_worker_code)
1053+
end
1054+
if wrkr === nothing
1055+
wrkr = p
1056+
end
1057+
# if a worker failed, spawn a new one
1058+
if wrkr === nothing || !Malt.isrunning(wrkr)
1059+
wrkr = p = addworker(; init_worker_code, io_ctx.color)
1060+
end
10361061

1037-
# pass in init_worker_code to custom worker function if defined
1038-
wrkr = if init_worker_code == :()
1039-
test_worker(test)
1040-
else
1041-
test_worker(test, init_worker_code)
1042-
end
1043-
if wrkr === nothing
1044-
wrkr = p
1045-
end
1046-
# if a worker failed, spawn a new one
1047-
if wrkr === nothing || !Malt.isrunning(wrkr)
1048-
wrkr = p = addworker(; init_worker_code, io_ctx.color)
1049-
end
1062+
# run the test
1063+
put!(printer_channel, (:started, test, worker_id(wrkr)))
1064+
result = try
1065+
Malt.remote_eval_wait(Main, wrkr.w, :(import ParallelTestRunner))
1066+
Malt.remote_call_fetch(invokelatest, wrkr.w, runtest,
1067+
testsuite[test], test, init_code, test_t0)
1068+
catch ex
1069+
if isa(ex, InterruptException)
1070+
# the worker got interrupted, signal other tasks to stop
1071+
stop_work()
1072+
return
1073+
end
10501074

1051-
# run the test
1052-
put!(printer_channel, (:started, test, worker_id(wrkr)))
1053-
result = try
1054-
Malt.remote_eval_wait(Main, wrkr.w, :(import ParallelTestRunner))
1055-
Malt.remote_call_fetch(invokelatest, wrkr.w, runtest,
1056-
testsuite[test], test, init_code, test_t0)
1057-
catch ex
1058-
if isa(ex, InterruptException)
1059-
# the worker got interrupted, signal other tasks to stop
1060-
stop_work()
1061-
break
1075+
ex
10621076
end
1077+
test_t1 = time()
1078+
output = Base.@lock wrkr.io_lock String(take!(wrkr.io))
1079+
Base.@lock results_lock push!(results, (; test, result, output, test_t0, test_t1))
1080+
1081+
# act on the results
1082+
if result isa AbstractTestRecord
1083+
put!(printer_channel, (:finished, test, worker_id(wrkr), result))
1084+
if anynonpass(result[]) && args.quickfail !== nothing
1085+
stop_work()
1086+
return
1087+
end
10631088

1064-
ex
1065-
end
1066-
test_t1 = time()
1067-
output = Base.@lock wrkr.io_lock String(take!(wrkr.io))
1068-
Base.@lock results_lock push!(results, (test, result, output, test_t0, test_t1))
1089+
if memory_usage(result) > max_worker_rss
1090+
# the worker has reached the max-rss limit, recycle it
1091+
# so future tests start with a smaller working set
1092+
Malt.stop(wrkr)
1093+
end
1094+
else
1095+
# One of Malt.TerminatedWorkerException, Malt.RemoteException, or ErrorException
1096+
@assert result isa Exception
1097+
put!(printer_channel, (:crashed, test, worker_id(wrkr)))
1098+
if args.quickfail !== nothing
1099+
stop_work()
1100+
return
1101+
end
10691102

1070-
# act on the results
1071-
if result isa AbstractTestRecord
1072-
put!(printer_channel, (:finished, test, worker_id(wrkr), result))
1073-
if anynonpass(result[]) && args.quickfail !== nothing
1074-
stop_work()
1075-
break
1103+
# the worker encountered some serious failure, recycle it
1104+
Malt.stop(wrkr)
10761105
end
10771106

1078-
if memory_usage(result) > max_worker_rss
1079-
# the worker has reached the max-rss limit, recycle it
1080-
# so future tests start with a smaller working set
1107+
# get rid of the custom worker
1108+
if wrkr != p
10811109
Malt.stop(wrkr)
10821110
end
1083-
else
1084-
# One of Malt.TerminatedWorkerException, Malt.RemoteException, or ErrorException
1085-
@assert result isa Exception
1086-
put!(printer_channel, (:crashed, test, worker_id(wrkr)))
1087-
if args.quickfail !== nothing
1088-
stop_work()
1089-
break
1090-
end
1091-
1092-
# the worker encountered some serious failure, recycle it
1093-
Malt.stop(wrkr)
1094-
end
10951111

1096-
# get rid of the custom worker
1097-
if wrkr != p
1098-
Malt.stop(wrkr)
1099-
end
1100-
1101-
Base.@lock test_lock begin
1102-
delete!(running_tests, test)
1112+
Base.@lock test_lock begin
1113+
delete!(running_tests, test)
1114+
end
1115+
catch ex
1116+
isa(ex, InterruptException) || rethrow()
1117+
finally
1118+
if acquired
1119+
# stop the worker if no more tests will need one from the pool
1120+
if tests_to_start[] == 0 && p !== nothing && Malt.isrunning(p)
1121+
Malt.stop(p)
1122+
p = nothing
1123+
end
1124+
put!(worker_pool, p)
1125+
Base.release(sem)
1126+
end
11031127
end
1104-
end
1105-
if p !== nothing
1106-
Malt.stop(p)
1107-
end
1108-
end)
1109-
end
1110-
1111-
1112-
#
1113-
# finalization
1114-
#
1115-
1116-
# monitor worker tasks for failure so that each one doesn't need a try/catch + stop_work()
1117-
try
1118-
while true
1119-
if any(istaskfailed, worker_tasks)
1120-
println(io_ctx.stderr, "\nCaught an error, stopping...")
1121-
break
1122-
elseif done || Base.@lock(test_lock, isempty(tests) && isempty(running_tests))
1123-
break
1124-
end
1125-
sleep(1)
1128+
end)
11261129
end
11271130
catch err
1128-
# in case the sleep got interrupted
1129-
isa(err, InterruptException) || rethrow()
1131+
if !(err isa InterruptException)
1132+
println(io_ctx.stderr, "\nCaught an error, stopping...")
1133+
end
11301134
finally
11311135
stop_work()
11321136
end
11331137

1138+
#
1139+
# finalization
1140+
#
1141+
11341142
# wait for the printer to finish so that all results have been printed
11351143
close(printer_channel)
11361144
wait(printer_task)
@@ -1149,6 +1157,14 @@ function runtests(mod::Module, args::ParsedArgs;
11491157
end
11501158
end
11511159

1160+
# clean up remaining workers in the pool
1161+
close(worker_pool)
1162+
for p in worker_pool
1163+
if p !== nothing && Malt.isrunning(p)
1164+
Malt.stop(p)
1165+
end
1166+
end
1167+
11521168
# print the output generated by each testset
11531169
for (testname, result, output, _start, _stop) in results
11541170
if !isempty(output)
@@ -1230,7 +1246,7 @@ function runtests(mod::Module, args::ParsedArgs;
12301246
end
12311247

12321248
# mark remaining or running tests as interrupted
1233-
for test in [tests; collect(keys(running_tests))]
1249+
for test in tests
12341250
(test in completed_tests) && continue
12351251
testset = create_testset(test)
12361252
Test.record(testset, Test.Error(:test_interrupted, test, nothing, Base.ExceptionStack(NamedTuple[(;exception = "skipped", backtrace = [])]), LineNumberNode(1)))

0 commit comments

Comments
 (0)