Skip to content

Commit f0e837e

Browse files
committed
Further propagation of fso and startTypePriorProbs type change
1 parent c047661 commit f0e837e

2 files changed

Lines changed: 34 additions & 29 deletions

File tree

src/main/java/bdmmprime/mapping/TypeMappedTree.java

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,14 @@
2121
import bdmmprime.parameterization.Parameterization;
2222
import bdmmprime.util.Utils;
2323
import beast.base.core.Citation;
24-
import beast.base.core.Function;
2524
import beast.base.core.Input;
2625
import beast.base.evolution.tree.Node;
2726
import beast.base.evolution.tree.TraitSet;
2827
import beast.base.evolution.tree.Tree;
29-
import beast.base.inference.parameter.RealParameter;
28+
import beast.base.spec.domain.NonNegativeReal;
29+
import beast.base.spec.inference.parameter.RealScalarParam;
30+
import beast.base.spec.type.RealScalar;
31+
import beast.base.spec.type.Simplex;
3032
import beast.base.util.Randomizer;
3133
import org.apache.commons.math3.ode.ContinuousOutputModel;
3234
import org.apache.commons.math3.ode.FirstOrderIntegrator;
@@ -72,13 +74,13 @@ public class TypeMappedTree extends Tree {
7274
"BDMM parameterization",
7375
Input.Validate.REQUIRED);
7476

75-
public Input<RealParameter> startTypePriorProbsInput = new Input<>("startTypePriorProbs",
77+
public Input<Simplex> startTypePriorProbsInput = new Input<>("startTypePriorProbs",
7678
"The prior probabilities for the initial individual type",
7779
Input.Validate.REQUIRED);
7880

79-
public Input<Function> finalSampleOffsetInput = new Input<>("finalSampleOffset",
81+
public Input<RealScalar<? extends NonNegativeReal>> finalSampleOffsetInput = new Input<>("finalSampleOffset",
8082
"If provided, the difference in time between the final sample and the end of the BD process.",
81-
new RealParameter("0.0"));
83+
new RealScalarParam<>(0.0, NonNegativeReal.INSTANCE));
8284

8385
public Input<TraitSet> typeTraitSetInput = new Input<>("typeTraitSet",
8486
"Trait information for initializing traits " +
@@ -107,7 +109,7 @@ public class TypeMappedTree extends Tree {
107109
Input.Validate.XOR, parameterizationInput);
108110

109111
private Parameterization param;
110-
private Function finalSampleOffset;
112+
private RealScalar<? extends NonNegativeReal> finalSampleOffset;
111113
private Tree untypedTree;
112114

113115
private ODESystem odeSystem;
@@ -183,7 +185,7 @@ private void doStochasticMapping() {
183185
double[] startTypeProbs = new double[param.getNTypes()];
184186

185187
for (int type=0; type<param.getNTypes(); type++)
186-
startTypeProbs[type] = y[type+param.getNTypes()]* startTypePriorProbsInput.get().getValue(type);
188+
startTypeProbs[type] = y[type+param.getNTypes()]* startTypePriorProbsInput.get().get(type);
187189

188190
// (startTypeProbs are unnormalized: this is okay for randomChoicePDF.)
189191
int startType = Randomizer.randomChoicePDF(startTypeProbs);
@@ -263,7 +265,7 @@ private void computeRhoSampledLeafStatus() {
263265
rhoSamplingIndex = new int[untypedTree.getLeafNodeCount()];
264266

265267
for (int nodeNr=0; nodeNr < treeInput.get().getLeafNodeCount(); nodeNr++) {
266-
double nodeTime = param.getNodeTime(untypedTree.getNode(nodeNr), finalSampleOffset.getArrayValue());
268+
double nodeTime = param.getNodeTime(untypedTree.getNode(nodeNr), finalSampleOffset.get());
267269
rhoSampled[nodeNr] = false;
268270
for (double rhoSamplingTime : param.getRhoSamplingTimes()) {
269271
if (Utils.equalWithPrecision(nodeTime, rhoSamplingTime)) {
@@ -360,7 +362,7 @@ private NodeKind getNodeKind(Node node) {
360362

361363
double delta = 2*Utils.globalPrecisionThreshold;
362364

363-
double timeOfSubtreeRootEdgeBottom = param.getNodeTime(untypedSubtreeRoot, finalSampleOffset.getArrayValue());
365+
double timeOfSubtreeRootEdgeBottom = param.getNodeTime(untypedSubtreeRoot, finalSampleOffset.get());
364366

365367
odeIntegrator.addEventHandler(odeSystem,
366368
(timeOfSubtreeRootEdgeBottom-timeOfSubtreeRootEdgeTop)/RATE_CHANGE_CHECKS_PER_EDGE,
@@ -388,7 +390,7 @@ private double[] getLeafState(Node leafNode) {
388390
y[param.getNTypes()+type] = 0.0;
389391
}
390392

391-
double leafTime = param.getNodeTime(leafNode, finalSampleOffset.getArrayValue());
393+
double leafTime = param.getNodeTime(leafNode, finalSampleOffset.get());
392394
double T = param.getTotalProcessLength();
393395

394396
if (Utils.lessThanWithPrecision(leafTime, T)) {
@@ -439,7 +441,7 @@ private double[] getLeafState(Node leafNode) {
439441

440442
} else {
441443

442-
int nodeInterval = param.getNodeIntervalIndex(leafNode, finalSampleOffset.getArrayValue());
444+
int nodeInterval = param.getNodeIntervalIndex(leafNode, finalSampleOffset.get());
443445

444446
if (leafTypeKnown) {
445447
for (int type = 0; type < param.getNTypes(); type++) {
@@ -473,7 +475,7 @@ private double[] getLeafState(Node leafNode) {
473475

474476
private double[] getSAState(Node saNode) {
475477

476-
double saNodeTime = param.getNodeTime(saNode, finalSampleOffset.getArrayValue());
478+
double saNodeTime = param.getNodeTime(saNode, finalSampleOffset.get());
477479

478480
double[] y = backwardsIntegrateSubtree(saNode.getNonDirectAncestorChild(), saNodeTime);
479481

@@ -510,7 +512,7 @@ private double[] getSAState(Node saNode) {
510512

511513
} else {
512514

513-
int nodeInterval = param.getNodeIntervalIndex(saNode, finalSampleOffset.getArrayValue());
515+
int nodeInterval = param.getNodeIntervalIndex(saNode, finalSampleOffset.get());
514516

515517
if (saTypeKnown) {
516518
for (int type = 0; type < param.getNTypes(); type++) {
@@ -543,7 +545,7 @@ private double[] getSAState(Node saNode) {
543545

544546
private double[] getInternalState(Node internalNode) {
545547

546-
double internalNodeTime = param.getNodeTime(internalNode, finalSampleOffset.getArrayValue());
548+
double internalNodeTime = param.getNodeTime(internalNode, finalSampleOffset.get());
547549

548550
double[] yLeft = backwardsIntegrateSubtree(internalNode.getChild(0), internalNodeTime);
549551
double[] yRight = backwardsIntegrateSubtree(internalNode.getChild(1), internalNodeTime);
@@ -553,7 +555,7 @@ private double[] getInternalState(Node internalNode) {
553555

554556
double[] y = new double[param.getNTypes()*2];
555557

556-
int nodeInterval = param.getNodeIntervalIndex(internalNode, finalSampleOffset.getArrayValue());
558+
int nodeInterval = param.getNodeIntervalIndex(internalNode, finalSampleOffset.get());
557559

558560
int N = param.getNTypes();
559561

@@ -618,7 +620,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
618620
int currentType = startType;
619621
double currentTime = startTime;
620622

621-
double endTime = param.getNodeTime(subtreeRoot, finalSampleOffset.getArrayValue());
623+
double endTime = param.getNodeTime(subtreeRoot, finalSampleOffset.get());
622624

623625
double[] rates = new double[param.getNTypes()];
624626
double[] ratesPrime = new double[param.getNTypes()];
@@ -672,7 +674,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
672674
}
673675
}
674676

675-
currentNode.setHeight(param.getNodeAge(currentTime, finalSampleOffset.getArrayValue()));
677+
currentNode.setHeight(param.getNodeAge(currentTime, finalSampleOffset.get()));
676678

677679
// Sample event type
678680

@@ -687,7 +689,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
687689
currentNode = newNode;
688690
}
689691

690-
currentNode.setHeight(param.getNodeAge(endTime, finalSampleOffset.getArrayValue()));
692+
currentNode.setHeight(param.getNodeAge(endTime, finalSampleOffset.get()));
691693

692694
switch(getNodeKind(subtreeRoot)) {
693695
case LEAF:
@@ -727,7 +729,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
727729
* @return backward-time integration result at this point on the tree
728730
*/
729731
private double[] getBackwardsIntegrationResult(Node node, double time) {
730-
double parentTime = node.isRoot() ? 0.0 : param.getNodeTime(node.getParent(), finalSampleOffset.getArrayValue());
732+
double parentTime = node.isRoot() ? 0.0 : param.getNodeTime(node.getParent(), finalSampleOffset.get());
731733
double adjustedTime = Math.max(time, parentTime + 2*Utils.globalPrecisionThreshold);
732734

733735
ContinuousOutputModel com = integrationResults[node.getNr()];
@@ -746,7 +748,7 @@ private double[] getBackwardsIntegrationResult(Node node, double time) {
746748

747749
private int[] sampleChildTypes(Node node, int parentType) {
748750

749-
double t = param.getNodeTime(node, finalSampleOffset.getArrayValue());
751+
double t = param.getNodeTime(node, finalSampleOffset.get());
750752
int interval = param.getIntervalIndex(t);
751753

752754
double[] y1 = getBackwardsIntegrationResult(node.getChild(0), t);

src/main/java/bdmmprime/trajectories/SampledTrajectory.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424
import bdmmprime.util.Utils;
2525
import beast.base.core.*;
2626
import beast.base.inference.CalculationNode;
27-
import beast.base.inference.parameter.RealParameter;
2827
import beast.base.evolution.tree.Node;
2928
import beast.base.evolution.tree.Tree;
29+
import beast.base.spec.domain.NonNegativeReal;
30+
import beast.base.spec.inference.parameter.RealScalarParam;
31+
import beast.base.spec.type.RealScalar;
32+
import beast.base.spec.type.Simplex;
3033

3134
import java.io.PrintStream;
3235
import java.util.ArrayList;
@@ -48,11 +51,11 @@ public class SampledTrajectory extends CalculationNode implements Loggable {
4851
public Input<Parameterization> parameterizationInput = new Input<>("parameterization",
4952
"Multi-type birth-death parameterization.", Input.Validate.REQUIRED);
5053

51-
public Input<Function> finalSampleOffsetInput = new Input<>("finalSampleOffset",
54+
public Input<RealScalar<? extends NonNegativeReal>> finalSampleOffsetInput = new Input<>("finalSampleOffset",
5255
"If provided, the difference in time between the final sample and the end of the BD process.",
53-
new RealParameter("0.0"));
56+
new RealScalarParam<>(0.0, NonNegativeReal.INSTANCE));
5457

55-
public Input<Function> startTypePriorProbsInput = new Input<>("startTypePriorProbs",
58+
public Input<Simplex> startTypePriorProbsInput = new Input<>("startTypePriorProbs",
5659
"The prior probabilities for the starting type. Only needed for testing tree prob estimates.");
5760

5861
public Input<Integer> nParticlesInput = new Input<>("nParticles",
@@ -86,7 +89,7 @@ public class SampledTrajectory extends CalculationNode implements Loggable {
8689
Tree mappedTree;
8790
String typeLabel;
8891
Parameterization param;
89-
Function finalSampleOffset;
92+
RealScalar<? extends NonNegativeReal> finalSampleOffset;
9093

9194
int nTypes, nParticles;
9295
double resampThresh;
@@ -261,7 +264,7 @@ public double getLogTreeProbEstimate() {
261264
// distribution having the observed state. However, because _all_ of the
262265
// particles have this initial value, it doesn't affect the weight distribution.
263266
// It affects the tree prob estimate though, which is important for testing.
264-
logTreeProbEstimate += Math.log(startTypePriorProbsInput.get().getArrayValue(rootType));
267+
logTreeProbEstimate += Math.log(startTypePriorProbsInput.get().get(rootType));
265268
}
266269

267270
return logTreeProbEstimate - logGamma(mappedTree.getLeafNodeCount() + 1);
@@ -294,7 +297,7 @@ List<ObservedEvent> getObservedEventList(Tree tree) {
294297
List<ObservedEvent> eventList = new ArrayList<>();
295298
ObservedSamplingEvent[] thisSamplingEvent = new ObservedSamplingEvent[param.getNTypes()];
296299
for (Node node : sampleNodes) {
297-
double t = param.getNodeTime(node, finalSampleOffset.getArrayValue());
300+
double t = param.getNodeTime(node, finalSampleOffset.get());
298301
int type = getNodeType(node, typeLabel);
299302

300303
if (thisSamplingEvent[type] == null || !Utils.equalWithPrecision(t,thisSamplingEvent[type].time)) {
@@ -316,13 +319,13 @@ List<ObservedEvent> getObservedEventList(Tree tree) {
316319
if (node.getChildCount() == 1) {
317320
// Observed type change
318321

319-
eventList.add(new TypeChangeEvent(param.getNodeTime(node, finalSampleOffset.getArrayValue()),
322+
eventList.add(new TypeChangeEvent(param.getNodeTime(node, finalSampleOffset.get()),
320323
getNodeType(node, typeLabel),
321324
getNodeType(node.getChild(0), typeLabel),1));
322325
} else {
323326
// Coalescence
324327

325-
eventList.add(new CoalescenceEvent(param.getNodeTime(node, finalSampleOffset.getArrayValue()),
328+
eventList.add(new CoalescenceEvent(param.getNodeTime(node, finalSampleOffset.get()),
326329
getNodeType(node, typeLabel),
327330
getNodeType(node.getChild(0), typeLabel),
328331
getNodeType(node.getChild(1), typeLabel), 1));

0 commit comments

Comments
 (0)