Skip to content

Commit c25b402

Browse files
author
ceciliav
committed
Add combine trajectories script
1 parent 1866ea9 commit c25b402

2 files changed

Lines changed: 161 additions & 34 deletions

File tree

workflow/rules/phylodynamics.smk

Lines changed: 49 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ checkpoint beast:
1919
folder_name = "results/analysis/{analysis}/chains",
2020
file_name = "{dataset}{sufix,.*}.{chain}.r{i}",
2121
state_file = "{dataset}{sufix,.*}.{chain}.state",
22-
previous_file_name = lambda wildcards: wildcards.dataset + wildcards.sufix + "." + wildcards.chain + ".r" + str(int(wildcards.i) - 1) if wildcards.i != 0 else wildcards.dataset + wildcards.sufix + wildcards.chain + ".r{i}"
22+
previous_file_name = lambda wildcards: wildcards.dataset + wildcards.sufix + "." + wildcards.chain + ".r" + str(int(wildcards.i) - 1) if int(wildcards.i) != 0 else wildcards.dataset + wildcards.sufix + wildcards.chain + ".r{i}"
2323
log:
2424
"logs/beast_{analysis}_{dataset}_{sufix,.*}_{chain}.r{i}.txt"
2525
benchmark:
@@ -45,8 +45,8 @@ checkpoint beast:
4545
4646
else
4747
ACTION="resume"
48-
scp {params.folder_name}/{params.previous_file_name}.log {params.folder_name}/running/{params.file_name}.log
49-
scp {params.folder_name}/{params.previous_file_name}.trees {params.folder_name}/running/{params.file_name}.trees
48+
cp {params.folder_name}/{params.previous_file_name}.log {params.folder_name}/running/{params.file_name}.log
49+
cp {params.folder_name}/{params.previous_file_name}.trees {params.folder_name}/running/{params.file_name}.trees
5050
5151
rm {params.folder_name}/{params.previous_file_name}.log {params.folder_name}/{params.previous_file_name}.trees {params.folder_name}/is_converged_{params.previous_file_name}.txt
5252
fi
@@ -90,6 +90,14 @@ def _is_converged(wildcards):
9090
end = "." + wildcards.output
9191
runs = int(trace[0][trace[0].find(start)+len(start):trace[0].find(end)])
9292

93+
ckpt = checkpoints.beast.get(
94+
analysis=wildcards.analysis,
95+
dataset=wildcards.dataset,
96+
sufix=wildcards.sufix,
97+
chain=wildcards.chain,
98+
i=runs
99+
)
100+
93101
with checkpoints.beast.get(analysis = wildcards.analysis, dataset = wildcards.dataset,
94102
sufix = wildcards.sufix,
95103
chain = wildcards.chain, i = runs).output.is_converged.open() as f:
@@ -106,7 +114,8 @@ def _is_converged(wildcards):
106114

107115
rule aggregate_runs:
108116
input:
109-
run = _is_converged
117+
#run = _is_converged
118+
run = "results/analysis/{analysis}/chains/{dataset}{sufix,.*}.{chain}.r0.{output}",
110119
output:
111120
chain = "results/analysis/{analysis}/chains/{dataset}{sufix,.*}.{chain}.{output}",
112121
log:
@@ -149,67 +158,73 @@ def _get_chains(wildcards):
149158
rule combine_chains:
150159
message:
151160
"""
152-
Combine chain files: {input.chain_files} with LogCombiner.
161+
Combine and downsample chain files: {input.chain_files}.
153162
"""
154163
input:
155164
chain_files = _get_chains
156165
output:
157166
combined_chain = "results/analysis/{analysis}/{dataset}{sufix}.{output}",
158-
# log:
159-
# "logs/combine_trace_{dataset}_{analysis}_{subsampling}.{dseed}.txt"
167+
log:
168+
"logs/combine_chain_{analysis}/{dataset}{sufix}.{output}.txt"
160169
# benchmark:
161170
# "benchmarks/combine_trace_{dataset}_{analysis}_{subsampling}.{dseed}.benchmark.txt"
162171
params:
163172
burnin = lambda wildcards: _get_analysis_param(wildcards, "burnin"),
173+
resample = config["logcombiner"].get("resample"),
174+
resample_traj = config["logcombiner"].get("resample_trajs"),
164175
# input_command = lambda wildcards, input: " -log ".join(input)
165176
input_command = lambda wildcards, input: " -log ".join(
166177
f for f in input.chain_files
167-
if os.path.exists(f) and os.path.getsize(f) > 0)
178+
if os.path.exists(f) and os.path.getsize(f) > 0),
179+
input_traj = lambda wildcards, input: ",".join(
180+
f for f in input.chain_files
181+
if os.path.exists(f) and os.path.getsize(f) > 0
182+
)
168183
shell:
169184
"""
170-
if [ "{wildcards.output}" = "traj" ]; then
171-
172-
first=$(echo "{input.chain_files}" | awk '{{print $1}}')
173-
head -n 1 "$first" > {output.combined_chain}
174-
for f in {input.chain_files}; do
175-
tail -n +2 "$f"
176-
done >> {output.combined_chain}
177-
185+
if [ "{wildcards.output}" = "traj" ]; then
186+
Rscript workflow/snakemake-phylo-beast2/workflow/scripts/combine_trajectories.R \
187+
--input {params.input_traj} \
188+
--output {output.combined_chain} \
189+
--burnin {params.burnin} \
190+
--subsample_n {params.resample_traj} \
191+
2>&1 | tee -a {log}
178192
else
179193
180194
logcombiner \
181195
-log {params.input_command} \
182196
-o {output.combined_chain} \
183197
-b {params.burnin} \
198+
-resample {params.resample} \
184199
2>&1 | tee -a {log}
185200
186201
fi
187202
188203
"""
189204

190-
rule downsample_chains:
191-
input:
192-
file = "results/analysis/{analysis}/{dataset}{sufix}.{output}"
193-
output:
194-
downsampled_file = "results/analysis/{analysis}/{dataset}{sufix}.ds.{output}"
195-
log:
196-
"logs/downsample_{output}_{analysis}_{dataset}{sufix}.txt"
197-
params:
198-
command = config["logcombiner"].get("command"),
199-
resample = config["logcombiner"].get("resample"),
200-
burnin = lambda wildcards: _get_analysis_param(wildcards, "burnin"),
201-
shell:
202-
"""
203-
{params.command} -log {input.file} -o {output.downsampled_file} -b {params.burnin} -resample {params.resample} 2>&1 | tee {log}
204-
"""
205+
# rule downsample_chains:
206+
# input:
207+
# file = "results/analysis/{analysis}/{dataset}{sufix}.{output}"
208+
# output:
209+
# downsampled_file = "results/analysis/{analysis}/{dataset}{sufix}.ds.{output}"
210+
# log:
211+
# "logs/downsample_{output}_{analysis}_{dataset}{sufix}.txt"
212+
# params:
213+
# command = config["logcombiner"].get("command"),
214+
# resample = config["logcombiner"].get("resample"),
215+
# burnin = lambda wildcards: _get_analysis_param(wildcards, "burnin"),
216+
# shell:
217+
# """
218+
# {params.command} -log {input.file} -o {output.downsampled_file} -b {params.burnin} -resample {params.resample} 2>&1 | tee {log}
219+
# """
205220

206221
rule summarize_trees:
207222
message:
208223
"""
209224
Summarize trees to {wildcards.topo} tree with median node heights with TreeAnnotator.
210225
"""
211226
input:
212-
trees = "results/analysis/{analysis}/{dataset}{sufix}.ds.trees"
227+
trees = "results/analysis/{analysis}/{dataset}{sufix}.trees"
213228
output:
214229
summary_tree = "results/analysis/{analysis}/{dataset}{sufix}.{topo}.tree"
215230
log:
@@ -235,8 +250,8 @@ rule stochastic_mapping:
235250
input:
236251
alignment = lambda wildcards: _get_alignment(wildcards),
237252
sm_xml = lambda wildcards: _get_analysis_param(wildcards, "sm_xml"),
238-
trace = "results/analysis/{analysis}/{dataset}{sufix}.ds.log",
239-
trees = "results/analysis/{analysis}/{dataset}{sufix}.ds.trees",
253+
trace = "results/analysis/{analysis}/{dataset}{sufix}.log",
254+
trees = "results/analysis/{analysis}/{dataset}{sufix}.trees",
240255
output:
241256
typed_trees = "results/analysis/{analysis}/{dataset}{sufix}.typed.trees",
242257
typed_node_trees = "results/analysis/{analysis}/{dataset}{sufix}.typed.node.trees",
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# ------------------------------------------------------------------------------
2+
# ---
3+
# / o o \ Project: snakemake-phylo-beast2
4+
# V\ Y /V Combine BDMM-Prime trajectory files
5+
# (\ / - \
6+
# )) / |
7+
# ((/__) || Code by Ceci VA
8+
# ------------------------------------------------------------------------------
9+
10+
11+
# 0. Libraries -----------------------------------------------------------------
12+
library(tidyverse)
13+
library(data.table)
14+
library(optparse)
15+
16+
source("workflow/scripts/utils.R")
17+
18+
combine_trajectories_dt <- function(traj_files, burnin_percentage = 10, subsample_n = NULL) {
19+
dt_l <- lapply(1:length(traj_files), function(i) {
20+
traj_file <- traj_files[[i]]
21+
dt <- fread(
22+
traj_file,
23+
fill = T,
24+
select = c("Sample", "type", "type2", "variable", "value", "age"),
25+
showProgress = TRUE)
26+
27+
burnin_from <- max(dt$Sample) * (burnin_percentage/100)
28+
dt <- dt[Sample >= burnin_from]
29+
# make Sample unique across files
30+
dt[, Sample := as.character(paste0(i, "_", Sample))]
31+
32+
dt[!is.na(variable)]
33+
})
34+
35+
dt_combined <- rbindlist(dt_l, use.names = TRUE)
36+
37+
if (!is.null(subsample_n)) {
38+
available_samples <- unique(dt_combined$Sample)
39+
n_available <- length(available_samples)
40+
41+
if (subsample_n > n_available) {
42+
message(
43+
"Requested subsample_n = ", subsample_n,
44+
" but only ", n_available,
45+
" unique trajectories are available after burnin. Keeping all trajectories."
46+
)
47+
subsample_n <- n_available
48+
}
49+
50+
s <- sample(available_samples, subsample_n)
51+
dt_combined <- dt_combined[Sample %in% s]
52+
}
53+
54+
return(dt_combined)
55+
}
56+
57+
58+
option_list <- list(
59+
make_option(
60+
c("--input"),
61+
type = "character",
62+
action = "store",
63+
help = "Input .traj files as a single space-separated string",
64+
metavar = "files",
65+
dest = "input"
66+
),
67+
make_option(
68+
c("--output"),
69+
type = "character",
70+
action = "store",
71+
help = "Output combined .traj file",
72+
metavar = "file",
73+
dest = "output"
74+
),
75+
make_option(
76+
c("--burnin"),
77+
type = "double",
78+
default = 0.1,
79+
help = "Burn-in fraction [default %default]",
80+
dest = "burnin"
81+
),
82+
make_option(
83+
c("--subsample_n"),
84+
type = "integer",
85+
default = NULL,
86+
help = "Optional number of posterior samples to keep",
87+
dest = "subsample_n"
88+
)
89+
)
90+
91+
parser <- OptionParser(option_list = option_list)
92+
opt <- parse_args(parser)
93+
94+
if (is.null(opt$input) || is.null(opt$output)) {
95+
print_help(parser)
96+
stop("Both --input and --output must be provided.", call. = FALSE)
97+
}
98+
99+
traj_files <- strsplit(opt$input, ",")[[1]]
100+
101+
message("Combining ", length(traj_files), " trajectory files.")
102+
message("Burn-in percentage: ", opt$burnin)
103+
104+
dt_combined <- combine_trajectories_dt(
105+
traj_files = traj_files,
106+
burnin_percentage = opt$burnin,
107+
subsample_n = opt$subsample_n
108+
)
109+
110+
fwrite(dt_combined, file = opt$output, sep = "\t")
111+
message("Wrote combined trajectories to: ", opt$output)
112+

0 commit comments

Comments
 (0)