2121import bdmmprime .parameterization .Parameterization ;
2222import bdmmprime .util .Utils ;
2323import beast .base .core .Citation ;
24- import beast .base .core .Function ;
2524import beast .base .core .Input ;
2625import beast .base .evolution .tree .Node ;
2726import beast .base .evolution .tree .TraitSet ;
2827import beast .base .evolution .tree .Tree ;
29- import beast .base .inference .parameter .RealParameter ;
28+ import beast .base .spec .domain .NonNegativeReal ;
29+ import beast .base .spec .inference .parameter .RealScalarParam ;
30+ import beast .base .spec .type .RealScalar ;
31+ import beast .base .spec .type .Simplex ;
3032import beast .base .util .Randomizer ;
3133import org .apache .commons .math3 .ode .ContinuousOutputModel ;
3234import org .apache .commons .math3 .ode .FirstOrderIntegrator ;
@@ -72,13 +74,13 @@ public class TypeMappedTree extends Tree {
7274 "BDMM parameterization" ,
7375 Input .Validate .REQUIRED );
7476
75- public Input <RealParameter > startTypePriorProbsInput = new Input <>("startTypePriorProbs" ,
77+ public Input <Simplex > startTypePriorProbsInput = new Input <>("startTypePriorProbs" ,
7678 "The prior probabilities for the initial individual type" ,
7779 Input .Validate .REQUIRED );
7880
79- public Input <Function > finalSampleOffsetInput = new Input <>("finalSampleOffset" ,
81+ public Input <RealScalar <? extends NonNegativeReal > > finalSampleOffsetInput = new Input <>("finalSampleOffset" ,
8082 "If provided, the difference in time between the final sample and the end of the BD process." ,
81- new RealParameter ( " 0.0" ));
83+ new RealScalarParam <>( 0.0 , NonNegativeReal . INSTANCE ));
8284
8385 public Input <TraitSet > typeTraitSetInput = new Input <>("typeTraitSet" ,
8486 "Trait information for initializing traits " +
@@ -107,7 +109,7 @@ public class TypeMappedTree extends Tree {
107109 Input .Validate .XOR , parameterizationInput );
108110
109111 private Parameterization param ;
110- private Function finalSampleOffset ;
112+ private RealScalar <? extends NonNegativeReal > finalSampleOffset ;
111113 private Tree untypedTree ;
112114
113115 private ODESystem odeSystem ;
@@ -183,7 +185,7 @@ private void doStochasticMapping() {
183185 double [] startTypeProbs = new double [param .getNTypes ()];
184186
185187 for (int type =0 ; type <param .getNTypes (); type ++)
186- startTypeProbs [type ] = y [type +param .getNTypes ()]* startTypePriorProbsInput .get ().getValue (type );
188+ startTypeProbs [type ] = y [type +param .getNTypes ()]* startTypePriorProbsInput .get ().get (type );
187189
188190 // (startTypeProbs are unnormalized: this is okay for randomChoicePDF.)
189191 int startType = Randomizer .randomChoicePDF (startTypeProbs );
@@ -263,7 +265,7 @@ private void computeRhoSampledLeafStatus() {
263265 rhoSamplingIndex = new int [untypedTree .getLeafNodeCount ()];
264266
265267 for (int nodeNr =0 ; nodeNr < treeInput .get ().getLeafNodeCount (); nodeNr ++) {
266- double nodeTime = param .getNodeTime (untypedTree .getNode (nodeNr ), finalSampleOffset .getArrayValue ());
268+ double nodeTime = param .getNodeTime (untypedTree .getNode (nodeNr ), finalSampleOffset .get ());
267269 rhoSampled [nodeNr ] = false ;
268270 for (double rhoSamplingTime : param .getRhoSamplingTimes ()) {
269271 if (Utils .equalWithPrecision (nodeTime , rhoSamplingTime )) {
@@ -360,7 +362,7 @@ private NodeKind getNodeKind(Node node) {
360362
361363 double delta = 2 *Utils .globalPrecisionThreshold ;
362364
363- double timeOfSubtreeRootEdgeBottom = param .getNodeTime (untypedSubtreeRoot , finalSampleOffset .getArrayValue ());
365+ double timeOfSubtreeRootEdgeBottom = param .getNodeTime (untypedSubtreeRoot , finalSampleOffset .get ());
364366
365367 odeIntegrator .addEventHandler (odeSystem ,
366368 (timeOfSubtreeRootEdgeBottom -timeOfSubtreeRootEdgeTop )/RATE_CHANGE_CHECKS_PER_EDGE ,
@@ -388,7 +390,7 @@ private double[] getLeafState(Node leafNode) {
388390 y [param .getNTypes ()+type ] = 0.0 ;
389391 }
390392
391- double leafTime = param .getNodeTime (leafNode , finalSampleOffset .getArrayValue ());
393+ double leafTime = param .getNodeTime (leafNode , finalSampleOffset .get ());
392394 double T = param .getTotalProcessLength ();
393395
394396 if (Utils .lessThanWithPrecision (leafTime , T )) {
@@ -439,7 +441,7 @@ private double[] getLeafState(Node leafNode) {
439441
440442 } else {
441443
442- int nodeInterval = param .getNodeIntervalIndex (leafNode , finalSampleOffset .getArrayValue ());
444+ int nodeInterval = param .getNodeIntervalIndex (leafNode , finalSampleOffset .get ());
443445
444446 if (leafTypeKnown ) {
445447 for (int type = 0 ; type < param .getNTypes (); type ++) {
@@ -473,7 +475,7 @@ private double[] getLeafState(Node leafNode) {
473475
474476 private double [] getSAState (Node saNode ) {
475477
476- double saNodeTime = param .getNodeTime (saNode , finalSampleOffset .getArrayValue ());
478+ double saNodeTime = param .getNodeTime (saNode , finalSampleOffset .get ());
477479
478480 double [] y = backwardsIntegrateSubtree (saNode .getNonDirectAncestorChild (), saNodeTime );
479481
@@ -510,7 +512,7 @@ private double[] getSAState(Node saNode) {
510512
511513 } else {
512514
513- int nodeInterval = param .getNodeIntervalIndex (saNode , finalSampleOffset .getArrayValue ());
515+ int nodeInterval = param .getNodeIntervalIndex (saNode , finalSampleOffset .get ());
514516
515517 if (saTypeKnown ) {
516518 for (int type = 0 ; type < param .getNTypes (); type ++) {
@@ -543,7 +545,7 @@ private double[] getSAState(Node saNode) {
543545
544546 private double [] getInternalState (Node internalNode ) {
545547
546- double internalNodeTime = param .getNodeTime (internalNode , finalSampleOffset .getArrayValue ());
548+ double internalNodeTime = param .getNodeTime (internalNode , finalSampleOffset .get ());
547549
548550 double [] yLeft = backwardsIntegrateSubtree (internalNode .getChild (0 ), internalNodeTime );
549551 double [] yRight = backwardsIntegrateSubtree (internalNode .getChild (1 ), internalNodeTime );
@@ -553,7 +555,7 @@ private double[] getInternalState(Node internalNode) {
553555
554556 double [] y = new double [param .getNTypes ()*2 ];
555557
556- int nodeInterval = param .getNodeIntervalIndex (internalNode , finalSampleOffset .getArrayValue ());
558+ int nodeInterval = param .getNodeIntervalIndex (internalNode , finalSampleOffset .get ());
557559
558560 int N = param .getNTypes ();
559561
@@ -618,7 +620,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
618620 int currentType = startType ;
619621 double currentTime = startTime ;
620622
621- double endTime = param .getNodeTime (subtreeRoot , finalSampleOffset .getArrayValue ());
623+ double endTime = param .getNodeTime (subtreeRoot , finalSampleOffset .get ());
622624
623625 double [] rates = new double [param .getNTypes ()];
624626 double [] ratesPrime = new double [param .getNTypes ()];
@@ -672,7 +674,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
672674 }
673675 }
674676
675- currentNode .setHeight (param .getNodeAge (currentTime , finalSampleOffset .getArrayValue ()));
677+ currentNode .setHeight (param .getNodeAge (currentTime , finalSampleOffset .get ()));
676678
677679 // Sample event type
678680
@@ -687,7 +689,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
687689 currentNode = newNode ;
688690 }
689691
690- currentNode .setHeight (param .getNodeAge (endTime , finalSampleOffset .getArrayValue ()));
692+ currentNode .setHeight (param .getNodeAge (endTime , finalSampleOffset .get ()));
691693
692694 switch (getNodeKind (subtreeRoot )) {
693695 case LEAF :
@@ -727,7 +729,7 @@ private Node forwardSimulateSubtree(Node subtreeRoot, double startTime, int star
727729 * @return backward-time integration result at this point on the tree
728730 */
729731 private double [] getBackwardsIntegrationResult (Node node , double time ) {
730- double parentTime = node .isRoot () ? 0.0 : param .getNodeTime (node .getParent (), finalSampleOffset .getArrayValue ());
732+ double parentTime = node .isRoot () ? 0.0 : param .getNodeTime (node .getParent (), finalSampleOffset .get ());
731733 double adjustedTime = Math .max (time , parentTime + 2 *Utils .globalPrecisionThreshold );
732734
733735 ContinuousOutputModel com = integrationResults [node .getNr ()];
@@ -746,7 +748,7 @@ private double[] getBackwardsIntegrationResult(Node node, double time) {
746748
747749 private int [] sampleChildTypes (Node node , int parentType ) {
748750
749- double t = param .getNodeTime (node , finalSampleOffset .getArrayValue ());
751+ double t = param .getNodeTime (node , finalSampleOffset .get ());
750752 int interval = param .getIntervalIndex (t );
751753
752754 double [] y1 = getBackwardsIntegrationResult (node .getChild (0 ), t );
0 commit comments