Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
authors = ["Patrick Altmeyer <[email protected]>"]
name = "TaijaParallel"
uuid = "bf1c2c22-5e42-4e78-8b6b-92e6c673eeb0"
version = "1.1.3"
version = "1.2.0"

[compat]
Aqua = "0.8"
Expand Down
1 change: 1 addition & 0 deletions ext/MPIExt/MPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using MPI
using ProgressMeter
using TaijaBase
using TaijaParallel
using TaijaParallel: load_with_retry

"The `MPIParallelizer` type is used to parallelize the evaluation of a function using `MPI.jl`."
struct MPIParallelizer <: TaijaParallel.AbstractParallelizer
Expand Down
2 changes: 1 addition & 1 deletion ext/MPIExt/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function TaijaBase.parallelize(
if parallelizer.rank == 0
outputs = []
for i = 1:length(chunks)
output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
output = load_with_retry(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
# Collect output from all processes in rank 0:
Expand Down
2 changes: 1 addition & 1 deletion ext/MPIExt/generate_counterfactual.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ function TaijaBase.parallelize(
if parallelizer.rank == 0
outputs = []
for i = 1:length(chunks)
output = Serialization.deserialize(joinpath(storage_path, "output_$i.jls"))
output = load_with_retry(joinpath(storage_path, "output_$i.jls"))
push!(outputs, output)
end
# Collect output from all processes in rank 0:
Expand Down
22 changes: 22 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Serialization

"""
chunk_obs(obs::AbstractVector, n_each::Integer, n_groups::Integer)

Expand Down Expand Up @@ -38,3 +40,23 @@ function split_obs(obs::AbstractVector, n::Integer)
N_counts = split_count(N, n)
return split_by_counts(obs, N_counts)
end

"""
load_with_retry(filepath; max_attempts=5, delay=1.0)

Load a file using Serialization.deserialize, retrying up to `max_attempts` times with exponential backoff.
"""
function load_with_retry(filepath; max_attempts=5, delay=1.0)
for attempt in 1:max_attempts
try
return Serialization.deserialize(filepath)
catch e
if isa(e, EOFError) && attempt < max_attempts
sleep(delay * attempt) # Exponential backoff
continue
end
rethrow(e) # Re-throw if it's not an EOFError or we're out of attempts
end
end
error("Failed to load $filepath after $max_attempts attempts")
end
Loading